Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / models / convert_weights.py
Size: Mime:
import json
import os
import re
import typing as t
from typing import Dict

import torch
from safetensors.torch import load_file as safe_load_file

from .modules.peft.utils import LORA_ATTN_MODULES
from .quantization import set_module_quantized_tensor_to_device

SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"


# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
    "model.embed_tokens.weight": "tok_embeddings.weight",
    "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
    "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
    "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
    "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
    "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
    "model.norm.weight": "norm.scale",
    "lm_head.weight": "output.weight",
    # specific to phi3
    "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
    "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
}


def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
    try:
        if "layers" in key:
            # Replace layer number with "{}" to create key for lookup
            abstract_key = re.sub(r"(\.\d+)", ".{}", key)
            layer_num = re.search(r"\d+", key).group(0)  # type:ignore[union-attr]
            new_key = mapping_dict[abstract_key]
            new_key = new_key.format(layer_num)
        else:
            new_key = mapping_dict[key]
    except KeyError as e:
        raise Exception(
            f'Error converting the state dict. Found unexpected key: "{key}". '
            "Please make sure you're loading a checkpoint with the right format. "
        ) from e

    return new_key


def hf_to_sarus(
    state_dict: Dict[str, torch.Tensor],
    foundation_model_name: str,
    num_heads: int = 32,
    num_kv_heads: int = 32,
    dim: int = 4096,
    head_dim: t.Optional[int] = None,
    triton_kernel: bool = False,
) -> Dict[str, torch.Tensor]:
    """
    Convert a state dict from HF's format to torchtune's format. State dicts
    from multiple checkpoint files should be consolidated into a single state dict
    before calling this function.

    Args:
        state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
        num_heads (int): Number of heads in the model.
        num_kv_heads (int): Number of heads in the key/value projection layers.
        dim (int): Dimension of the model.
        head_dim (int): Dimension of the head. If not provided, it will be calculated
            as dim // num_heads.

    Returns:
        Dict[str, torch.Tensor]: State dict in torchtune's format.
    """

    if "phi" in foundation_model_name:
        return convert_phi_state_dict(state_dict=state_dict)

    converted_state_dict = {}
    if head_dim is None:
        head_dim = dim // num_heads

    def _permute(t: torch.Tensor, n_heads: int) -> torch.Tensor:
        return (
            t.view(n_heads, 2, head_dim // 2, dim)
            .transpose(1, 2)
            .reshape((head_dim * n_heads), dim)
        )

    for key, value in state_dict.items():
        if (
            "rotary_emb.inv_freq" not in key
        ):  # Skip loading the position embeddings
            new_key = get_mapped_key(key, _FROM_HF)  # type:ignore
            if "q_proj" in key and not triton_kernel:
                value = _permute(value, num_heads)
            elif "k_proj" in key and not triton_kernel:
                value = _permute(value, num_kv_heads)

            converted_state_dict[new_key] = value
    return converted_state_dict


def convert_phi_state_dict(
    state_dict: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
    converted_state_dict = {}
    for key, value in state_dict.items():
        new_key = get_mapped_key(key, _FROM_HF)  # type:ignore
        if "qkv" in key:
            (
                q,
                k,
                v,
            ) = value.chunk(3, dim=0)
            converted_state_dict[new_key] = q
            converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
            converted_state_dict[new_key.replace("q_proj", "v_proj")] = v
        elif "gate" in key:
            w1, w3 = value.chunk(2, dim=0)
            converted_state_dict[new_key] = w1
            converted_state_dict[new_key.replace("w1", "w3")] = w3
        else:
            converted_state_dict[new_key] = value
    return converted_state_dict


def load_sharded_checkpoint(
    folder: str,
    foundation_model_name: str,
    num_heads: int = 32,
    num_kv_heads: int = 32,
    dim: int = 4096,
    head_dim: t.Optional[int] = None,
    triton_kernel: bool = False,
) -> t.Dict[str, torch.Tensor]:
    """
    This is the same as
    [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
    but for a sharded checkpoint.

    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
    loaded in the model.

    Args:
        model (`torch.nn.Module`): The model in which to load the checkpoint.
        folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.

    """
    # Load the index
    safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
    assert os.path.isfile(safe_index_file)
    with open(safe_index_file, "r", encoding="utf-8") as f:
        index = json.load(f)

    shard_files = list(set(index["weight_map"].values()))
    state_dict: t.Dict[str, torch.Tensor] = {}
    for shard_file in shard_files:
        state_dict |= safe_load_file(os.path.join(folder, shard_file), "cpu")
    state_dict = hf_to_sarus(
        state_dict=state_dict,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        dim=dim,
        head_dim=head_dim,
        foundation_model_name=foundation_model_name,
        triton_kernel=triton_kernel,
    )
    return state_dict


def load_state_dict_into_model(
    state_dict: t.Dict[str, torch.Tensor],
    model: torch.nn.Module,
    device: t.Union[str, int],
) -> None:
    for param_name, param in state_dict.items():
        old_param = model
        splits = param_name.split(".")
        for split in splits:
            old_param = getattr(old_param, split)
            if old_param is None:
                break

        if old_param is not None:
            param = param.to(old_param.dtype)
            if old_param.is_contiguous():
                param = param.contiguous()
        set_module_quantized_tensor_to_device(
            module=model,
            tensor_name=param_name,
            device=device,
            value=param,
            quantized_stats=None,
        )


def adapt_keynames_for_lora(
    state_dict: t.Dict[str, torch.Tensor],
    lora_attn_modules: t.List[LORA_ATTN_MODULES],
    apply_lora_to_mlp: bool,
    apply_lora_to_output: bool,
) -> t.Dict[str, torch.Tensor]:
    """The input state dict correspond to the base model without any Lora, we need
    to replace the corresponding keys accordingly"""
    converted_state_dict = {}

    for key, value in state_dict.items():
        new_key = key
        if "mlp." in key:
            if apply_lora_to_mlp:
                new_key = key.removesuffix("weight") + "base.weight"
            else:
                pass
        elif key == "output.weight":
            if apply_lora_to_output:
                new_key = "output.base.weight"
            else:
                pass
        else:
            for changed_key in lora_attn_modules:
                if changed_key in key:
                    new_key = key.removesuffix("weight") + "base.weight"
        converted_state_dict[new_key] = value
    return converted_state_dict