Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
#
#  The implementation of this file is based on:
# https://github.com/intel/neural-compressor/tree/master/neural_compressor
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications:
# Add k-quant quantization method.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""WeightOnly for onnxrt adaptor."""

import copy
import logging
import os
import sys

import numpy as np
import onnx
from onnx import numpy_helper
from onnx.helper import np_dtype_to_tensor_dtype

import onnxruntime as ort

from .onnx_model import ONNXModel
from .util import simple_progress_bar

logger = logging.getLogger("neural_compressor")


def make_matmul_weight_only_node(
    node,
    weight_shape,
    num_bits,
    group_size,
    k_blocks,
    q_weight,
    scale,
    zero_point,
    accuracy_level=0,
):  # pragma: no cover
    """Build MatMulNBits node.

    Args:
        node: original matmul node
        weight_shape: original weight shape
        num_bits (int): num_bits
        group_size (int): how many elements share one scale/zp
        k_blocks (int): block number
        q_weight (array): quantized weight
        scale (array): scale
        zero_point (array): zero point
        accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).

    Returns:
        matmul_weight_only_node: MatMulNBits node
        new_inits: initializers of the new node
    """
    blob_size = group_size * num_bits // 8
    packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
    q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}"
    input_names = [node.input[0], q_weight_name]
    new_inits = []
    kwargs = {}

    op_type = "MatMulNBits"

    # pack quantized weight
    if num_bits == 4:
        q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
        packed[:, :] = q_weight_pairs[:, :blob_size]
    elif num_bits == 8:
        packed = q_weight
    else:
        logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")

    packed = np.reshape(packed, (-1, k_blocks, blob_size))

    # build scale tensor
    scale = np.reshape(scale, (-1, k_blocks))
    assert scale.dtype == np.float32 or scale.dtype == np.float16
    scale_tensor = onnx.helper.make_tensor(
        name=node.input[1] + "_scale",
        data_type=np_dtype_to_tensor_dtype(scale.dtype),
        dims=scale.shape,
        vals=scale.tobytes(),
        raw=True,
    )
    input_names.append(scale_tensor.name)
    new_inits.append(scale_tensor)

    # build zero_point tensor
    if zero_point is not None:
        if num_bits == 8:
            packed_zp = zero_point.astype("uint8")
        elif num_bits == 4:
            # For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8.
            packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
            # create an index array
            idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1)
            # separate odd and even indices
            even_idx = idx[::2]
            odd_idx = idx[1::2]
            # vectorized operation for even and odd indices
            packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
            packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)
        else:
            raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")

        packed_zp = np.reshape(packed_zp, (weight_shape[1], -1))
        zp_tensor = onnx.helper.make_tensor(
            name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
        )
        input_names.append(zp_tensor.name)
        new_inits.append(zp_tensor)

    # set kwargs
    kwargs["K"] = weight_shape[0]
    kwargs["N"] = weight_shape[1]
    kwargs["bits"] = num_bits
    kwargs["block_size"] = group_size
    if accuracy_level > 0:
        # require onnxruntime > 1.16.3
        kwargs["accuracy_level"] = accuracy_level

    q_weight_tensor = onnx.helper.make_tensor(
        name=q_weight_name,
        data_type=2,
        dims=packed.shape,
        vals=packed.tobytes(),
        raw=True,
    )
    new_inits.append(q_weight_tensor)

    matmul_weight_only_node = onnx.helper.make_node(
        op_type,
        inputs=input_names,
        outputs=node.output,
        name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
        domain="com.microsoft",
        **kwargs,
    )
    return matmul_weight_only_node, new_inits


