Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ nn / quantized / modules / batchnorm.py

import torch
import torch.nn.quantized.functional
import torch.nn.intrinsic as nni

class BatchNorm2d(torch.nn.BatchNorm2d):
    r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
    """

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm2d, self).__init__(num_features)
        self.eps = eps
        self.scale = 1.0
        self.zero_point = 0

    def forward(self, input):
        return torch.ops.quantized.batch_norm2d(input, self.weight, self.bias, self.running_mean,
                                                self.running_var, self.eps, self.scale, self.zero_point)

    def _get_name(self):
        return 'QuantizedBatchNorm2d'

    @classmethod
    def from_float(cls, mod):
        activation_post_process = mod.activation_post_process
        if type(mod) == nni.BNReLU2d:
            mod = mod[0]
        scale, zero_point = activation_post_process.calculate_qparams()
        new_mod = cls(mod.num_features, mod.eps)
        new_mod.weight = mod.weight
        new_mod.bias = mod.bias
        new_mod.running_mean = mod.running_mean
        new_mod.running_var = mod.running_var
        new_mod.scale = float(scale)
        new_mod.zero_point = int(zero_point)
        return new_mod

# TODO: dedup with BatchNorm2d
class BatchNorm3d(torch.nn.BatchNorm3d):
    r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
    """

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm3d, self).__init__(num_features)
        self.eps = eps
        self.scale = 1.0
        self.zero_point = 0

    def forward(self, input):
        return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean,
                                                self.running_var, self.eps, self.scale, self.zero_point)

    def _get_name(self):
        return 'QuantizedBatchNorm3d'

    @classmethod
    def from_float(cls, mod):
        activation_post_process = mod.activation_post_process
        if type(mod) == nni.BNReLU3d:
            mod = mod[0]
        scale, zero_point = activation_post_process.calculate_qparams()
        new_mod = cls(mod.num_features, mod.eps)
        new_mod.weight = mod.weight
        new_mod.bias = mod.bias
        new_mod.running_mean = mod.running_mean
        new_mod.running_var = mod.running_var
        new_mod.scale = float(scale)
        new_mod.zero_point = int(zero_point)
        return new_mod