Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import bitsandbytes as bnb
import torch
import torch.nn as nn
import logging
import dataclasses
import typing as t
logger = logging.getLogger(__name__)
class TorchDType:
@classmethod
def __get_validators__(cls) -> t.Generator:
yield cls.validate
@classmethod
def validate(cls, v: t.Union[torch.dtype, str]) -> torch.dtype:
if isinstance(v, torch.dtype):
return v
elif isinstance(v, str):
try:
return t.cast(torch.dtype, getattr(torch, v))
except AttributeError:
raise ValueError(f"Invalid torch.dtype: {v}") from None
else:
raise ValueError(f"Invalid type for torch.dtype: {type(v)}")
@dataclasses.dataclass()
class QuantizationConfig:
load_in_4bit: bool
bnb_4bit_use_double_quant: bool = False
bnb_4bit_quant_type: t.Literal["nf4", "fp4"] = "nf4"
bnb_4bit_compute_dtype: TorchDType = torch.float16 # type: ignore
def set_module_quantized_tensor_to_device(
module: nn.Module,
tensor_name: str,
device: t.Union[int, str, torch.device],
value: t.Optional[torch.Tensor] = None,
quantized_stats: t.Optional[t.Dict[str, t.Any]] = None,
) -> None:
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
class `Int8Params` from `bitsandbytes`.
Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
tensor_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
quantized_stats (`dict[str, Any]`, *optional*):
Dict with items for either 4-bit or 8-bit serialization
"""
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
if (
tensor_name not in module._parameters
and tensor_name not in module._buffers
):
raise ValueError(
f"{module} does not have a parameter or a buffer named {tensor_name}."
)
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
if (
old_value.device == torch.device("meta")
and device not in ["meta", torch.device("meta")]
and value is None
):
raise ValueError(
f"{tensor_name} is on the meta device, we need a `value` to put in on {device}."
)
prequantized_loading = quantized_stats is not None
is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(
module._parameters[tensor_name], bnb.nn.Params4bit
)
if is_4bit:
param = module._parameters[tensor_name]
if param.device.type != "cuda": # type:ignore[union-attr] #check is done by is_4bit
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
else:
new_value = torch.tensor(value, device="cpu")
kwargs = old_value.__dict__
if prequantized_loading != (
new_value.dtype in (torch.int8, torch.uint8)
):
raise ValueError(
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
)
if prequantized_loading:
new_value = bnb.nn.Params4bit.from_prequantized(
data=new_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=device,
**kwargs,
)
else:
new_value = bnb.nn.Params4bit(
new_value, requires_grad=False, **kwargs
).to(device)
module._parameters[tensor_name] = new_value
else:
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)
if is_buffer:
module._buffers[tensor_name] = new_value
else:
new_value = nn.Parameter(
new_value, requires_grad=old_value.requires_grad
)
module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(
model: nn.Module,
quantization_config: QuantizationConfig,
modules_to_not_convert: t.Optional[t.List[str]] = None,
current_key_name: t.Optional[t.List[str]] = None,
has_been_replaced: bool = False,
) -> t.Tuple[nn.Module, bool]:
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if (
isinstance(module, nn.Linear)
and name not in modules_to_not_convert
):
# Check if the current key is not in the `modules_to_not_convert`
if not any(
key in ".".join(current_key_name)
for key in modules_to_not_convert
):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module) # type:ignore [union-attr]
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) # type:ignore [union-attr]
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_bnb_linear(
model: nn.Module,
quantization_config: QuantizationConfig,
current_key_name: t.Optional[t.List[str]] = None,
) -> nn.Module:
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module..
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["output", "lora_a", "lora_b"]
model, has_been_replaced = _replace_with_bnb_linear(
model,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
)
assert has_been_replaced, (
"You are loading your model in 4bit but no linear modules were found in your model."
)
return model