from caffe2.python import schema
from caffe2.python.layers.layers import ModelLayer
import numpy as np
class LayerNormalization(ModelLayer):
def __init__(
self,
model,
input_record,
name='layer_normalization',
scale_optim=None,
bias_optim=None,
epsilon=1e-4,
axis=1,
use_layer_norm_op=True,
scale_init_value=1.0,
**kwargs
):
super(LayerNormalization, self).__init__(
model, name, input_record, **kwargs)
assert isinstance(input_record, schema.Scalar), (
"Incorrect input type: {}".format(input_record))
self.input_shape = input_record.field_type().shape
self.axis = axis
assert len(self.input_shape) >= 1, (
"This layer supports only >= 2D tensors")
input_dims = self.input_shape[0]
self.output_schema = schema.Scalar(
(np.float32, self.input_shape),
self.get_next_blob_reference('output')
)
self.scale = self.create_param(param_name='scale',
shape=[input_dims],
initializer=('ConstantFill', {'value': scale_init_value}),
optimizer=scale_optim)
self.bias = self.create_param(param_name='bias',
shape=[input_dims],
initializer=('ConstantFill', {'value': 0.0}),
optimizer=bias_optim)
self.use_layer_norm_op = use_layer_norm_op
if self.use_layer_norm_op:
self.epsilon = epsilon
else:
assert len(self.input_shape) == 1, (
"When using alternative implementation, "
"input data can only be 2D"
)
self.epsilon = model.maybe_add_global_constant(
"%s_epsilon" % self.name, float(epsilon)
)
def add_ops_with_layer_norm_op(self, net):
input_blob = self.input_record.field_blobs()
ln_output = self.output_schema.field_blobs()
output_blobs = [net.NextScopedBlob('ln_output'), net.NextScopedBlob('ln_mean'),
net.NextScopedBlob('ln_stdev')]
normalized, mean, stdev = net.LayerNorm(input_blob,
output_blobs,
axis=self.axis,
epsilon=self.epsilon)
scaled = net.Mul(
[normalized, self.scale],
[net.NextScopedBlob('ln_scaled')],
broadcast=1,
axis=self.axis,
)
net.Add(
[scaled, self.bias],
ln_output,
broadcast=1,
axis=self.axis,
)
def add_ops_without_layer_norm_op(self, net):
# two issues here:
# 1. use multiple ops to replace the function of LayerNorm
# 2. do not use legacy broadcast
ln_output = net.NextScopedBlob("ln_output")
ln_mean = net.NextScopedBlob("ln_mean")
ln_stdev = net.NextScopedBlob("ln_stdev")
ln_mean_arr = net.NextScopedBlob("ln_mean_arr")
net.ReduceBackMean(self.input_record.field_blobs(), [ln_mean_arr])
net.ExpandDims([ln_mean_arr], [ln_mean], dims=[1])
ln_centered = net.NextScopedBlob("ln_centered")
net.Sub(self.input_record.field_blobs() + [ln_mean], [ln_centered])
ln_sqr = net.NextScopedBlob("ln_sqr")
net.Sqr([ln_centered], [ln_sqr])
ln_sqr_mean = net.NextScopedBlob("ln_sqr_mean")
net.ReduceBackMean([ln_sqr], [ln_sqr_mean])
ln_var = net.NextScopedBlob("ln_var")
net.Add([ln_sqr_mean, self.epsilon], ln_var)
ln_std_arr = net.NextScopedBlob("ln_std_arr")
net.Pow([ln_var], [ln_std_arr], exponent=0.5)
net.ExpandDims([ln_std_arr], [ln_stdev], dims=[1])
net.Div([ln_centered, ln_stdev], [ln_output])
ln_scaled = net.NextScopedBlob("ln_scaled")
net.Mul([ln_output, self.scale], [ln_scaled])
net.Add([ln_scaled, self.bias], self.output_schema.field_blobs())
def add_ops(self, net):
if self.use_layer_norm_op:
self.add_ops_with_layer_norm_op(net)
else:
self.add_ops_without_layer_norm_op(net)