Repository URL to install this package:
Version:
1.23.0 ▾
|
import argparse
import json
import os
import numpy as np
import onnx
import onnxruntime
from onnxruntime.quantization import QuantFormat, QuantType, StaticQuantConfig, quantize
from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
class OnnxModelCalibrationDataReader(CalibrationDataReader):
def __init__(self, model_path):
self.model_dir = os.path.dirname(model_path)
data_dirs = [
os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir) if a.startswith("test_data_set_")
]
model_inputs = onnxruntime.InferenceSession(model_path).get_inputs()
name2tensors = []
for data_dir in data_dirs:
name2tensor = {}
data_paths = [os.path.join(data_dir, a) for a in sorted(os.listdir(data_dir))]
data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths]
for model_input, data_ndarray in zip(model_inputs, data_ndarrays, strict=False):
name2tensor[model_input.name] = data_ndarray
name2tensors.append(name2tensor)
assert len(name2tensors) == len(data_dirs)
assert len(name2tensors[0]) == len(model_inputs)
self.calibration_data = iter(name2tensors)
def get_next(self) -> dict:
"""generate the input data dict for ONNXinferenceSession run"""
return next(self.calibration_data, None)
def read_onnx_pb_data(self, file_pb):
tensor = onnx.TensorProto()
with open(file_pb, "rb") as f:
tensor.ParseFromString(f.read())
ret = onnx.numpy_helper.to_array(tensor)
return ret
def parse_arguments():
parser = argparse.ArgumentParser(description="The arguments for static quantization")
parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model")
parser.add_argument(
"-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model"
)
parser.add_argument(
"--activation_type",
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
default="quint8",
help="Activation quantization type used",
)
parser.add_argument(
"--weight_type",
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
default="qint8",
help="Weight quantization type used",
)
parser.add_argument("--enable_subgraph", action="store_true", help="If set, subgraph will be quantized.")
parser.add_argument(
"--force_quantize_no_input_check",
action="store_true",
help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not"
" quantized already. Setting to True to force such operator always quantize input and so generate"
" quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.",
)
parser.add_argument(
"--matmul_const_b_only",
action="store_true",
help="If set, only MatMul with const B will be quantized.",
)
parser.add_argument(
"--add_qdq_pair_to_weight",
action="store_true",
help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear"
" nodes to weight.",
)
parser.add_argument(
"--dedicated_qdq_pair",
action="store_true",
help="If set, it will create identical and dedicated QDQ pair for each node.",
)
parser.add_argument(
"--op_types_to_exclude_output_quantization",
nargs="+",
default=[],
help="If any op type is specified, it won't quantize the output of ops with this specific op types.",
)
parser.add_argument(
"--calibration_method",
default="minmax",
choices=["minmax", "entropy", "percentile", "distribution"],
help="Calibration method used",
)
parser.add_argument("--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used")
parser.add_argument(
"--calib_tensor_range_symmetric",
action="store_true",
help="If enabled, the final range of tensor during calibration will be explicitly"
" set to symmetric to central point 0",
)
# TODO: --calib_strided_minmax"
# TODO: --calib_moving_average_constant"
# TODO: --calib_max_intermediate_outputs"
parser.add_argument(
"--calib_moving_average",
action="store_true",
help="If enabled, the moving average of"
" the minimum and maximum values will be computed when the calibration method selected is MinMax.",
)
parser.add_argument(
"--disable_quantize_bias",
action="store_true",
help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node"
" If not set, it remains floating-point bias and does not insert any quantization nodes"
" associated with biases.",
)
# TODO: Add arguments related to Smooth Quant
parser.add_argument(
"--use_qdq_contrib_ops",
action="store_true",
help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the com.microsoft domain,"
" which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations.",
)
parser.add_argument(
"--minimum_real_range",
type=float,
default=0.0001,
help="If set to a floating-point value, the calculation of the quantization parameters"
" (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)"
" is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is"
" necessary for EPs like QNN that require a minimum floating-point range when determining "
" quantization parameters.",
)
parser.add_argument(
"--qdq_keep_removable_activations",
action="store_true",
help="If set, removable activations (e.g., Clip or Relu) will not be removed,"
" and will be explicitly represented in the QDQ model.",
)
parser.add_argument(
"--qdq_disable_weight_adjust_for_int32_bias",
action="store_true",
help="If set, QDQ quantizer will not adjust the weight's scale when the bias"
" has a scale (input_scale * weight_scale) that is too small.",
)
parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization")
parser.add_argument(
"--nodes_to_quantize",
nargs="+",
default=None,
help="List of nodes names to quantize. When this list is not None only the nodes in this list are quantized.",
)
parser.add_argument(
"--nodes_to_exclude",
nargs="+",
default=None,
help="List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None.",
)
parser.add_argument(
"--op_per_channel_axis",
nargs=2,
action="append",
metavar=("OP_TYPE", "PER_CHANNEL_AXIS"),
default=[],
help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's"
" effective only when per channel quantization is supported and per_channel is True. If specific"
" op type supports per channel quantization but not explicitly specified with channel axis,"
" default channel axis will be used.",
)
parser.add_argument("--tensor_quant_overrides", help="Set the json file for tensor quantization overrides.")
return parser.parse_args()
def get_tensor_quant_overrides(file):
# TODO: Enhance the function to handle more real cases of json file
if not file:
return {}
with open(file) as f:
quant_override_dict = json.load(f)
for tensor in quant_override_dict:
for enc_dict in quant_override_dict[tensor]:
enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32)
enc_dict["zero_point"] = np.array(enc_dict["zero_point"])
return quant_override_dict
def main():
args = parse_arguments()
data_reader = OnnxModelCalibrationDataReader(model_path=args.input_model_path)
arg2quant_type = {
"qint8": QuantType.QInt8,
"quint8": QuantType.QUInt8,
"qint16": QuantType.QInt16,
"quint16": QuantType.QUInt16,
"qint4": QuantType.QInt4,
"quint4": QuantType.QUInt4,
"qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN,
}
activation_type = arg2quant_type[args.activation_type]
weight_type = arg2quant_type[args.weight_type]
qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis)
extra_options = {
"EnableSubgraph": args.enable_subgraph,
"ForceQuantizeNoInputCheck": args.force_quantize_no_input_check,
"MatMulConstBOnly": args.matmul_const_b_only,
"AddQDQPairToWeight": args.add_qdq_pair_to_weight,
"OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization,
"DedicatedQDQPair": args.dedicated_qdq_pair,
"QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis,
"CalibTensorRangeSymmetric": args.calib_tensor_range_symmetric,
"CalibMovingAverage": args.calib_moving_average,
"QuantizeBias": not args.disable_quantize_bias,
"UseQDQContribOps": args.use_qdq_contrib_ops,
"MinimumRealRange": args.minimum_real_range,
"QDQKeepRemovableActivations": args.qdq_keep_removable_activations,
"QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias,
# Load json file for encoding override
"TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides),
}
arg2calib_method = {
"minmax": CalibrationMethod.MinMax,
"entropy": CalibrationMethod.Entropy,
"percentile": CalibrationMethod.Percentile,
"distribution": CalibrationMethod.Distribution,
}
arg2quant_format = {
"qdq": QuantFormat.QDQ,
"qoperator": QuantFormat.QOperator,
}
sqc = StaticQuantConfig(
calibration_data_reader=data_reader,
calibrate_method=arg2calib_method[args.calibration_method],
quant_format=arg2quant_format[args.quant_format],
activation_type=activation_type,
weight_type=weight_type,
op_types_to_quantize=None,
nodes_to_quantize=args.nodes_to_quantize,
nodes_to_exclude=args.nodes_to_exclude,
per_channel=args.per_channel,
reduce_range=False,
use_external_data_format=False,
calibration_providers=None, # Use CPUExecutionProvider
extra_options=extra_options,
)
quantize(model_input=args.input_model_path, model_output=args.output_quantized_model_path, quant_config=sqc)
if __name__ == "__main__":
main()