Repository URL to install this package:
|
Version:
0.7.1+cu122 ▾
|
import gc
import os
from logging import getLogger
import accelerate
import huggingface_hub
import torch
from accelerate.utils import find_tied_parameters
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_marlin import _get_perms, unpack_qzeros
from .import_utils import MARLIN_AVAILABLE, MARLIN_EXCEPTION
from .modeling_utils import recurse_getattr, recurse_setattr
if MARLIN_AVAILABLE:
import autogptq_marlin_cuda
logger = getLogger(__name__)
def prepare_model_for_marlin_load(
model_name_or_path,
model,
quantize_config,
quant_linear_class,
torch_dtype,
current_model_save_name,
device_map,
):
# The model (e.g. model.safetensors) is already serialized in the Marlin format, load it directly.
if hasattr(quantize_config, "is_marlin_format") and quantize_config.is_marlin_format:
model_save_name = current_model_save_name
logger.info(f"Loading a GPTQ model, detected Marlin serialized format at {model_save_name}.")
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)
else:
model_save_name = _get_cached_marlin_save_name(model_name_or_path)
# If GPTQ model has Marlin version cached locally, load from the cached version (no repacking needed).
if os.path.isfile(model_save_name):
logger.info(
f"Loading a GPTQ model, detected a cached repacked weight for Marlin kernel at {model_save_name}."
)
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)
# Otherwise, convert the model to Marlin format first and cache locally.
else:
# Loading the GPTQ checkpoint to do the conversion.
# TODO: Avoid loading the model with wrong QuantLinear, and directly use
# Marlin ones. The repacking can be done directly on the safetensors, just
# as for AWQ checkpoints.
accelerate.utils.modeling.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=current_model_save_name,
device_map=device_map,
offload_state_dict=True,
offload_buffers=True,
)
# Convert model to marlin, repacking weights into Marlin format.
model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=True)
# Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202
tied_params = find_tied_parameters(model)
for weight_group in tied_params:
for param_name in weight_group:
if isinstance(recurse_getattr(model, param_name), torch.nn.Parameter):
recurse_setattr(
model,
param_name,
torch.nn.Parameter(recurse_getattr(model, param_name).clone()),
)
else:
recurse_setattr(
model,
param_name,
recurse_getattr(model, param_name).clone(),
)
# Cache the converted model.
safe_save(model.state_dict(), model_save_name)
return model, model_save_name
def _get_cached_marlin_save_name(model_name_or_path):
"""
Gets The Cached Weight Path.
If remote: $HF_HOME/assets/autogptq/{model_name_or_path}/autogptq_model.safetensors
If local: {model_name_or_path}/autogptq_model.safetensors
"""
if os.path.isdir(model_name_or_path):
return os.path.join(model_name_or_path, "autogptq_model.safetensors")
else:
namespace, subfolder = model_name_or_path.split("/")
assets_path = huggingface_hub.cached_assets_path(
library_name="auto_gptq", namespace=namespace, subfolder=subfolder
)
return os.path.join(assets_path, "autogptq_model.safetensors")
# Validate marlin suppor
def _validate_marlin_device_support():
device_capacity = torch.cuda.get_device_capability()
return device_capacity[0] == 8 and (device_capacity[1] == 0 or device_capacity[1] == 6)
# Adapted from https://github.com/rib-2/marlin/tree/conversion
def _validate_marlin_compatibility(quantization_config):
if not MARLIN_AVAILABLE:
return f"AutoGPTQ is not compiled with the Marlin kernel, with the following error: {MARLIN_EXCEPTION}"
if quantization_config.bits != 4:
return f"The quantized model uses a bitwidth different than 4 (found {quantization_config.bits})"
if quantization_config.group_size != 128 and quantization_config.group_size != -1:
return "The quantized model uses a group size that is not 128 or -1 (found quantization_config.group_size)"
if not quantization_config.sym:
return "The quantized model uses asymmetric quantization"
if quantization_config.desc_act:
return "The quantized model uses act-order (also called desc-act) scheme"
return None
@torch.no_grad()
def convert_to_marlin(model, model_quantlinear, quantization_config, repack: bool, strict: bool = False):
"""
Converts GPTQ-packed weights to the Marlin format. This assumes that the model already meets Marlin kernel constraints.
Arguments:
repack (`bool`):
Whether to repack the qweights from `model` into the Marlin's QuantLinear layers.
"""
if repack:
message = "Repacking weights to be compatible with Marlin kernel..."
else:
# TODO: load directly Marlin QuantLinear.
message = "Overriding QuantLinear layers to use Marlin's QuantLinear..."
for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue
parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1 :]
# We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights
# from checkpoints holding zero bias.
with torch.device("meta"):
new_module = MarlinQuantLinear(
bits=4,
group_size=module.group_size,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
bias=module.bias is not None,
trainable=False,
)
# workspace is never in the state_dict, thus we need to allocate it manually.
new_module.workspace = torch.zeros(module.outfeatures // 128 * 16, dtype=torch.int, device=module.device)
# Dequantize the weight.
if repack:
marlin_repacked_weight = autogptq_marlin_cuda.gptq_repack(module.qweight)
if strict:
dequantized_qzeros = unpack_qzeros(module.qzeros)
if not torch.all(dequantized_qzeros == 8):
raise ValueError(
"Marlin kernel is compatible only with checkpoints using symetric quantization. "
"Found non-symmetric quantization for the weight {name}."
)
_, _scale_perm, _scale_perm_single = _get_perms()
s = module.scales.data.clone()
if module.group_size != module.infeatures:
s = s.reshape((1, -1))
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, module.outfeatures)).contiguous()
new_module.B = marlin_repacked_weight
new_module.s = s
new_module.bias = module.bias
new_module = new_module.to(module.device)
# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, new_module)
# Free cuda memory.
del module
if repack:
del marlin_repacked_weight
gc.collect()
# Set quantization config to be Marlin.
quantization_config.is_marlin_format = True
return model