Repository URL to install this package:
|
Version:
0.7.1+cu122 ▾
|
import torch.nn as nn
class GeneralQuantLinear(nn.Linear):
def __init__(self, quant_linear_module):
super().__init__(
in_features=quant_linear_module.infeatures,
out_features=quant_linear_module.outfeatures,
bias=True,
)
self.infeatures = quant_linear_module.infeatures
self.outfeatures = quant_linear_module.outfeatures
self.bits = quant_linear_module.bits
self.group_size = quant_linear_module.group_size
self.maxq = quant_linear_module.maxq
self.weight.requires_grad = False
self.weight.data = quant_linear_module.qweight
self.register_buffer("qweight", quant_linear_module.qweight)
self.bias.data = quant_linear_module.bias
self.qweight.requires_grad = False
self.bias.requires_grad = False
self.register_buffer("qzeros", quant_linear_module.qzeros)
self.register_buffer("scales", quant_linear_module.scales)
self.register_buffer("g_idx", quant_linear_module.g_idx)
if hasattr(quant_linear_module, "wf"):
self.wf = quant_linear_module.wf
if hasattr(quant_linear_module, "kernel_switch_threshold"):
self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
if hasattr(quant_linear_module, "autogptq_cuda_available"):
self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
self.trainable = quant_linear_module.trainable
self.forward = quant_linear_module.forward
@classmethod
def inject_to_model(cls, model, target_module_type):
for name, m in model.named_modules():
if not isinstance(m, target_module_type):
continue
new_m = cls(m)
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, new_m)