# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
# @package label_smooth
# Module caffe2.python.layers.label_smooth
from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer
import numpy as np
class LabelSmooth(ModelLayer):
def __init__(
self, model, label, smooth_matrix, name='label_smooth', **kwargs
):
super(LabelSmooth, self).__init__(model, name, label, **kwargs)
self.label = label
# shape as a list
smooth_matrix = np.array(smooth_matrix).astype(np.float32).flatten()
self.set_dim(smooth_matrix)
self.set_smooth_matrix(smooth_matrix)
self.output_schema = schema.Scalar(
(np.float32, (self.dim, )),
self.get_next_blob_reference('smoothed_label')
)
def set_dim(self, smooth_matrix):
num_elements = smooth_matrix.size
self.binary_prob_label = (num_elements == 2)
if self.binary_prob_label:
self.dim = 1
else:
assert np.sqrt(num_elements)**2 == num_elements
self.dim = int(np.sqrt(num_elements))
def set_smooth_matrix(self, smooth_matrix):
if not self.binary_prob_label:
self.smooth_matrix = self.model.add_global_constant(
'%s_label_smooth_matrix' % self.name,
array=smooth_matrix.reshape((self.dim, self.dim)),
dtype=np.dtype(np.float32),
)
self.len = self.model.add_global_constant(
'%s_label_dim' % self.name,
array=self.dim,
dtype=np.dtype(np.int64),
)
else:
self.smooth_matrix = smooth_matrix
def add_ops_for_binary_prob_label(self, net):
if self.label.field_type().base != np.float32:
float32_label = net.NextScopedBlob('float32_label')
net.Cast([self.label()], [float32_label], to=core.DataType.FLOAT)
else:
float32_label = self.label()
net.StumpFunc(
float32_label,
self.output_schema(),
threshold=0.5,
low_value=self.smooth_matrix[0],
high_value=self.smooth_matrix[1],
)
def add_ops_for_categorical_label(self, net):
if self.label.field_type().base != np.int64:
int64_label = net.NextScopedBlob('int64_label')
net.Cast([self.label()], [int64_label], to=core.DataType.INT64)
else:
int64_label = self.label()
one_hot_label = net.NextScopedBlob('one_hot_label')
net.OneHot([int64_label, self.len], [one_hot_label])
net.MatMul([one_hot_label, self.smooth_matrix], self.output_schema())
def add_ops(self, net):
if self.binary_prob_label:
self.add_ops_for_binary_prob_label(net)
else:
self.add_ops_for_categorical_label(net)