def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
    """Quantize tensor per group.

    Args:
        data : input weight
        num_bits (int, optional): num_bits. Defaults to 4.
        group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
        scheme (str, optional): quantization scheme. Defaults to "asym".
        dtype (str, optional): data type. Defaults to "int".
        ratio (float, optional): percentile of clip. Defaults to 1.0.

    Returns:
        output: quantized weight
        scale: scale
        zero_point: zero point
    """
    data = np.reshape(data, (-1, group_size))
    if scheme == "asym" or dtype == "uint":
        maxq = 2**num_bits - 1
        minq = 0
    elif scheme == "sym":
        maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0
        minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1

    rmin = np.min(data, axis=1, keepdims=True) * ratio
    rmax = np.max(data, axis=1, keepdims=True) * ratio
    if scheme == "sym":
        max_range = np.maximum(np.abs(rmin), np.abs(rmax))
        scale = np.ones(rmax.shape)
        mask = max_range > 0
        scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq)
        zero_point = (
            np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1))
        )
    else:
        scale = np.ones(rmax.shape)
        scale[rmin != rmax] = np.array(
            [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()]
        )
        zero_point = (
            ((np.zeros(scale.shape) - rmin) / scale).round()
            if dtype == "int"
            else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8")
        )

    q_weight = np.empty_like(data, dtype=scale.dtype)
    np.divide(data, scale, out=q_weight)
    np.add(q_weight, zero_point, out=q_weight)
    np.round(q_weight, out=q_weight)
    np.clip(q_weight, minq, maxq, out=q_weight)

    return q_weight, scale, zero_point


def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
    """Quantize tensor per group based on k quant.

    Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c

    Args:
        data : input weight
        num_bits (int, optional): num_bits. Defaults to 4.
        group_size (int, optional): how many elements share one scale/zp. Defaults to 32.

    Returns:
        output: quantized weight
        scale: scale
        zero_point: zero point
    """
    data = np.reshape(data, (-1, group_size)).astype(np.float32)  # nb = data.shape[0], (nb, group_size)
    maxq = 2**num_bits - 1
    minq = 0
    sum_x2 = np.sum(data**2, axis=1, keepdims=True)  # (nb, 1)
    av_x = np.sqrt(sum_x2 / group_size)  # (nb, 1)
    weights = np.add(av_x, np.abs(data))  # (nb, group_size)
    rmin = np.min(data, axis=1, keepdims=True)  # (nb, 1)
    rmax = np.max(data, axis=1, keepdims=True)  # (nb, 1)
    sum_w = np.sum(weights, axis=1, keepdims=True)  # (nb, 1)
    sum_x = np.sum(weights * data, axis=1, keepdims=True)  # (nb, group_size)
    iscale = np.ones(rmax.shape, dtype=data.dtype)  # (nb, 1)
    mask = rmin != rmax
    iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
    scale = 1 / iscale
    quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq)  # (nb, group_size)
    diff = scale * quant_data + rmin - data  # (nb, group_size)
    best_mad = np.sum(weights * diff**2, axis=1, keepdims=True)  # (nb, 1)
    nstep = 20
    rdelta = 0.1
    # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
    rrmin = -1
    for is_ in range(nstep):
        iscale_new = np.ones(rmax.shape, dtype=data.dtype)  # (nb, 1)
        factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
        mask = rmin != rmax
        iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
        quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq)  # (nb, group_size)
        mul_weights_quant_data_new = weights * quant_data_new
        sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True)  # (nb, 1)
        sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True)  # (nb, 1)
        sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True)  # (nb, 1)
        D = np.subtract(sum_w * sum_l2, sum_l**2)  # noqa: N806

        this_scale = (sum_w * sum_xl - sum_x * sum_l) / D  # (nb, 1)
        this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D  # (nb, 1)

        diff = this_scale * quant_data_new + this_min - data  # (nb, group_size)
        mad = np.sum(weights * diff**2, axis=1, keepdims=True)  # (nb, 1)

        mad_1 = np.array(mad)
        best_mad_1 = np.array(best_mad)
        idx_to_replace = np.where(mad_1 < best_mad_1)[0]
        quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
        best_mad[idx_to_replace] = mad[idx_to_replace]
        scale[idx_to_replace] = this_scale[idx_to_replace]
        rmin[idx_to_replace] = this_min[idx_to_replace]

    zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
    scale = scale.astype(np.float64)
    q_weight = np.empty_like(data, dtype=scale.dtype)
    np.divide(data, scale, out=q_weight)
    np.add(q_weight, zero_point, out=q_weight)
    np.round(q_weight, out=q_weight)
    np.clip(q_weight, minq, maxq, out=q_weight)

    return q_weight, scale, zero_point


