Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / layers / label_smooth.py

# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

# @package label_smooth
# Module caffe2.python.layers.label_smooth





from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer
import numpy as np


class LabelSmooth(ModelLayer):
    def __init__(
        self, model, label, smooth_matrix, name='label_smooth', **kwargs
    ):
        super(LabelSmooth, self).__init__(model, name, label, **kwargs)
        self.label = label
        # shape as a list
        smooth_matrix = np.array(smooth_matrix).astype(np.float32).flatten()
        self.set_dim(smooth_matrix)
        self.set_smooth_matrix(smooth_matrix)
        self.output_schema = schema.Scalar(
            (np.float32, (self.dim, )),
            self.get_next_blob_reference('smoothed_label')
        )

    def set_dim(self, smooth_matrix):
        num_elements = smooth_matrix.size
        self.binary_prob_label = (num_elements == 2)
        if self.binary_prob_label:
            self.dim = 1
        else:
            assert np.sqrt(num_elements)**2 == num_elements
            self.dim = int(np.sqrt(num_elements))

    def set_smooth_matrix(self, smooth_matrix):
        if not self.binary_prob_label:
            self.smooth_matrix = self.model.add_global_constant(
                '%s_label_smooth_matrix' % self.name,
                array=smooth_matrix.reshape((self.dim, self.dim)),
                dtype=np.dtype(np.float32),
            )
            self.len = self.model.add_global_constant(
                '%s_label_dim' % self.name,
                array=self.dim,
                dtype=np.dtype(np.int64),
            )
        else:
            self.smooth_matrix = smooth_matrix

    def add_ops_for_binary_prob_label(self, net):
        if self.label.field_type().base != np.float32:
            float32_label = net.NextScopedBlob('float32_label')
            net.Cast([self.label()], [float32_label], to=core.DataType.FLOAT)
        else:
            float32_label = self.label()
        net.StumpFunc(
            float32_label,
            self.output_schema(),
            threshold=0.5,
            low_value=self.smooth_matrix[0],
            high_value=self.smooth_matrix[1],
        )

    def add_ops_for_categorical_label(self, net):
        if self.label.field_type().base != np.int64:
            int64_label = net.NextScopedBlob('int64_label')
            net.Cast([self.label()], [int64_label], to=core.DataType.INT64)
        else:
            int64_label = self.label()
        one_hot_label = net.NextScopedBlob('one_hot_label')
        net.OneHot([int64_label, self.len], [one_hot_label])
        net.MatMul([one_hot_label, self.smooth_matrix], self.output_schema())

    def add_ops(self, net):
        if self.binary_prob_label:
            self.add_ops_for_binary_prob_label(net)
        else:
            self.add_ops_for_categorical_label(net)