Repository URL to install this package:
|
Version:
1.20.1 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import logging
from dataclasses import dataclass
import onnx
from ...quant_utils import QuantType
from ...tensor_quant_overrides import QuantTypeInfo, TensorQuantOverridesHelper
@dataclass
class TensorTypeRequest:
"""
Bundles desired quantization type requests for a tensor. A distinction is made between the
produced type and the consumed type.
"""
# The tensor's quant type at the producer end. If None, assumed to be the default activation quant type.
producer: QuantTypeInfo | None
# The tensor's quant type received by a set of consumer nodes.
# If None, assumed to be the default activation quant type for all consumers.
# consumers[1] is a set of consumer node names.
consumers: tuple[QuantTypeInfo, set[str]] | None
class MixedPrecisionTensorQuantOverridesFixer:
"""
Helper that generates tensor quantization overrides for mixed-precision QDQ models.
Specifically, this helper fixes an initial set of quantization overrides that assign a non-default
activation quantization type to one or more tensors by doing the following:
- Inferring which other tensors need to be overridden to the non-default activation quantization type.
- Inserting quantization data type conversions.
Example:
--------
Float model:
input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0
^
|
input_1 --> Op2 -+-> Op4 ----+
|
+-> Op7 --> output_1
|
+-> Op8 --> output_2
If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out"
is quantized to 16-bit, then we would specify the following initial tensor quantization overrides:
```
init_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]}
```
These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output
to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types
are valid:
```
overrides = TensorQuantOverridesHelper(init_overrides)
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, QuantType.QUInt8)
fixer.apply(
default_activation_qtype=QuantType.QUInt8,
default_activation_symmetric=False,
)
```
The above snippet generates the following "fixed" overrides (get via overrides.get_dict()):
{
"Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}],
"Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}],
"Op4_out": [{"quant_type": QUInt16}],
"Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}]
}
How to interpret the fixed overrides:
- Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type,
but Op7 and Op8 consume the original u8 type.
- Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type.
- Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type.
- Op5's output is converted from u16 to u8. Op6 consumes the u8 type.
"""
def __init__(
self,
overrides: TensorQuantOverridesHelper,
producers: dict[str, onnx.NodeProto],
consumers: dict[str, list[onnx.NodeProto]],
value_infos: dict[str, onnx.ValueInfoProto],
initializers: dict[str, onnx.TensorProto],
):
"""
Params:
overrides: The initial tensor quantization overrides to fix.
producers: Dictionary that maps a tensor name to the producer node that generates the tensor.
consumers: Dictionary that maps a tensor name to the consumer nodes that take the tensor as input.
value_infos: Dictionary that maps a tensor name to its onnx.ValueInfoProto.
initializers: Dictionary that maps an initializer name to its onnx.TensorProto.
"""
self.overrides = overrides
self.consumers = consumers
self.producers = producers
self.value_infos = value_infos
self.initializers = initializers
@staticmethod
def create_from_model(
overrides: TensorQuantOverridesHelper, model: onnx.ModelProto, default_activation_qtype: QuantType
) -> MixedPrecisionTensorQuantOverridesFixer:
"""
Helper function that creates an instance of this class from a loaded ONNX model.
Params:
overrides: The initial tensor quantization overrides to fix.
model: Loaded ONNX model
default_activation_qtype: The intended default activation quantization type.
Used to validate the initial overrides.
Returns:
Initialized MixedPrecisionTensorQuantOverridesFixer object
"""
model = onnx.shape_inference.infer_shapes(model) # Need to infer shapes to get value_infos
# Build dictionaries that enable convenient lookups of initializers and value_infos by name.
initializers = {initializer.name: initializer for initializer in model.graph.initializer}
value_infos = {vi.name: vi for vi in model.graph.value_info}
value_infos.update({ot.name: ot for ot in model.graph.output})
value_infos.update({it.name: it for it in model.graph.input})
# Ensure that the user-provided initial overrides are actually valid.
valid, err = overrides.is_valid(initializers, set(value_infos), default_activation_qtype)
if not valid:
pprint_overrides = overrides.pprint_str(indent=4)
logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}")
raise ValueError(err)
consumers = {}
producers = {}
# Build dictionaries that map a tensor name to the consumer or producer nodes.
for node in model.graph.node:
for input_name in node.input:
if input_name:
if input_name not in consumers:
consumers[input_name] = []
consumers[input_name].append(node)
for output_name in node.output:
producers[output_name] = node
return MixedPrecisionTensorQuantOverridesFixer(overrides, producers, consumers, value_infos, initializers)
def apply(
self,
default_activation_qtype: QuantType,
default_activation_symmetric: bool,
):
"""
Fixes the initial tensor quantization overrides (in-place) for use in mixed-precision QDQ models.
Params:
default_activation_qtype: The intended default activation quantization type.
default_activation_symmetric: The intended default symmetry used to quantize activations.
"""
type_requests = self.get_desired_tensor_types(default_activation_qtype, default_activation_symmetric)
# Use type requests to "fix" tensor quantization overrides by adding
# quantization type conversions where necessary.
for tensor_name, type_req in type_requests.items():
all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])])
has_producer_req = type_req.producer is not None
has_consumer_req = bool(type_req.consumers)
# Only producer type: Add conversion back to default activation type
if has_producer_req and not has_consumer_req:
self._update_converted_tensor(
tensor_name, type_req.producer, QuantTypeInfo(default_activation_qtype), all_consumers
)
# Only consumers
elif not has_producer_req and has_consumer_req:
prod_type_info = self.overrides.get_node_output_qtype_info(tensor_name, default_activation_qtype)
consumer_type_info = type_req.consumers[0]
if prod_type_info != consumer_type_info:
self._update_converted_tensor(
tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
)
else:
if not self._check_nodes_are_not_convert_consumers(tensor_name, type_req.consumers[1]):
raise ValueError(
f"Tensor override for '{tensor_name}' converts the type for consumers that need the original type."
)
# Both producer and consumers
elif has_producer_req and has_consumer_req:
prod_type_info = type_req.producer
consumer_type_info = type_req.consumers[0]
if prod_type_info != consumer_type_info:
self._update_converted_tensor(
tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
)
else:
consumers_for_original_type = all_consumers.difference(type_req.consumers[1])
if len(consumers_for_original_type) == 0:
# All consumers want the overridden type, so no need for convert nodes!
# Just add the override to the new new if not already present.
if tensor_name not in self.overrides:
self.overrides[tensor_name] = [{}]
prod_type_info.save_to_dict(self.overrides[tensor_name][0])
assert "convert" not in self.overrides[tensor_name][0]
else:
# Some consumers don't want the overridden type.
self._update_converted_tensor(
tensor_name,
prod_type_info,
QuantTypeInfo(default_activation_qtype),
consumers_for_original_type,
)
else:
raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.")
# Done. Check if the overrides are valid.
valid, err = self.overrides.is_valid(self.initializers, set(self.value_infos), default_activation_qtype)
if not valid:
pprint_overrides = self.overrides.pprint_str(indent=4)
logging.error(
f"Generated invalid tensor quantization overrides for mixed-precision QDQ model:\n{pprint_overrides}"
)
raise ValueError(err)
def get_desired_tensor_types(
self,
default_activation_qtype: QuantType,
default_activation_symmetric: bool,
) -> dict[str, TensorTypeRequest]:
"""
Iterates through the initial tensor quantization overrides and builds a set of TensorTypeRequests objects
that describe the quantization types required at each tensor. These TensorTypeRequests objects are ultimately
used to generated the "fixed" overrides.
Params:
default_activation_qtype: The intended default activation quantization type.
default_activation_symmetric: The intended default symmetry used to quantize activations.
Returns:
TensorTypeRequest objects as a dict that maps a tensor name to its requested types.
"""
type_requests = {}
default_activation_type_info = QuantTypeInfo(default_activation_qtype, default_activation_symmetric)
# Scan tensor overrides for type conversion requests.
for tensor_name, override_list in self.overrides.items():
if not self.__is_tensor_quantizable(tensor_name):
continue # Skip non-quantizable tensors (e.g., not a float)
if tensor_name in self.initializers:
continue # Skip initializers
if not override_list or len(override_list) > 1:
continue # Skip per-channel stuff
override_dict = override_list[0]
quant_type_info = QuantTypeInfo.load_from_dict(override_dict, default_activation_type_info.quant_type)
producer_node = self.producers.get(tensor_name) # None if this is a model input
if quant_type_info != default_activation_type_info and "convert" not in override_dict:
if producer_node is not None:
self._add_type_requests_for_node(type_requests, quant_type_info, producer_node)
# Find all consumer nodes of `tensor_name` and update their inputs/outputs to the new type.
for consumer_node in self.consumers.get(tensor_name, []):
self._add_type_requests_for_node(type_requests, quant_type_info, consumer_node)
return type_requests
def _add_type_requests_for_node(
self,
type_requests: dict[str, TensorTypeRequest],
quant_type_info: QuantTypeInfo,
node: onnx.NodeProto,
):
"""
Adds TensorTypeRequest objects for a given node, assuming that we want all its inputs and outputs
to have the same quantization type (as specified by the `quant_type_info` parameter).
Params:
type_requests: Dictionary of type requests to append to for this node.
quant_type_info: The quantization type to use for inputs and outputs.
node: The node for which the TensorTypeRequest objects are created and added to type_requests.
"""
# Add output side
for output_name in node.output:
if not self.__is_tensor_quantizable(output_name):
continue
if output_name not in type_requests:
type_requests[output_name] = TensorTypeRequest(quant_type_info, None)
else:
if (
type_requests[output_name].producer is not None
and type_requests[output_name].producer != quant_type_info
):
raise ValueError(f"Tensor {output_name} has multiple types.")
type_requests[output_name].producer = quant_type_info
# Add the consumer side
for input_name in node.input:
if input_name and input_name not in self.initializers and self.__is_tensor_quantizable(input_name):
if input_name not in type_requests:
type_requests[input_name] = TensorTypeRequest(None, None)
if type_requests[input_name].consumers is None:
type_requests[input_name].consumers = (quant_type_info, set())
if type_requests[input_name].consumers[0] != quant_type_info:
raise ValueError(f"Tensor {input_name} has consumers requesting different types.")
if not node.name:
raise ValueError(
f"Node of type {node.op_type} with output 0 {node.output[0]} does not have a name!"
)
type_requests[input_name].consumers[1].add(node.name)
def _update_converted_tensor(
self,
tensor_name: str,
producer_type_info: QuantTypeInfo,
consumer_type_info: QuantTypeInfo,
consumer_names: set[str],
):
"""
Updates the tensor quantization overrides for a tensor that is converted from one type to another.
Params:
tensor_name: The name of the tensor for which to update overrides.
producer_type_info: Info for the tensor's produced type.
consumer_type_info: Info for the tensor's consumed (i.e., converted) type.
consumer_names: Nodes names of consumers that consume the converted type.
"""
if tensor_name not in self.overrides or not self.overrides[tensor_name]:
self.overrides[tensor_name] = [{}]
producer_type_info.save_to_dict(self.overrides[tensor_name][0])
overrides = self.overrides[tensor_name][0]
if producer_type_info != QuantTypeInfo.load_from_dict(overrides):
raise ValueError(f"Desired producer quant_type for {tensor_name} doesn't match existing type.")
if consumer_names:
if "convert" not in overrides:
overrides["convert"] = {}
consumer_type_info.save_to_dict(overrides["convert"])
convert_dict = overrides["convert"]
if consumer_type_info != QuantTypeInfo.load_from_dict(convert_dict):
raise ValueError(f"Desired consumer quant_type for {tensor_name} doesn't match existing type.")
if "recv_nodes" not in convert_dict:
convert_dict["recv_nodes"] = set()
convert_dict["recv_nodes"].update(consumer_names)
def _check_nodes_are_not_convert_consumers(self, tensor_name: str, node_names: set[str]):
"""
Returns true if the given nodes do not consume/receive a converted quantization type.
Params:
tensor_name: The name of the tensor to check.
node_names: Set of node names that should not be consumers of the converted type.
"""
if tensor_name not in self.overrides or not self.overrides[tensor_name]:
return True
overrides = self.overrides[tensor_name][0]
if "convert" not in overrides:
return True
convert_dict = overrides["convert"]
if "recv_nodes" not in convert_dict:
return False
return not convert_dict["recv_nodes"].intersection(node_names)
def __is_tensor_quantizable(self, tensor_name):
weight = self.initializers.get(tensor_name)
if weight is not None:
if weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16):
return True
elif tensor_name in self.value_infos:
vi = self.value_infos[tensor_name]
if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
onnx.TensorProto.FLOAT,
onnx.TensorProto.FLOAT16,
):
return True
return False