def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
    """Quantize tensor per group based on k quant.

    Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c

    Args:
        data : input weight
        num_bits (int, optional): num_bits. Defaults to 4.
        group_size (int, optional): how many elements share one scale/zp. Defaults to 4.

    Returns:
        output: quantized weight
        scale: scale
        zero_point: zero point
    """
    try:
        import cupy as cp  # noqa: PLC0415
        import torch  # noqa: PLC0415

        if torch.cuda.is_available():
            data = cp.asarray(data)
            data = data.reshape((-1, group_size)).astype(cp.float32)  # nb = data.shape[0], (nb, group_size)
            maxq = 2**num_bits - 1
            minq = 0
            sum_x2 = cp.sum(data**2, axis=1, keepdims=True)  # (nb, 1)
            av_x = cp.sqrt(sum_x2 / group_size)  # (nb, 1)
            weights = cp.add(av_x, cp.abs(data))  # (nb, group_size)
            rmin = cp.min(data, axis=1, keepdims=True)  # (nb, 1)
            rmax = cp.max(data, axis=1, keepdims=True)  # (nb, 1)
            sum_w = cp.sum(weights, axis=1, keepdims=True)  # (nb, 1)
            sum_x = cp.sum(weights * data, axis=1, keepdims=True)  # (nb, group_size)
            iscale = cp.ones(rmax.shape, dtype=data.dtype)  # (nb, 1)
            mask = rmin != rmax
            iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
            scale = 1 / iscale
            quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq)  # (nb, group_size)
            diff = scale * quant_data + rmin - data  # (nb, group_size)
            best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True)  # (nb, 1)
            nstep = 20
            rdelta = 0.1
            rrmin = -1
            for is_ in range(nstep):
                iscale_new = cp.ones(rmax.shape, dtype=data.dtype)  # (nb, 1)
                factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
                mask = rmin != rmax
                iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
                quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq)  # (nb, group_size)
                mul_weights_quant_data_new = weights * quant_data_new
                sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True)  # (nb, 1)
                sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True)  # (nb, 1)
                sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True)  # (nb, 1)
                D = cp.subtract(sum_w * sum_l2, sum_l**2)  # noqa: N806

                this_scale = (sum_w * sum_xl - sum_x * sum_l) / D  # (nb, 1)
                this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D  # (nb, 1)

                diff = this_scale * quant_data_new + this_min - data  # (nb, group_size)
                mad = cp.sum(weights * diff**2, axis=1, keepdims=True)  # (nb, 1)

                mad_1 = cp.array(mad)
                best_mad_1 = cp.array(best_mad)
                idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
                quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
                best_mad[idx_to_replace] = mad[idx_to_replace]
                scale[idx_to_replace] = this_scale[idx_to_replace]
                rmin[idx_to_replace] = this_min[idx_to_replace]

            zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
            scale = scale.astype(cp.float64)
            q_weight = cp.empty_like(data, dtype=scale.dtype)
            cp.divide(data, scale, out=q_weight)
            cp.add(q_weight, zero_point, out=q_weight)
            cp.round(q_weight, out=q_weight)
            cp.clip(q_weight, minq, maxq, out=q_weight)

            return q_weight.get(), scale.get(), zero_point.get()
        else:
            logger.warning(
                "Try to use k-quant quantization on CUDA. However, CUDA is not available."
                "Fall back to k-quant quantization on CPU."
            )
            return quant_tensor_k_quant_cpu(data, num_bits, group_size)
    except ImportError:
        logger.info(
            "Now we are using k-quant quantization on cpu, which is time consuming."
            "Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
            "Please also install torch to check CUDA availability."
        )
        return quant_tensor_k_quant_cpu(data, num_bits, group_size)


def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
    """Quant dequant tensor per group.

    Args:
        data : input weight
        num_bits (int, optional): num_bits. Defaults to 4.
        group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
        scheme (str, optional): quantization scheme. Defaults to "asym".
        dtype (str, optional): data type. Defaults to "int".
        ratio (float, optional): percentile of clip. Defaults to 1.0.

    Returns:
        output: quant-dequant weight
    """
    org_shape = data.shape
    weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio)
    return np.reshape(scale * (weight - zp), org_shape)


