Repository URL to install this package:
|
Version:
0.7.1+cu122 ▾
|
auto_gptq
/
test_repacking.py
|
|---|
import copy
import unittest
import autogptq_marlin_cuda
import torch
import torch.nn as nn
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as CudaOldQuantLinear
from auto_gptq.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from auto_gptq.nn_modules.qlinear.qlinear_marlin import _get_perms, dequantize_weight
def gen_quant4(k, n, groupsize=-1):
maxq = 2 ** 4 - 1
w = torch.randn((k, n), dtype=torch.half, device="cpu")
original_w = w.clone()
if groupsize != -1:
w = w.reshape((-1, groupsize, n))
w = w.permute(1, 0, 2)
w = w.reshape((groupsize, -1))
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / maxq
# Quantize.
w = torch.round(w / s).int()
# Unsigned storage.
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
# Dequantize.
ref = (w - (maxq + 1) // 2).half() * s
if groupsize != -1:
def reshape(w):
w = w.reshape((groupsize, -1, n))
w = w.permute(1, 0, 2)
w = w.reshape((k, n)).contiguous()
return w
ref = reshape(ref)
w = reshape(w)
s = s.reshape((-1, n)).contiguous()
linear = nn.Linear(k, n, bias=False)
linear.weight.data = ref.t()
return original_w, linear, s
original_w, linear, s = gen_quant4(64, 128)
class TestRepacking(unittest.TestCase):
def test_marlin_fast_repacking(self):
k = 2048
n = 1024
m = 5
group_size = 128
_, linear, s = gen_quant4(k, n, group_size)
cuda_old_linear = CudaOldQuantLinear(bits=4, group_size=group_size, infeatures=k, outfeatures=n, bias=False)
zeros = torch.full((k // group_size, n), 8, dtype=torch.int32)
cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None)
# Adapted from utils.marlin_utils.convert_to_marlin
dequantized_weight, dequantized_qzeros = dequantize_weight(cuda_old_linear)
dequantized_weight = dequantized_weight.to(torch.float16)
self.assertTrue(torch.all(dequantized_qzeros == 8))
linear_module = torch.nn.Linear(
in_features=k,
out_features=n,
bias=False,
dtype=torch.float16,
device="cuda",
)
linear_module.weight.data.copy_(linear.weight.data) # Not using dequantized_weight to avoid approx
# Create new linear method and copy to model.
marlin_linear = MarlinQuantLinear(
bits=4,
group_size=group_size,
infeatures=k,
outfeatures=n,
bias=False,
trainable=False,
)
marlin_linear.pack(linear_module.to("cuda"), scales=copy.deepcopy(cuda_old_linear.scales.data.t()).to("cuda"))
inp = torch.rand(m, k, dtype=torch.float16, device="cuda")
cuda_old_linear = cuda_old_linear.to("cuda")
marlin_linear = marlin_linear.to("cuda")
with torch.no_grad():
res_cuda_old = cuda_old_linear(inp)
res_marlin = marlin_linear(inp)
reldiff = (res_cuda_old - res_marlin).abs() / (res_cuda_old.abs() + 1e-12)
self.assertTrue(torch.mean(reldiff) < 4e-3)
weight_repacked = autogptq_marlin_cuda.gptq_repack(cuda_old_linear.qweight)
self.assertTrue(torch.allclose(weight_repacked, marlin_linear.B))
_, _scale_perm, _scale_perm_single = _get_perms()
s = cuda_old_linear.scales.data.clone()
if group_size != k:
s = s.reshape((1, -1))
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, n)).contiguous()
self.assertTrue(torch.allclose(s, marlin_linear.s))