Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ ao / nn / quantized / reference / modules / linear.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from .utils import ReferenceQuantizedModule

__all__ = ['Linear']

class Linear(nn.Linear, ReferenceQuantizedModule):
    """ A reference quantized linear module that fits into the FX
    Graph Mode Quantization workflow
    activation will be floating point Tensor, we will store floating
    point weight as well in the module, but in forward we'll quantize
    and dequantize the weight before running the floating point functional
    linear operator.
    """
    _IS_REFERENCE = True

    def __init__(
            self,
            in_features: int,
            out_features: int,
            bias_: bool = True,
            device: Optional[torch.device] = None,
            dtype: Optional[torch.dtype] = None,
            weight_qparams: Optional[Dict[str, Any]] = None):
        super().__init__(in_features, out_features, bias_, device, dtype)
        self._init_weight_qparams(weight_qparams, device)

    def _get_name(self):
        return "QuantizedLinear(Reference)"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        we have:
        w(float) -- quant - dequant \
        x(float) ------------- F.linear ---

        In the full model, we will see
        w(float) -- quant - *dequant \
        x -- quant --- *dequant --  *F.linear --- *quant - dequant
        and the backend should be able to fuse the ops with `*` into a quantized linear
        """
        weight_quant_dequant = self.get_weight()
        result = F.linear(x, weight_quant_dequant, self.bias)
        return result

    @classmethod
    def from_float(cls, float_linear, weight_qparams):
        qref_linear = Linear(
            float_linear.in_features, float_linear.out_features,
            float_linear.bias is not None, device=float_linear.weight.device,
            dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
        qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
        if float_linear.bias is not None:
            qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
        return qref_linear