Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from functools import partial
from typing import List, Optional
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt
from torch import nn
import torch
import math
from ..modules import (
CausalSelfAttention,
FeedForward,
RMSNorm,
RotaryPositionalEmbeddings,
TransformerDecoder,
TransformerDecoderLayer,
)
from ..modules.peft.utils import LORA_ATTN_MODULES
from ..modules.peft.lora import LoRALinear
"""
Component builders for the Llama3 model and popular variants such as LoRA.
torchtune provides composable building blocks. Builder functions help
stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``CausalSelfAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
"""
class Llama3ScaledRoPE(RotaryPositionalEmbeddings):
def apply_scaling(self, freqs: torch.Tensor) -> torch.Tensor:
"""From the following Meta-Llama code:
https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1/api/model.py#L41"""
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append(
(1 - smooth) * freq / scale_factor + smooth * freq
)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def _rope_init(self) -> None:
freqs = 1.0 / (
self.base
** (
torch.arange(0, self.dim, 2)[: (self.dim // 2)].float()
/ self.dim
)
)
# The model is initialized with meta tensors so that afterwards pretrained weights are loaded, so we skip computations as they are not allowed on meta devices
if freqs.device.type != "meta":
theta = self.apply_scaling(freqs)
else:
theta = freqs
self.register_buffer("theta", theta, persistent=False)
self.build_rope_cache(self.max_seq_len)
# ------------------ Vanilla Llama3.1 ------------------
def llama3(
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: int = 500000,
intermediate_dim: Optional[int] = None,
norm_eps: float = 1e-5,
gradient_checkpoint: bool = True,
) -> TransformerDecoder:
"""
Build the decoder associated with the Llama3 model. This includes:
- Token embeddings
- num_layers number of TransformerDecoderLayer blocks
- RMS Norm layer applied to the output of the transformer
- Final projection into token space
Args:
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. If specified,
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is
`None`, in which case this is the same as MHA
embed_dim (int): embedding dimension for self-attention
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
Returns:
TransformerDecoder: Instantiation of Llama3 model.
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
rope = Llama3ScaledRoPE(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = (
intermediate_dim
if intermediate_dim
else scale_hidden_dim_for_mlp(embed_dim)
)
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
if gradient_checkpoint:
# activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html
non_reentrant_wrapper = partial(
torch_ckpt.checkpoint_wrapper,
checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT,
)
layer = non_reentrant_wrapper(layer) # type: ignore[assignment]
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj
)
# ------------------ LoRA Llama3 ------------------
def lora_llama3(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# llama3 args
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
intermediate_dim: Optional[int] = None,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
rope_base: int = 500000,
gradient_checkpoint: bool = True,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
) -> TransformerDecoder:
"""
Return a version of Llama3 (an instance of :func:`~torchtune.modules.TransformerDecoder`)
with LoRA applied based on the passed in configuration.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. If specified,
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is
`None`, in which case this is the same as MHA
embed_dim (int): embedding dimension for self-attention
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Returns:
TransformerDecoder: Instantiation of Llama3 model with LoRA applied to
a subset of the attention projections in each layer.
"""
self_attn = lora_llama3_self_attention(
lora_modules=lora_attn_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
rope_base=rope_base,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
hidden_dim = (
intermediate_dim
if intermediate_dim
else scale_hidden_dim_for_mlp(embed_dim)
)
if apply_lora_to_mlp:
mlp = lora_llama3_mlp(
dim=embed_dim,
hidden_dim=hidden_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
else:
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
if gradient_checkpoint:
# activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html
non_reentrant_wrapper = partial(
torch_ckpt.checkpoint_wrapper,
checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT,
)
layer = non_reentrant_wrapper(layer) # type: ignore[assignment]
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = (
LoRALinear(
embed_dim,
vocab_size,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
return model
def lora_llama3_self_attention(
lora_modules: List[LORA_ATTN_MODULES],
*,
# CausalSelfAttention args
embed_dim: int,
num_heads: int,
num_kv_heads: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: int = 500000,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
) -> CausalSelfAttention:
"""
Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA
applied to a subset of its linear layers
Args:
lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj",
"output_proj"}``.
embed_dim (int): embedding dimension for self-attention
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. If specified,
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is
`None`, in which case this is the same as MHA
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Returns:
CausalSelfAttention: instantiation of self-attention module with LoRA
applied to a subset of Q, K, V, output projections.
Raises:
ValueError: If lora_modules arg is an empty list
"""
if not lora_modules:
raise ValueError(
f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules"
)
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
q_proj = (
LoRALinear(
embed_dim,
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
)
k_proj = (
LoRALinear(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
v_proj = (
LoRALinear(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
output_proj = (
LoRALinear(
embed_dim,
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
)
rope = Llama3ScaledRoPE(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
output_proj=output_proj,
pos_embeddings=rope,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
return self_attn
def lora_llama3_mlp(
*,
dim: int,
hidden_dim: int,
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
) -> FeedForward:
gate_proj = LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
down_proj = LoRALinear(
in_dim=hidden_dim,
out_dim=dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
up_proj = LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
return FeedForward(
gate_proj=gate_proj,
down_proj=down_proj,
up_proj=up_proj,
)
def scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
"""Scale hidden dimension for MLP to keep number of parameters and computation constant.
Args:
dim (int): Input dimension.
multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
Returns:
Scaled hidden dimension.
"""
# Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
# parameters and computation constant
hidden_dim = 4 * int(2 * dim / 3)
# Round hidden dimension to nearest multiple of `multiple_of`
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim