Repository URL to install this package:
|
Version:
2.1.2+cpu ▾
|
from typing import Callable
import torch
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.fx import Node
def _is_share_obs_or_fq_op(op: Callable) -> bool:
return op in [
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
torch.ops.aten.mean.default,
torch.ops.aten.mean.dim,
torch.ops.aten.permute.default,
torch.ops.aten.permute_copy.default,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze_copy.dim,
# TODO: remove?
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.view_copy.default,
torch.ops.aten.view.default,
torch.ops.aten.slice_copy.Tensor,
torch.ops.aten.flatten.using_ints,
]
def propagate_annotation(model: torch.fx.GraphModule) -> None:
for n in model.graph.nodes:
if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
continue
prev_node = n.args[0]
if not isinstance(prev_node, Node):
continue
quantization_annotation = prev_node.meta.get("quantization_annotation", None)
if not quantization_annotation:
continue
output_qspec = quantization_annotation.output_qspec
if not output_qspec:
continue
# make sure current node is not annotated
if (
"quantization_annotation" in n.meta
and n.meta["quantization_annotation"]._annotated
):
continue
shared_qspec = SharedQuantizationSpec(prev_node)
# propagate the previous output_qspec to the current node
n.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
prev_node: shared_qspec,
},
output_qspec=shared_qspec,
_annotated=True,
)