def pad_tensor(weight, group_size, k_blocks):
    """Pad tensor rowi so that it can be is divisible by group_size.

    Args:
        weight (array): weight
        group_size (int): how many elements share one scale/zp
        k_blocks (int): the number of block

    Returns:
        weight: paded weight
    """
    if group_size == -1:
        return weight

    org_w_shape = weight.shape
    padded_rows = k_blocks * group_size
    pad_len = padded_rows - org_w_shape[0]

    if pad_len > 0:
        weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant")

    return weight


def rtn_quantize(
    model,
    weight_config={},  # noqa: B006
    num_bits=4,
    group_size=32,
    scheme="asym",
    ratios={},  # noqa: B006
    accuracy_level=0,
    providers=["CPUExecutionProvider"],  # noqa: B006
    algorithm="k_quant",
):
    """Quant the model with round to nearst method.

    Args:
        model (ModelProto or ONNXModel): onnx model
        weight_config (dict): quantization config
                For example,
                weight_config = {
                    'fc2':
                        {
                            'bits': 4,
                            'group_size': 32,
                            'scheme': 'sym',
                            'algorithm': 'RTN'
                        }
                }
        num_bits (int, optional): num_bits. Default is 4.
        group_size (int, optional): how many elements share one scale/zp. Default is 32.
        scheme (str, optional): sym or asym. Defaults to "asym".
        ratios (dict, optional): percentile of clip. Defaults to {}.
        accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8).
        providers (list): providers to use

    Returns:
        model: fake quantized ONNXModel
    """
    model = ONNXModel(model)
    base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
    new_nodes = []
    remove_nodes = []
    total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]])
    curr_id = 0
    for node in model.nodes():
        if node.op_type in ["MatMul"]:
            curr_id += 1
            simple_progress_bar(total_num, curr_id)
        if (
            node.op_type in ["MatMul"]
            and model.get_initializer(node.input[1]) is not None
            and weight_config.get(node.name, {}) != "fp32"
        ):
            weight_tensor = model.get_initializer(node.input[1])
            weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
            if len(weight.shape) != 2:
                continue

            dtype = weight.dtype

            if node.name in weight_config:
                num_bits = weight_config[node.name]["bits"]
                group_size = weight_config[node.name]["group_size"]
                scheme = weight_config[node.name]["scheme"]

            org_w_shape = weight.shape  # ic, oc
            group_size = group_size if group_size != -1 else org_w_shape[0]

            k_blocks = (org_w_shape[0] - 1) // group_size + 1
            init_share_num = model.get_initializer_share_num(node.input[1])

            weight = pad_tensor(weight, group_size, k_blocks)

            satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8  # noqa: N806

            if satisfy_MatMulNBits_condition:  # pragma: no cover
                if algorithm == "k_quant":
                    q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size)
                else:
                    q_weight, scale, zp = quant_tensor(
                        weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
                    )

                q_matmul_node, new_inits = make_matmul_weight_only_node(
                    node=node,
                    weight_shape=org_w_shape,
                    num_bits=num_bits,
                    group_size=group_size,
                    k_blocks=k_blocks,
                    q_weight=q_weight.astype("uint8"),
                    scale=scale.astype(dtype),
                    zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None,
                    accuracy_level=accuracy_level,
                )

                model.add_initializers(new_inits)
                remove_nodes.append(node)
                new_nodes.append(q_matmul_node)
            else:
                q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
                q_weight = np.reshape(q_weight, (org_w_shape[1], -1))
                q_weight = np.transpose(q_weight)
                q_weight = q_weight[: org_w_shape[0], :].astype(dtype)
                q_weight_tensor = onnx.helper.make_tensor(
                    name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
                    data_type=np_dtype_to_tensor_dtype(dtype),
                    dims=weight.shape,
                    vals=q_weight.tobytes(),
                    raw=True,
                )
                model.add_initializer(q_weight_tensor)
                node.input[1] = q_weight_tensor.name
            if init_share_num == 1:
                model.remove_initializer(weight_tensor)

    model.add_nodes(new_nodes)
    model.remove_nodes(remove_nodes)
    model.topological_sort()
    return model


