import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from .utils import ReferenceQuantizedModule
__all__ = ['Linear']
class Linear(nn.Linear, ReferenceQuantizedModule):
""" A reference quantized linear module that fits into the FX
Graph Mode Quantization workflow
activation will be floating point Tensor, we will store floating
point weight as well in the module, but in forward we'll quantize
and dequantize the weight before running the floating point functional
linear operator.
"""
_IS_REFERENCE = True
def __init__(
self,
in_features: int,
out_features: int,
bias_: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_qparams: Optional[Dict[str, Any]] = None):
super().__init__(in_features, out_features, bias_, device, dtype)
self._init_weight_qparams(weight_qparams, device)
def _get_name(self):
return "QuantizedLinear(Reference)"
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.linear ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.linear --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized linear
"""
weight_quant_dequant = self.get_weight()
result = F.linear(x, weight_quant_dequant, self.bias)
return result
@classmethod
def from_float(cls, float_linear, weight_qparams):
qref_linear = Linear(
float_linear.in_features, float_linear.out_features,
float_linear.bias is not None, device=float_linear.weight.device,
dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
if float_linear.bias is not None:
qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
return qref_linear