Repository URL to install this package:
|
Version:
0.7.1+cu122 ▾
|
# Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
logger = getLogger(__name__)
try:
import autogptq_marlin_cuda
except ImportError as e:
marlin_import_exception = e
def error_raiser_marlin(*args, **kwargs):
raise ValueError(
f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}"
)
autogptq_marlin_cuda = error_raiser_marlin
def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
"""Marlin FP16xINT4 multiply; can be used within `torch.compile`.
@A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
@B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
@C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
@s: `torch.half` scales of shape `(m / group_size, n)`
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
"""
autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)
# Precompute permutations for Marlin weight and scale shuffling
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
class QuantLinear(nn.Module):
QUANT_TYPE = "marlin"
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if torch.version.hip:
raise ValueError("Can not use Marlin int4*fp16 kernel with AMD ROCm version of PyTorch as the kernel is not compatible. Please do not use `use_marlin=True` when using ROCm devices.")
if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `use_marlin=True`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).')
if infeatures % 128 != 0 or outfeatures % 256 != 0:
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
if group_size not in [-1, 128] and group_size != infeatures:
raise ValueError("Only group_size -1 and 128 are supported.")
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")
if trainable:
raise NotImplementedError("Marlin does not support train.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.group_size = group_size if group_size != -1 else infeatures
self.register_buffer(
"B",
torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int),
)
self.register_buffer(
"s",
torch.empty((self.infeatures // group_size, self.outfeatures), dtype=torch.half),
)
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
self.register_buffer(
"workspace",
torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int),
persistent=False,
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half))
else:
self.bias = None
def post_init(self):
pass
def pack(self, linear, scales):
"""Pack a fake-quantized linear layer into this actual Marlin representation.
@linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
@scales: corresponding quantization scales of shape `(infeatures, groups)`
"""
if linear.weight.dtype != torch.half:
raise ValueError("Only `torch.half` weights are supported.")
tile = 16
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()
if self.group_size != self.infeatures:
w = w.reshape((-1, self.group_size, self.outfeatures))
w = w.permute(1, 0, 2)
w = w.reshape((self.group_size, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if self.group_size != self.infeatures:
w = w.reshape((self.group_size, -1, self.outfeatures))
w = w.permute(1, 0, 2)
w = w.reshape((self.infeatures, self.outfeatures)).contiguous()
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, self.outfeatures)).contiguous()
w = w.reshape((self.infeatures // tile, tile, self.outfeatures // tile, tile))
w = w.permute((0, 2, 1, 3))
w = w.reshape((self.infeatures // tile, self.outfeatures * tile))
res = w
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
self.B[:, :] = q.to(self.B.device)
self.s[:, :] = s.to(self.s.device)
if self.bias is not None:
self.bias[:] = linear.bias.data.to(self.bias.device)
def forward(self, A):
A = A.half()
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
mul(
A.view((-1, A.shape[-1])),
self.B,
C.view((-1, C.shape[-1])),
self.s,
self.workspace,
)
C = C + self.bias if self.bias is not None else C
return C
# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
# Unpack 4-bit values and interpret them as signed integers
unpacked_weights = torch.zeros(
(qweight.shape[0] * 8, qweight.shape[1]),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * 8),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for row in range(unpacked_weights.shape[0]):
i = row % 8
unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
for col in range(unpacked_zeros.shape[1]):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
return unpacked_weights, unpacked_zeros + 1
def unpack_qzeros(qzeros):
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * 8),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
return unpacked_zeros + 1
# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def dequantize_weight(layer):
qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
group_size = unpacked_qweight.shape[0] // scales.shape[0]
scales = scales.repeat_interleave(group_size, dim=0)
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
return unpacked_qweight.T, unpacked_qzeros
def dequantize_qzeros(layer):
qzeros = layer.qzeros
unpacked_qzeros = unpack_qzeros(qzeros)
group_size = layer.group_size
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
return unpacked_qzeros
__all__ = ["QuantLinear", "dequantize_weight"]