Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import json
import os
import re
import typing as t
from typing import Dict
import torch
from safetensors.torch import load_file as safe_load_file
from .modules.peft.utils import LORA_ATTN_MODULES
from .quantization import set_module_quantized_tensor_to_device
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"lm_head.weight": "output.weight",
# specific to phi3
"model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
}
def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
try:
if "layers" in key:
# Replace layer number with "{}" to create key for lookup
abstract_key = re.sub(r"(\.\d+)", ".{}", key)
layer_num = re.search(r"\d+", key).group(0) # type:ignore[union-attr]
new_key = mapping_dict[abstract_key]
new_key = new_key.format(layer_num)
else:
new_key = mapping_dict[key]
except KeyError as e:
raise Exception(
f'Error converting the state dict. Found unexpected key: "{key}". '
"Please make sure you're loading a checkpoint with the right format. "
) from e
return new_key
def hf_to_sarus(
state_dict: Dict[str, torch.Tensor],
foundation_model_name: str,
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: t.Optional[int] = None,
triton_kernel: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from HF's format to torchtune's format. State dicts
from multiple checkpoint files should be consolidated into a single state dict
before calling this function.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
head_dim (int): Dimension of the head. If not provided, it will be calculated
as dim // num_heads.
Returns:
Dict[str, torch.Tensor]: State dict in torchtune's format.
"""
if "phi" in foundation_model_name:
return convert_phi_state_dict(state_dict=state_dict)
converted_state_dict = {}
if head_dim is None:
head_dim = dim // num_heads
def _permute(t: torch.Tensor, n_heads: int) -> torch.Tensor:
return (
t.view(n_heads, 2, head_dim // 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)
for key, value in state_dict.items():
if (
"rotary_emb.inv_freq" not in key
): # Skip loading the position embeddings
new_key = get_mapped_key(key, _FROM_HF) # type:ignore
if "q_proj" in key and not triton_kernel:
value = _permute(value, num_heads)
elif "k_proj" in key and not triton_kernel:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value
return converted_state_dict
def convert_phi_state_dict(
state_dict: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
converted_state_dict = {}
for key, value in state_dict.items():
new_key = get_mapped_key(key, _FROM_HF) # type:ignore
if "qkv" in key:
(
q,
k,
v,
) = value.chunk(3, dim=0)
converted_state_dict[new_key] = q
converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
converted_state_dict[new_key.replace("q_proj", "v_proj")] = v
elif "gate" in key:
w1, w3 = value.chunk(2, dim=0)
converted_state_dict[new_key] = w1
converted_state_dict[new_key.replace("w1", "w3")] = w3
else:
converted_state_dict[new_key] = value
return converted_state_dict
def load_sharded_checkpoint(
folder: str,
foundation_model_name: str,
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: t.Optional[int] = None,
triton_kernel: bool = False,
) -> t.Dict[str, torch.Tensor]:
"""
This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`torch.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
"""
# Load the index
safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
assert os.path.isfile(safe_index_file)
with open(safe_index_file, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
state_dict: t.Dict[str, torch.Tensor] = {}
for shard_file in shard_files:
state_dict |= safe_load_file(os.path.join(folder, shard_file), "cpu")
state_dict = hf_to_sarus(
state_dict=state_dict,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim=dim,
head_dim=head_dim,
foundation_model_name=foundation_model_name,
triton_kernel=triton_kernel,
)
return state_dict
def load_state_dict_into_model(
state_dict: t.Dict[str, torch.Tensor],
model: torch.nn.Module,
device: t.Union[str, int],
) -> None:
for param_name, param in state_dict.items():
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
param = param.to(old_param.dtype)
if old_param.is_contiguous():
param = param.contiguous()
set_module_quantized_tensor_to_device(
module=model,
tensor_name=param_name,
device=device,
value=param,
quantized_stats=None,
)
def adapt_keynames_for_lora(
state_dict: t.Dict[str, torch.Tensor],
lora_attn_modules: t.List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool,
apply_lora_to_output: bool,
) -> t.Dict[str, torch.Tensor]:
"""The input state dict correspond to the base model without any Lora, we need
to replace the corresponding keys accordingly"""
converted_state_dict = {}
for key, value in state_dict.items():
new_key = key
if "mlp." in key:
if apply_lora_to_mlp:
new_key = key.removesuffix("weight") + "base.weight"
else:
pass
elif key == "output.weight":
if apply_lora_to_output:
new_key = "output.base.weight"
else:
pass
else:
for changed_key in lora_attn_modules:
if changed_key in key:
new_key = key.removesuffix("weight") + "base.weight"
converted_state_dict[new_key] = value
return converted_state_dict