def get_weight_scale(weight, group_size):
    """Get the scale of weight."""
    org_shape = weight.shape
    weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight
    scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0)
    return scale


def prepare_inputs(model, n_samples, dataloader, providers):
    """Prepare inputs for weight only quantization.

    Args:
        model (ModelProto or ONNXModel): onnx model
        n_samples (int, optional): calibration sample number. -1 means all samples.
        dataloader (object): dataloader for calibration.
        providers (list): providers to use

    Returns:
        inputs: prepared inputs.
        so: session options
    """
    from importlib.util import find_spec  # noqa: PLC0415

    from .util import to_numpy  # noqa: PLC0415

    so = ort.SessionOptions()
    if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"):  # pragma: no cover
        from onnxruntime_extensions import get_library_path  # noqa: PLC0415

        so.register_custom_ops_library(get_library_path())
    if model.is_large_model:
        onnx.save_model(
            model.model,
            model.model_path + "_augment.onnx",
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            convert_attribute=False,
        )

    session = (
        ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
        if not model.is_large_model
        else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
    )
    inputs_names = [i.name for i in session.get_inputs()]
    del session

    inputs = []
    for i, data in enumerate(dataloader):
        if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples:
            break
        if len(inputs_names) != 1 or isinstance(data[0], dict):
            assert len(data[0]) == len(inputs_names), (
                f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}"
            )

        if isinstance(data[0], dict):
            inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()]))  # noqa: C404
        elif isinstance(data[0], np.ndarray):  # pragma: no cover
            inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)]))  # noqa: C404
        else:  # pragma: no cover
            inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)]))  # noqa: C404
    return inputs, so


