Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ nn / quantized / _reference / modules / linear.py

import torch
import torch.nn.quantized as nnq
import torch.nn.functional as F
from typing import Optional

class Linear(nnq.Linear):
    """ A backend independent version of nn.quantized.Linear
        we will not pack the parameters in this module, since weight packing is an
        optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
        this is useful when user want to use this module in other backends like Glow.
    """
    def __init__(self, in_features, out_features, bias_=True,
                 dtype=torch.qint8):
        super().__init__(in_features, out_features, bias_, dtype)
        self._qweight, self._bias = self._packed_params._weight_bias()
        del self._packed_params

    def _get_name(self):
        return "QuantizedLinear(Reference)"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_dequant = x.dequantize()
        weight_dequant = self._qweight.dequantize()
        float_result = F.linear(x_dequant, weight_dequant, self._bias)
        # NEEDFIX: we don't have dtype in the Linear module APIs right now!
        result = torch.quantize_per_tensor(
            float_result, self.scale, self.zero_point, torch.quint8)
        return result

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        super()._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + '_qweight'] = self._qweight
        destination[prefix + '_bias'] = self._bias

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self._qweight = state_dict[prefix + '_qweight']
        self._bias = state_dict[prefix + '_bias']
        state_dict.pop(prefix + '_qweight')
        state_dict.pop(prefix + '_bias')

        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, False,
            missing_keys, unexpected_keys, error_msgs)

    def _weight_bias(self):
        return self._qweight, self._bias

    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
        self._qweight = w
        self._bias = b