Repository URL to install this package:
Version:
1.23.0 ▾
|
#
# The implementation of this file is based on:
# https://github.com/intel/neural-compressor/tree/master/neural_compressor
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications:
# Add k-quant quantization method.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""WeightOnly for onnxrt adaptor."""
import copy
import logging
import os
import sys
import numpy as np
import onnx
from onnx import numpy_helper
from onnx.helper import np_dtype_to_tensor_dtype
import onnxruntime as ort
from .onnx_model import ONNXModel
from .util import simple_progress_bar
logger = logging.getLogger("neural_compressor")
def make_matmul_weight_only_node(
node,
weight_shape,
num_bits,
group_size,
k_blocks,
q_weight,
scale,
zero_point,
accuracy_level=0,
): # pragma: no cover
"""Build MatMulNBits node.
Args:
node: original matmul node
weight_shape: original weight shape
num_bits (int): num_bits
group_size (int): how many elements share one scale/zp
k_blocks (int): block number
q_weight (array): quantized weight
scale (array): scale
zero_point (array): zero point
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
Returns:
matmul_weight_only_node: MatMulNBits node
new_inits: initializers of the new node
"""
blob_size = group_size * num_bits // 8
packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}"
input_names = [node.input[0], q_weight_name]
new_inits = []
kwargs = {}
op_type = "MatMulNBits"
# pack quantized weight
if num_bits == 4:
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
packed[:, :] = q_weight_pairs[:, :blob_size]
elif num_bits == 8:
packed = q_weight
else:
logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
packed = np.reshape(packed, (-1, k_blocks, blob_size))
# build scale tensor
scale = np.reshape(scale, (-1, k_blocks))
assert scale.dtype == np.float32 or scale.dtype == np.float16
scale_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_scale",
data_type=np_dtype_to_tensor_dtype(scale.dtype),
dims=scale.shape,
vals=scale.tobytes(),
raw=True,
)
input_names.append(scale_tensor.name)
new_inits.append(scale_tensor)
# build zero_point tensor
if zero_point is not None:
if num_bits == 8:
packed_zp = zero_point.astype("uint8")
elif num_bits == 4:
# For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8.
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
# create an index array
idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1)
# separate odd and even indices
even_idx = idx[::2]
odd_idx = idx[1::2]
# vectorized operation for even and odd indices
packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)
else:
raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
packed_zp = np.reshape(packed_zp, (weight_shape[1], -1))
zp_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
)
input_names.append(zp_tensor.name)
new_inits.append(zp_tensor)
# set kwargs
kwargs["K"] = weight_shape[0]
kwargs["N"] = weight_shape[1]
kwargs["bits"] = num_bits
kwargs["block_size"] = group_size
if accuracy_level > 0:
# require onnxruntime > 1.16.3
kwargs["accuracy_level"] = accuracy_level
q_weight_tensor = onnx.helper.make_tensor(
name=q_weight_name,
data_type=2,
dims=packed.shape,
vals=packed.tobytes(),
raw=True,
)
new_inits.append(q_weight_tensor)
matmul_weight_only_node = onnx.helper.make_node(
op_type,
inputs=input_names,
outputs=node.output,
name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
domain="com.microsoft",
**kwargs,
)
return matmul_weight_only_node, new_inits
def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
"""Quantize tensor per group.
Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
scheme (str, optional): quantization scheme. Defaults to "asym".
dtype (str, optional): data type. Defaults to "int".
ratio (float, optional): percentile of clip. Defaults to 1.0.
Returns:
output: quantized weight
scale: scale
zero_point: zero point
"""
data = np.reshape(data, (-1, group_size))
if scheme == "asym" or dtype == "uint":
maxq = 2**num_bits - 1
minq = 0
elif scheme == "sym":
maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0
minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1
rmin = np.min(data, axis=1, keepdims=True) * ratio
rmax = np.max(data, axis=1, keepdims=True) * ratio
if scheme == "sym":
max_range = np.maximum(np.abs(rmin), np.abs(rmax))
scale = np.ones(rmax.shape)
mask = max_range > 0
scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq)
zero_point = (
np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1))
)
else:
scale = np.ones(rmax.shape)
scale[rmin != rmax] = np.array(
[float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()]
)
zero_point = (
((np.zeros(scale.shape) - rmin) / scale).round()
if dtype == "int"
else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8")
)
q_weight = np.empty_like(data, dtype=scale.dtype)
np.divide(data, scale, out=q_weight)
np.add(q_weight, zero_point, out=q_weight)
np.round(q_weight, out=q_weight)
np.clip(q_weight, minq, maxq, out=q_weight)
return q_weight, scale, zero_point
def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
"""Quantize tensor per group based on k quant.
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
Returns:
output: quantized weight
scale: scale
zero_point: zero point
"""
data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
maxq = 2**num_bits - 1
minq = 0
sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
weights = np.add(av_x, np.abs(data)) # (nb, group_size)
rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
mask = rmin != rmax
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
scale = 1 / iscale
quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
diff = scale * quant_data + rmin - data # (nb, group_size)
best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
nstep = 20
rdelta = 0.1
# nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
rrmin = -1
for is_ in range(nstep):
iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
mask = rmin != rmax
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
mul_weights_quant_data_new = weights * quant_data_new
sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
D = np.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
mad_1 = np.array(mad)
best_mad_1 = np.array(best_mad)
idx_to_replace = np.where(mad_1 < best_mad_1)[0]
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
best_mad[idx_to_replace] = mad[idx_to_replace]
scale[idx_to_replace] = this_scale[idx_to_replace]
rmin[idx_to_replace] = this_min[idx_to_replace]
zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
scale = scale.astype(np.float64)
q_weight = np.empty_like(data, dtype=scale.dtype)
np.divide(data, scale, out=q_weight)
np.add(q_weight, zero_point, out=q_weight)
np.round(q_weight, out=q_weight)
np.clip(q_weight, minq, maxq, out=q_weight)
return q_weight, scale, zero_point
def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
"""Quantize tensor per group based on k quant.
Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
Returns:
output: quantized weight
scale: scale
zero_point: zero point
"""
try:
import cupy as cp # noqa: PLC0415
import torch # noqa: PLC0415
if torch.cuda.is_available():
data = cp.asarray(data)
data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
maxq = 2**num_bits - 1
minq = 0
sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
mask = rmin != rmax
iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
scale = 1 / iscale
quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
diff = scale * quant_data + rmin - data # (nb, group_size)
best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
nstep = 20
rdelta = 0.1
rrmin = -1
for is_ in range(nstep):
iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
mask = rmin != rmax
iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
mul_weights_quant_data_new = weights * quant_data_new
sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
D = cp.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
mad_1 = cp.array(mad)
best_mad_1 = cp.array(best_mad)
idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
best_mad[idx_to_replace] = mad[idx_to_replace]
scale[idx_to_replace] = this_scale[idx_to_replace]
rmin[idx_to_replace] = this_min[idx_to_replace]
zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
scale = scale.astype(cp.float64)
q_weight = cp.empty_like(data, dtype=scale.dtype)
cp.divide(data, scale, out=q_weight)
cp.add(q_weight, zero_point, out=q_weight)
cp.round(q_weight, out=q_weight)
cp.clip(q_weight, minq, maxq, out=q_weight)
return q_weight.get(), scale.get(), zero_point.get()
else:
logger.warning(
"Try to use k-quant quantization on CUDA. However, CUDA is not available."
"Fall back to k-quant quantization on CPU."
)
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
except ImportError:
logger.info(
"Now we are using k-quant quantization on cpu, which is time consuming."
"Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
"Please also install torch to check CUDA availability."
)
return quant_tensor_k_quant_cpu(data, num_bits, group_size)
def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
"""Quant dequant tensor per group.
Args:
data : input weight
num_bits (int, optional): num_bits. Defaults to 4.
group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
scheme (str, optional): quantization scheme. Defaults to "asym".
dtype (str, optional): data type. Defaults to "int".
ratio (float, optional): percentile of clip. Defaults to 1.0.
Returns:
output: quant-dequant weight
"""
org_shape = data.shape
weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio)
return np.reshape(scale * (weight - zp), org_shape)
def pad_tensor(weight, group_size, k_blocks):
"""Pad tensor rowi so that it can be is divisible by group_size.
Args:
weight (array): weight
group_size (int): how many elements share one scale/zp
k_blocks (int): the number of block
Returns:
weight: paded weight
"""
if group_size == -1:
return weight
org_w_shape = weight.shape
padded_rows = k_blocks * group_size
pad_len = padded_rows - org_w_shape[0]
if pad_len > 0:
weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant")
return weight
def rtn_quantize(
model,
weight_config={}, # noqa: B006
num_bits=4,
group_size=32,
scheme="asym",
ratios={}, # noqa: B006
accuracy_level=0,
providers=["CPUExecutionProvider"], # noqa: B006
algorithm="k_quant",
):
"""Quant the model with round to nearst method.
Args:
model (ModelProto or ONNXModel): onnx model
weight_config (dict): quantization config
For example,
weight_config = {
'fc2':
{
'bits': 4,
'group_size': 32,
'scheme': 'sym',
'algorithm': 'RTN'
}
}
num_bits (int, optional): num_bits. Default is 4.
group_size (int, optional): how many elements share one scale/zp. Default is 32.
scheme (str, optional): sym or asym. Defaults to "asym".
ratios (dict, optional): percentile of clip. Defaults to {}.
accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8).
providers (list): providers to use
Returns:
model: fake quantized ONNXModel
"""
model = ONNXModel(model)
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
new_nodes = []
remove_nodes = []
total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]])
curr_id = 0
for node in model.nodes():
if node.op_type in ["MatMul"]:
curr_id += 1
simple_progress_bar(total_num, curr_id)
if (
node.op_type in ["MatMul"]
and model.get_initializer(node.input[1]) is not None
and weight_config.get(node.name, {}) != "fp32"
):
weight_tensor = model.get_initializer(node.input[1])
weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
if len(weight.shape) != 2:
continue
dtype = weight.dtype
if node.name in weight_config:
num_bits = weight_config[node.name]["bits"]
group_size = weight_config[node.name]["group_size"]
scheme = weight_config[node.name]["scheme"]
org_w_shape = weight.shape # ic, oc
group_size = group_size if group_size != -1 else org_w_shape[0]
k_blocks = (org_w_shape[0] - 1) // group_size + 1
init_share_num = model.get_initializer_share_num(node.input[1])
weight = pad_tensor(weight, group_size, k_blocks)
satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8 # noqa: N806
if satisfy_MatMulNBits_condition: # pragma: no cover
if algorithm == "k_quant":
q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size)
else:
q_weight, scale, zp = quant_tensor(
weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
)
q_matmul_node, new_inits = make_matmul_weight_only_node(
node=node,
weight_shape=org_w_shape,
num_bits=num_bits,
group_size=group_size,
k_blocks=k_blocks,
q_weight=q_weight.astype("uint8"),
scale=scale.astype(dtype),
zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None,
accuracy_level=accuracy_level,
)
model.add_initializers(new_inits)
remove_nodes.append(node)
new_nodes.append(q_matmul_node)
else:
q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
q_weight = np.reshape(q_weight, (org_w_shape[1], -1))
q_weight = np.transpose(q_weight)
q_weight = q_weight[: org_w_shape[0], :].astype(dtype)
q_weight_tensor = onnx.helper.make_tensor(
name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
data_type=np_dtype_to_tensor_dtype(dtype),
dims=weight.shape,
vals=q_weight.tobytes(),
raw=True,
)
model.add_initializer(q_weight_tensor)
node.input[1] = q_weight_tensor.name
if init_share_num == 1:
model.remove_initializer(weight_tensor)
model.add_nodes(new_nodes)
model.remove_nodes(remove_nodes)
model.topological_sort()
return model
def get_weight_scale(weight, group_size):
"""Get the scale of weight."""
org_shape = weight.shape
weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight
scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0)
return scale
def prepare_inputs(model, n_samples, dataloader, providers):
"""Prepare inputs for weight only quantization.
Args:
model (ModelProto or ONNXModel): onnx model
n_samples (int, optional): calibration sample number. -1 means all samples.
dataloader (object): dataloader for calibration.
providers (list): providers to use
Returns:
inputs: prepared inputs.
so: session options
"""
from importlib.util import find_spec # noqa: PLC0415
from .util import to_numpy # noqa: PLC0415
so = ort.SessionOptions()
if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
from onnxruntime_extensions import get_library_path # noqa: PLC0415
so.register_custom_ops_library(get_library_path())
if model.is_large_model:
onnx.save_model(
model.model,
model.model_path + "_augment.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=False,
)
session = (
ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
if not model.is_large_model
else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
)
inputs_names = [i.name for i in session.get_inputs()]
del session
inputs = []
for i, data in enumerate(dataloader):
if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples:
break
if len(inputs_names) != 1 or isinstance(data[0], dict):
assert len(data[0]) == len(inputs_names), (
f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}"
)
if isinstance(data[0], dict):
inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) # noqa: C404
elif isinstance(data[0], np.ndarray): # pragma: no cover
inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)])) # noqa: C404
else: # pragma: no cover
inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)])) # noqa: C404
return inputs, so
def gptq(
W,
H,
num_bits=4,
group_size=32,
scheme="asym",
blocksize=128,
percdamp=0.01,
actorder=False,
mse=False,
perchannel=True,
):
"""Quant the weight with GPTQ method.
Args:
W (array): weight.
H (array): Hessian matrix.
num_bits (int, optional): num_bits. Default is 4.
group_size (int, optional): how many elements share one scale/zp. Default is 32.
scheme (str, optional): sym or asym. Defaults to "asym".
blocksize (int, optional): blocksize to quantize weight.
percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
mse (bool, optional): whether get scale and zero point with mse error.
perchannel (bool, optional): whether quantize weight per-channel.
Returns:
Q: fake quantized weight
"""
maxq = 2**num_bits - 1
grid = 100
maxshrink = 0.8
norm = 2.4
def find_params(weight):
org_shape = weight.shape
# find zp, scale
if not perchannel:
weight = np.expand_dims(weight.flatten(), axis=1)
tmp = np.zeros(weight.shape[1])
xmin = np.minimum(np.min(weight, axis=0), tmp)
xmax = np.maximum(np.max(weight, axis=0), tmp)
if scheme == "sym":
xmax = np.maximum(np.abs(xmin), xmax)
tmp = xmin < 0
if np.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
scale = (xmax - xmin) / maxq
if scheme == "sym":
zero = np.ones(scale.shape) * (maxq + 1) / 2
else:
zero = np.round(-xmin / scale)
if mse:
best = np.ones([weight.shape[1]]) * float("inf")
for i in range(int(maxshrink * grid)):
p = 1 - i / grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / maxq
zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero
q = np.clip(np.round(weight / scale1) + zero1, 0, maxq)
q -= weight
q = np.power(np.abs(q), norm)
err = np.sum(q, 0)
tmp = err < best
if np.any(tmp):
best[tmp] = err[tmp]
scale[tmp] = scale1[tmp]
zero[tmp] = zero1[tmp]
if not perchannel:
tmp = org_shape[1]
scale = np.repeat(scale, tmp)
zero = np.repeat(zero, tmp)
shape = [-1] + [1] * (len(org_shape) - 1)
scale = np.reshape(scale, shape)
zero = np.reshape(zero, shape)
return scale, zero
shape = W.shape
scale, zp = find_params(W)
dead = np.diag(H) == 0
H[dead, dead] = 1
W[dead, :] = 0 # such channel makes no contribution to quantization computation
# rearrange considering the diag's value
if actorder:
perm = np.argsort(np.diag(H))[::-1]
W = W[perm, :] # noqa: N806
H = H[perm, :][:, perm] # noqa: N806
Losses = np.zeros_like(W) # noqa: N806
Q = np.zeros_like(W) # noqa: N806
damp = percdamp * np.mean(np.diag(H))
diag = np.arange(shape[0])
H[diag, diag] += damp # add a average value of
H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806
Hinv = H # noqa: N806
for i1 in range(0, shape[0], blocksize):
i2 = min(i1 + blocksize, shape[0])
count = i2 - i1
W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806
Q1 = np.zeros_like(W1) # noqa: N806
Err1 = np.zeros_like(W1) # noqa: N806
Losses1 = np.zeros_like(W1) # noqa: N806
Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806
for i in range(count): # within a block, channel wise
w = W1[i, :]
d = Hinv1[i, i]
if group_size != -1:
if (i1 + i) % group_size == 0:
scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :])
q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten()
Q1[i, :] = q
Losses1[i, :] = (w - q) ** 2 / d**2
err1 = (w - q) / d
W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0))
Err1[i, :] = err1
Q[i1:i2, :] = Q1
Losses[i1:i2, :] = Losses1 / 2
W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1)
if actorder:
invperm = np.argsort(perm)
Q = Q[invperm, :] # noqa: N806
Q = np.reshape(Q, W.shape) # noqa: N806
del W
return Q
def gptq_quantize(
model,
dataloader,
weight_config={}, # noqa: B006
num_bits=4,
group_size=32,
scheme="asym",
n_samples=128,
percdamp=0.01,
blocksize=128,
actorder=False,
mse=False,
perchannel=True,
accuracy_level=0,
providers=["CPUExecutionProvider"], # noqa: B006
):
"""Quant the model with GPTQ method.
Args:
model (ModelProto or ONNXModel): onnx model
dataloader (object): dataloader for calibration.
weight_config (dict): quantization config
For example,
weight_config = {
'fc2':
{
'bits': 4,
'group_size': 32,
'scheme': 'sym',
'algorithm': 'GPTQ'
}
}
num_bits (int, optional): num_bits. Default is 4.
group_size (int, optional): how many elements share one scale/zp. Default is 32.
scheme (str, optional): sym or asym. Defaults to "asym".
n_samples (int, optional): calibration sample number.
percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
blocksize (int, optional): blocksize to quantize weight.
actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
mse (bool, optional): whether get scale and zero point with mse error.
perchannel (bool, optional): whether quantize weight per-channel.
accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
providers (list): providers to use
Returns:
model: fake quantized ONNXModel
"""
model = ONNXModel(model)
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
inputs, so = prepare_inputs(model, n_samples, dataloader, providers)
del dataloader
org_output = copy.deepcopy(model.model.graph.output)
model.remove_tensors_from_outputs([i.name for i in org_output])
output_names = []
for node in model.nodes():
if (
node.op_type in ["MatMul"]
and weight_config.get(node.name, {}) != "fp32"
and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
):
output_names.append(node.input[0])
output_names = list(set(output_names))
model.add_tensors_to_outputs(output_names)
if model.is_large_model:
onnx.save_model(
model.model,
model.model_path + "_augment.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=False,
)
session = (
ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
if not model.is_large_model
else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
)
for idx, input_name in enumerate(output_names):
simple_progress_bar(len(output_names), idx + 1)
node_list = []
weights = []
for node in model.input_name_to_nodes[input_name]:
if (
node.op_type in ["MatMul"]
and weight_config.get(node.name, {}) != "fp32"
and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
and model.get_initializer(node.input[1]) is not None
):
weight = numpy_helper.to_array(
model.get_initializer(model.get_node(node.name).input[1]), base_dir
).copy()
if len(weight.shape) != 2:
continue
weights.append(weight)
node_list.append(model.get_node(node.name))
if len(weights) == 0:
continue
Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806
nsamples = 0
for data in inputs:
inp = session.run([input_name], data)[0]
tmp = inp.shape[0]
inp = np.reshape(inp, (-1, inp.shape[-1]))
Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806
nsamples += tmp
inp = np.sqrt(2 / nsamples) * inp
Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806
for (
node,
weight,
H, # noqa: N806
) in zip(node_list, weights, Hs, strict=False):
if node.name in weight_config:
num_bits = weight_config[node.name]["bits"]
group_size = weight_config[node.name]["group_size"]
scheme = weight_config[node.name]["scheme"]
group_size = group_size if group_size != -1 else weight.shape[0]
dtype = weight.dtype
q_weight = gptq(
weight,
H,
num_bits=num_bits,
group_size=group_size,
scheme=scheme,
blocksize=blocksize,
percdamp=percdamp,
actorder=actorder,
mse=mse,
perchannel=perchannel,
)
weight_tensor = model.get_initializer(node.input[1])
init_share_num = model.get_initializer_share_num(node.input[1])
satisfy_MatMulNBits_condition = num_bits == 4 # noqa: N806
if satisfy_MatMulNBits_condition: # pragma: no cover
org_shape = weight.shape
k_blocks = (org_shape[0] + group_size - 1) // group_size
q_weight = pad_tensor(q_weight, group_size, k_blocks)
q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint")
q_matmul_node, new_inits = make_matmul_weight_only_node(
node=node,
weight_shape=org_shape,
num_bits=num_bits,
group_size=group_size,
k_blocks=k_blocks,
q_weight=q_weight.astype("uint8"),
scale=scale.astype(dtype),
zero_point=zp if scheme == "asym" else None,
accuracy_level=accuracy_level,
)
model.add_initializers(new_inits)
model.remove_node(node)
model.add_node(q_matmul_node)
else:
q_weight_tensor = onnx.helper.make_tensor(
name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
data_type=np_dtype_to_tensor_dtype(dtype),
dims=q_weight.shape,
vals=q_weight.astype(dtype).tobytes(),
raw=True,
)
model.add_initializer(q_weight_tensor)
node.input[1] = q_weight_tensor.name
if init_share_num == 1:
model.remove_initializer(weight_tensor)
model.remove_tensors_from_outputs(output_names)
model.model.graph.output.MergeFrom(org_output)
model.topological_sort()
# reload external data to prevent external data file path errors
if model.is_large_model:
from onnx.external_data_helper import load_external_data_for_model # noqa: PLC0415
load_external_data_for_model(model.model, os.path.split(model.model_path)[0])
return model