def gptq(
    W,
    H,
    num_bits=4,
    group_size=32,
    scheme="asym",
    blocksize=128,
    percdamp=0.01,
    actorder=False,
    mse=False,
    perchannel=True,
):
    """Quant the weight with GPTQ method.

    Args:
        W (array): weight.
        H (array): Hessian matrix.
        num_bits (int, optional): num_bits. Default is 4.
        group_size (int, optional): how many elements share one scale/zp. Default is 32.
        scheme (str, optional): sym or asym. Defaults to "asym".
        blocksize (int, optional): blocksize to quantize weight.
        percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
        actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
        mse (bool, optional): whether get scale and zero point with mse error.
        perchannel (bool, optional): whether quantize weight per-channel.

    Returns:
        Q: fake quantized weight
    """
    maxq = 2**num_bits - 1
    grid = 100
    maxshrink = 0.8
    norm = 2.4

    def find_params(weight):
        org_shape = weight.shape
        # find zp, scale
        if not perchannel:
            weight = np.expand_dims(weight.flatten(), axis=1)
        tmp = np.zeros(weight.shape[1])
        xmin = np.minimum(np.min(weight, axis=0), tmp)
        xmax = np.maximum(np.max(weight, axis=0), tmp)
        if scheme == "sym":
            xmax = np.maximum(np.abs(xmin), xmax)
            tmp = xmin < 0
            if np.any(tmp):
                xmin[tmp] = -xmax[tmp]
        tmp = (xmin == 0) & (xmax == 0)
        xmin[tmp] = -1
        xmax[tmp] = +1

        scale = (xmax - xmin) / maxq
        if scheme == "sym":
            zero = np.ones(scale.shape) * (maxq + 1) / 2
        else:
            zero = np.round(-xmin / scale)
        if mse:
            best = np.ones([weight.shape[1]]) * float("inf")
            for i in range(int(maxshrink * grid)):
                p = 1 - i / grid
                xmin1 = p * xmin
                xmax1 = p * xmax
                scale1 = (xmax1 - xmin1) / maxq
                zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero
                q = np.clip(np.round(weight / scale1) + zero1, 0, maxq)
                q -= weight
                q = np.power(np.abs(q), norm)
                err = np.sum(q, 0)
                tmp = err < best
                if np.any(tmp):
                    best[tmp] = err[tmp]
                    scale[tmp] = scale1[tmp]
                    zero[tmp] = zero1[tmp]
        if not perchannel:
            tmp = org_shape[1]
            scale = np.repeat(scale, tmp)
            zero = np.repeat(zero, tmp)
        shape = [-1] + [1] * (len(org_shape) - 1)
        scale = np.reshape(scale, shape)
        zero = np.reshape(zero, shape)
        return scale, zero

    shape = W.shape
    scale, zp = find_params(W)
    dead = np.diag(H) == 0
    H[dead, dead] = 1
    W[dead, :] = 0  # such channel makes no contribution to quantization computation

    # rearrange considering the diag's value
    if actorder:
        perm = np.argsort(np.diag(H))[::-1]
        W = W[perm, :]  # noqa: N806
        H = H[perm, :][:, perm]  # noqa: N806
    Losses = np.zeros_like(W)  # noqa: N806
    Q = np.zeros_like(W)  # noqa: N806
    damp = percdamp * np.mean(np.diag(H))
    diag = np.arange(shape[0])
    H[diag, diag] += damp  # add a average value of
    H = np.linalg.cholesky(np.linalg.inv(H)).T  # noqa: N806
    Hinv = H  # noqa: N806
    for i1 in range(0, shape[0], blocksize):
        i2 = min(i1 + blocksize, shape[0])
        count = i2 - i1

        W1 = copy.deepcopy(W[i1:i2, :])  # noqa: N806
        Q1 = np.zeros_like(W1)  # noqa: N806
        Err1 = np.zeros_like(W1)  # noqa: N806
        Losses1 = np.zeros_like(W1)  # noqa: N806
        Hinv1 = Hinv[i1:i2, i1:i2]  # noqa: N806

        for i in range(count):  # within a block, channel wise
            w = W1[i, :]
            d = Hinv1[i, i]

            if group_size != -1:
                if (i1 + i) % group_size == 0:
                    scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :])

            q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten()
            Q1[i, :] = q
            Losses1[i, :] = (w - q) ** 2 / d**2

            err1 = (w - q) / d
            W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0))
            Err1[i, :] = err1

        Q[i1:i2, :] = Q1
        Losses[i1:i2, :] = Losses1 / 2

        W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1)

    if actorder:
        invperm = np.argsort(perm)
        Q = Q[invperm, :]  # noqa: N806

    Q = np.reshape(Q, W.shape)  # noqa: N806
    del W
    return Q


