Repository URL to install this package:
|
Version:
0.2.4+cu122 ▾
|
import torch
try:
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
def pack_intweight(unpacked_qweight, interleave, kstride):
# unpacked_qweight: [N, K]
N = unpacked_qweight.shape[0]
K = unpacked_qweight.shape[1]
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
Packed_Kernel = Packed_Kernel.reshape(N, K)
# interleaving every four rows
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, interleave, K // kstride, kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, K // kstride, kstride, interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel = (
Packed_Kernel[..., 0]
| (Packed_Kernel[..., 1] << 4)
| (Packed_Kernel[..., 2] << 8)
| (Packed_Kernel[..., 3] << 12)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
qweight = (
torch.tensor(Packed_Kernel.astype("int16"))
.to(unpacked_qweight.device)
.contiguous()
)
return qweight
class WQLinear_GEMVFast(torch.nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
self.interleave = 4
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = 32 // self.w_bit
int16_pack_num = 16 // self.w_bit
assert out_features % (self.interleave) == 0
self.register_buffer(
"qweight",
torch.zeros(
(
out_features // self.interleave,
in_features // int16_pack_num * self.interleave,
),
dtype=torch.int16,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(
calculate_zeros_width(in_features, self.group_size) * pack_num,
out_features,
),
dtype=torch.float16,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
calculate_zeros_width(in_features, self.group_size) * pack_num,
out_features,
),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only:
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(linear.in_features, group_size) * pack_num,
),
dtype=torch.float16,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
# awq_linear.scales = scales.clone().half()
awq_linear.scales = qscales.transpose(1, 0).contiguous()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ qscales[:, idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
awq_linear.qweight = pack_intweight(
intweight.contiguous(), interleave=4, kstride=64
)
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros_like(qscales)
qzeros[:, : scales.shape[1]] = -(
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(torch.float16)
awq_linear.qzeros = qzeros.transpose(1, 0).contiguous()
return awq_linear
@torch.no_grad()
def forward(self, x):
inputs = x
if inputs.numel() / inputs.shape[-1] < 8:
out = awq_v2_ext.gemv_forward_cuda_decode(
inputs,
self.qweight,
self.scales,
self.qzeros,
inputs.numel() // inputs.shape[-1],
self.out_features,
self.in_features,
self.group_size,
)
else:
out = awq_v2_ext.gemm_forward_cuda_prefill(
inputs, self.qweight, self.scales, self.qzeros
)
out = out + self.bias if self.bias is not None else out
return out