Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / models / quantization.py
Size: Mime:
import bitsandbytes as bnb
import torch
import torch.nn as nn
import logging
import dataclasses
import typing as t

logger = logging.getLogger(__name__)


class TorchDType:
    @classmethod
    def __get_validators__(cls) -> t.Generator:
        yield cls.validate

    @classmethod
    def validate(cls, v: t.Union[torch.dtype, str]) -> torch.dtype:
        if isinstance(v, torch.dtype):
            return v
        elif isinstance(v, str):
            try:
                return t.cast(torch.dtype, getattr(torch, v))
            except AttributeError:
                raise ValueError(f"Invalid torch.dtype: {v}") from None
        else:
            raise ValueError(f"Invalid type for torch.dtype: {type(v)}")


@dataclasses.dataclass()
class QuantizationConfig:
    load_in_4bit: bool
    bnb_4bit_use_double_quant: bool = False
    bnb_4bit_quant_type: t.Literal["nf4", "fp4"] = "nf4"
    bnb_4bit_compute_dtype: TorchDType = torch.float16  # type: ignore


def set_module_quantized_tensor_to_device(
    module: nn.Module,
    tensor_name: str,
    device: t.Union[int, str, torch.device],
    value: t.Optional[torch.Tensor] = None,
    quantized_stats: t.Optional[t.Dict[str, t.Any]] = None,
) -> None:
    """
    A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
    `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
    function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
    class `Int8Params` from `bitsandbytes`.

    Args:
        module (`torch.nn.Module`):
            The module in which the tensor we want to move lives.
        tensor_name (`str`):
            The full name of the parameter/buffer.
        device (`int`, `str` or `torch.device`):
            The device on which to set the tensor.
        value (`torch.Tensor`, *optional*):
            The value of the tensor (useful when going from the meta device to any other device).
        quantized_stats (`dict[str, Any]`, *optional*):
            Dict with items for either 4-bit or 8-bit serialization
    """
    # Recurse if needed
    if "." in tensor_name:
        splits = tensor_name.split(".")
        for split in splits[:-1]:
            new_module = getattr(module, split)
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            module = new_module
        tensor_name = splits[-1]

    if (
        tensor_name not in module._parameters
        and tensor_name not in module._buffers
    ):
        raise ValueError(
            f"{module} does not have a parameter or a buffer named {tensor_name}."
        )
    is_buffer = tensor_name in module._buffers
    old_value = getattr(module, tensor_name)

    if (
        old_value.device == torch.device("meta")
        and device not in ["meta", torch.device("meta")]
        and value is None
    ):
        raise ValueError(
            f"{tensor_name} is on the meta device, we need a `value` to put in on {device}."
        )

    prequantized_loading = quantized_stats is not None
    is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(
        module._parameters[tensor_name], bnb.nn.Params4bit
    )

    if is_4bit:
        param = module._parameters[tensor_name]
        if param.device.type != "cuda":  # type:ignore[union-attr] #check is done by is_4bit
            if value is None:
                new_value = old_value.to(device)
            elif isinstance(value, torch.Tensor):
                new_value = value.to("cpu")
            else:
                new_value = torch.tensor(value, device="cpu")

            kwargs = old_value.__dict__

            if prequantized_loading != (
                new_value.dtype in (torch.int8, torch.uint8)
            ):
                raise ValueError(
                    f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
                )

        if prequantized_loading:
            new_value = bnb.nn.Params4bit.from_prequantized(
                data=new_value,
                quantized_stats=quantized_stats,
                requires_grad=False,
                device=device,
                **kwargs,
            )
        else:
            new_value = bnb.nn.Params4bit(
                new_value, requires_grad=False, **kwargs
            ).to(device)
        module._parameters[tensor_name] = new_value
    else:
        if value is None:
            new_value = old_value.to(device)
        elif isinstance(value, torch.Tensor):
            new_value = value.to(device)
        else:
            new_value = torch.tensor(value, device=device)
        if is_buffer:
            module._buffers[tensor_name] = new_value
        else:
            new_value = nn.Parameter(
                new_value, requires_grad=old_value.requires_grad
            )
            module._parameters[tensor_name] = new_value


def _replace_with_bnb_linear(
    model: nn.Module,
    quantization_config: QuantizationConfig,
    modules_to_not_convert: t.Optional[t.List[str]] = None,
    current_key_name: t.Optional[t.List[str]] = None,
    has_been_replaced: bool = False,
) -> t.Tuple[nn.Module, bool]:
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """
    if modules_to_not_convert is None:
        modules_to_not_convert = []

    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)

        if (
            isinstance(module, nn.Linear)
            and name not in modules_to_not_convert
        ):
            # Check if the current key is not in the `modules_to_not_convert`
            if not any(
                key in ".".join(current_key_name)
                for key in modules_to_not_convert
            ):
                in_features = module.in_features
                out_features = module.out_features
                model._modules[name] = bnb.nn.Linear4bit(
                    in_features,
                    out_features,
                    module.bias is not None,
                    quantization_config.bnb_4bit_compute_dtype,
                    compress_statistics=quantization_config.bnb_4bit_use_double_quant,
                    quant_type=quantization_config.bnb_4bit_quant_type,
                )
                has_been_replaced = True
            # Store the module class in case we need to transpose the weight later
            model._modules[name].source_cls = type(module)  # type:ignore [union-attr]
            # Force requires grad to False to avoid unexpected errors
            model._modules[name].requires_grad_(False)  # type:ignore [union-attr]

        if len(list(module.children())) > 0:
            _, has_been_replaced = _replace_with_bnb_linear(
                module,
                modules_to_not_convert=modules_to_not_convert,
                current_key_name=current_key_name,
                quantization_config=quantization_config,
                has_been_replaced=has_been_replaced,
            )
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model, has_been_replaced


def replace_with_bnb_linear(
    model: nn.Module,
    quantization_config: QuantizationConfig,
    current_key_name: t.Optional[t.List[str]] = None,
) -> nn.Module:
    """
    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
    library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
    bitsandbytes`

    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
    be kept as a `torch.nn.Linear` module..

    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
            Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
            for numerical stability reasons.
        current_key_name (`List[`str`]`, *optional*):
            An array to track the current key of the recursion. This is used to check whether the current key (part of
            it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
            `disk`).
    """
    modules_to_not_convert = ["output", "lora_a", "lora_b"]
    model, has_been_replaced = _replace_with_bnb_linear(
        model,
        modules_to_not_convert=modules_to_not_convert,
        current_key_name=current_key_name,
        quantization_config=quantization_config,
    )
    assert has_been_replaced, (
        "You are loading your model in 4bit but no linear modules were found in your model."
    )
    return model