def gptq_quantize(
    model,
    dataloader,
    weight_config={},  # noqa: B006
    num_bits=4,
    group_size=32,
    scheme="asym",
    n_samples=128,
    percdamp=0.01,
    blocksize=128,
    actorder=False,
    mse=False,
    perchannel=True,
    accuracy_level=0,
    providers=["CPUExecutionProvider"],  # noqa: B006
):
    """Quant the model with GPTQ method.

    Args:
        model (ModelProto or ONNXModel): onnx model
        dataloader (object): dataloader for calibration.
        weight_config (dict): quantization config
                For example,
                weight_config = {
                    'fc2':
                        {
                            'bits': 4,
                            'group_size': 32,
                            'scheme': 'sym',
                            'algorithm': 'GPTQ'
                        }
                }
        num_bits (int, optional): num_bits. Default is 4.
        group_size (int, optional): how many elements share one scale/zp. Default is 32.
        scheme (str, optional): sym or asym. Defaults to "asym".
        n_samples (int, optional): calibration sample number.
        percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
        blocksize (int, optional): blocksize to quantize weight.
        actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
        mse (bool, optional): whether get scale and zero point with mse error.
        perchannel (bool, optional): whether quantize weight per-channel.
        accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
        providers (list): providers to use

    Returns:
        model: fake quantized ONNXModel
    """
    model = ONNXModel(model)
    base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""

    inputs, so = prepare_inputs(model, n_samples, dataloader, providers)
    del dataloader
    org_output = copy.deepcopy(model.model.graph.output)
    model.remove_tensors_from_outputs([i.name for i in org_output])
    output_names = []
    for node in model.nodes():
        if (
            node.op_type in ["MatMul"]
            and weight_config.get(node.name, {}) != "fp32"
            and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
        ):
            output_names.append(node.input[0])
    output_names = list(set(output_names))
    model.add_tensors_to_outputs(output_names)
    if model.is_large_model:
        onnx.save_model(
            model.model,
            model.model_path + "_augment.onnx",
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            convert_attribute=False,
        )

    session = (
        ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
        if not model.is_large_model
        else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
    )

    for idx, input_name in enumerate(output_names):
        simple_progress_bar(len(output_names), idx + 1)
        node_list = []
        weights = []

        for node in model.input_name_to_nodes[input_name]:
            if (
                node.op_type in ["MatMul"]
                and weight_config.get(node.name, {}) != "fp32"
                and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
                and model.get_initializer(node.input[1]) is not None
            ):
                weight = numpy_helper.to_array(
                    model.get_initializer(model.get_node(node.name).input[1]), base_dir
                ).copy()
                if len(weight.shape) != 2:
                    continue

                weights.append(weight)
                node_list.append(model.get_node(node.name))

        if len(weights) == 0:
            continue

        Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights]  # noqa: N806
        nsamples = 0
        for data in inputs:
            inp = session.run([input_name], data)[0]
            tmp = inp.shape[0]
            inp = np.reshape(inp, (-1, inp.shape[-1]))
            Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs]  # noqa: N806
            nsamples += tmp
            inp = np.sqrt(2 / nsamples) * inp
            Hs = [i + np.matmul(inp.T, inp) for i in Hs]  # noqa: N806

        for (
            node,
            weight,
            H,  # noqa: N806
        ) in zip(node_list, weights, Hs, strict=False):
            if node.name in weight_config:
                num_bits = weight_config[node.name]["bits"]
                group_size = weight_config[node.name]["group_size"]
                scheme = weight_config[node.name]["scheme"]
            group_size = group_size if group_size != -1 else weight.shape[0]
            dtype = weight.dtype

            q_weight = gptq(
                weight,
                H,
                num_bits=num_bits,
                group_size=group_size,
                scheme=scheme,
                blocksize=blocksize,
                percdamp=percdamp,
                actorder=actorder,
                mse=mse,
                perchannel=perchannel,
            )

            weight_tensor = model.get_initializer(node.input[1])
            init_share_num = model.get_initializer_share_num(node.input[1])

            satisfy_MatMulNBits_condition = num_bits == 4  # noqa: N806

            if satisfy_MatMulNBits_condition:  # pragma: no cover
                org_shape = weight.shape
                k_blocks = (org_shape[0] + group_size - 1) // group_size
                q_weight = pad_tensor(q_weight, group_size, k_blocks)
                q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint")
                q_matmul_node, new_inits = make_matmul_weight_only_node(
                    node=node,
                    weight_shape=org_shape,
                    num_bits=num_bits,
                    group_size=group_size,
                    k_blocks=k_blocks,
                    q_weight=q_weight.astype("uint8"),
                    scale=scale.astype(dtype),
                    zero_point=zp if scheme == "asym" else None,
                    accuracy_level=accuracy_level,
                )

                model.add_initializers(new_inits)
                model.remove_node(node)
                model.add_node(q_matmul_node)
            else:
                q_weight_tensor = onnx.helper.make_tensor(
                    name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
                    data_type=np_dtype_to_tensor_dtype(dtype),
                    dims=q_weight.shape,
                    vals=q_weight.astype(dtype).tobytes(),
                    raw=True,
                )
                model.add_initializer(q_weight_tensor)
                node.input[1] = q_weight_tensor.name
            if init_share_num == 1:
                model.remove_initializer(weight_tensor)

    model.remove_tensors_from_outputs(output_names)
    model.model.graph.output.MergeFrom(org_output)

    model.topological_sort()

    # reload external data to prevent external data file path errors
    if model.is_large_model:
        from onnx.external_data_helper import load_external_data_for_model  # noqa: PLC0415

        load_external_data_for_model(model.model, os.path.split(model.model_path)[0])

    return model