Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from typing import List, Optional, cast, Tuple
from torch import nn
import torch
from sarus_llm.models.modules import (
CausalSelfAttention,
FeedForward,
RMSNorm,
TransformerDecoder,
TransformerDecoderLayer,
)
from ..modules.peft.utils import LORA_ATTN_MODULES
from ..modules.peft.lora import LoRALinear
from functools import partial
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt
import sarus_llm.liger_kernels as liger_kernels
"""
Component builders for the Phi3 4K Mini Instruct model.
Builder functions help stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``CausalSelfAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
"""
class Phi3RotaryPositionalEmbeddings(nn.Module):
"""
RoPE Embeddings used in the Phi3 model.
Ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
This class is not numerically equivalent to the RoPE Embedding module
used by Llama2 and Llama3.
Args:
dim (int): Embedding dimension. This is usually set to the dim of each
head in the attention module computed as ``embed_dim`` // ``num_heads``
max_seq_len (int): Maximum expected sequence length for the
model, if exceeded the cached freqs will be recomputed
base (int): The base for the geometric progression used to compute
the rotation angles
"""
def __init__(
self,
dim: int,
max_seq_len: int = 4096,
base: int = 10_000,
) -> None:
super().__init__()
self.dim = dim
self.base = base
self.max_seq_len = max_seq_len
self._rope_init()
def _rope_init(self) -> None:
theta = 1.0 / (
self.base
** (
torch.arange(0, self.dim, 2)[: (self.dim // 2)].float()
/ self.dim
)
)
self.register_buffer("theta", theta, persistent=False)
self.build_rope_cache(self.max_seq_len)
def build_rope_cache(self, max_seq_len: int = 4096) -> None:
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
seq_idx = torch.arange(
max_seq_len, dtype=self.theta.dtype, device=self.theta.device
)
# Outer product of theta and position index; output tensor has
# a shape of [max_seq_len, dim // 2]
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
# We cache the cos and sin embeddings instead of the IDs. This helps
# ensure we have correct behavior when training with bf16
# Size: [max_seq_len, (dim * 2)]
freqs = torch.cat([idx_theta, idx_theta], dim=-1)
cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1)
self.register_buffer("cache", cache, persistent=False)
def reset_parameters(self) -> None:
self._rope_init()
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
triton_kernel: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if triton_kernel:
seq_len = q.size(1)
head_dim = q.size(-1)
# extract the values based on whether input_pos is set or not
rope_cache = (
self.cache[:seq_len]
if input_pos is None
else self.cache[input_pos]
)
# reshape the cache for broadcasting
# tensor has shape [b, s, 1, h_d * 2] if packed samples,
# otherwise has shape [1, s, 1, h_d * 2]
rope_cache = rope_cache.view(-1, seq_len, head_dim * 2)
# [b, s, h_d]
cos = rope_cache[..., :head_dim]
sin = rope_cache[..., head_dim:]
# inputs should be of shape
# q: (bsz, n_q_head, seq_len, head_dim)
# k: (bsz, n_kv_head, seq_len, head_dim)
# cos: (1, seq_len, head_dim)
# sin: (1, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
return cast(
Tuple[torch.Tensor, torch.Tensor],
liger_kernels.LigerRopeFunction.apply(q, k, cos, sin),
)
return self.forward_no_kernel(q, input_pos), self.forward_no_kernel(
k, input_pos
)
def forward_no_kernel(
self, x: torch.Tensor, input_pos: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x (Tensor): input tensor with shape
[b, s, n_h, h_d]
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b, s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Returns:
Tensor: output tensor with RoPE applied
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- n_h: num heads
- h_d: head dim
TODO: The implementation below can be made more efficient
for inference.
"""
# input tensor has shape [b, s, n_h, h_d]
seq_len = x.size(1)
head_dim = x.size(-1)
# extract the values based on whether input_pos is set or not. When
# input_pos is provided, we're in inference mode
rope_cache = (
self.cache[:seq_len]
if input_pos is None
else self.cache[input_pos]
)
# reshape the cache for broadcasting
# tensor has shape [b, s, 1, h_d * 2] if packed samples,
# otherwise has shape [1, s, 1, h_d * 2]
rope_cache = rope_cache.view(-1, seq_len, 1, head_dim * 2)
# [b, s, 1, h_d]
cos = rope_cache[..., :head_dim]
sin = rope_cache[..., head_dim:]
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
rotated = torch.cat((-x2, x1), dim=-1)
# cos: [b, s, 1, h_d]
# x: [b, s, n_h, h_d]
x_out = (x * cos) + (rotated * sin)
return cast(torch.Tensor, x_out.type_as(x))
def phi3(
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
intermediate_dim: int,
max_seq_len: int,
gradient_checkpoint: int,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
rope_base: int = 10_000,
) -> TransformerDecoder:
"""
Args:
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
embed_dim (int): embedding dimension for self-attention
intermediate_dim (int): intermediate dimension for MLP
max_seq_len (int): maximum sequence length the model will be run with,
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
norm_eps (float): epsilon in RMS norms
rope_base (int): base for the rotary positional embeddings. Default: 10_000
Returns:
TransformerDecoder: Instantiation of Phi3 Mini 4K Instruct model.
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
rope = Phi3RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
kv_cache=None,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
if gradient_checkpoint:
# activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html
non_reentrant_wrapper = partial(
torch_ckpt.checkpoint_wrapper,
checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT,
)
layer = non_reentrant_wrapper(layer) # type: ignore[assignment]
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
def phi3_mlp(dim: int, hidden_dim: int) -> FeedForward:
"""
Build the MLP layer associated with the Phi3 Mini 4K Instruct model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj
)
# ------------------ LoRA Phi3 ------------------
def lora_phi3(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# phi3 args
gradient_checkpoint: bool,
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
intermediate_dim: int,
max_seq_len: int,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
rope_base: int = 10_000,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
) -> TransformerDecoder:
"""
Return a version of Phi3
with LoRA applied based on the passed in configuration.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
embed_dim (int): embedding dimension for self-attention
intermediate_dim (int): intermediate dimension for MLP
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
norm_eps (float): epsilon in RMS norms.
rope_base (int): base value for Rotary Position Embeddings.
Default: 10000
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Returns:
TransformerDecoder: Instantiation of Llama3 model with LoRA applied to
a subset of the attention projections in each layer.
"""
self_attn = lora_phi3_self_attention(
lora_modules=lora_attn_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
rope_base=rope_base,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
if apply_lora_to_mlp:
mlp = lora_phi3_mlp(
dim=embed_dim,
hidden_dim=intermediate_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
else:
mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
if gradient_checkpoint:
# activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html
non_reentrant_wrapper = partial(
torch_ckpt.checkpoint_wrapper,
checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT,
)
layer = non_reentrant_wrapper(layer) # type: ignore[assignment]
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = (
LoRALinear(
embed_dim,
vocab_size,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
return model
def lora_phi3_self_attention(
lora_modules: List[LORA_ATTN_MODULES],
*,
# CausalSelfAttention args
embed_dim: int,
num_heads: int,
num_kv_heads: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: int = 10_000,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
) -> CausalSelfAttention:
"""
Return an instance of :func:`~.modules.CausalSelfAttention` with LoRA
applied to a subset of its linear layers
Args:
lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj",
"output_proj"}``.
embed_dim (int): embedding dimension for self-attention
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
rope_base (int): base value for Rotary Position Embeddings.
Default: 10000
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Returns:
CausalSelfAttention: instantiation of self-attention module with LoRA
applied to a subset of Q, K, V, output projections.
Raises:
ValueError: If lora_modules arg is an empty list
"""
if not lora_modules:
raise ValueError(
f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules"
)
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
q_proj = (
LoRALinear(
embed_dim,
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
)
k_proj = (
LoRALinear(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
v_proj = (
LoRALinear(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
output_proj = (
LoRALinear(
embed_dim,
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
)
rope = Phi3RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
output_proj=output_proj,
pos_embeddings=rope,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
return self_attn
def lora_phi3_mlp(
*,
dim: int,
hidden_dim: int,
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
) -> FeedForward:
gate_proj = LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
down_proj = LoRALinear(
in_dim=hidden_dim,
out_dim=dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
up_proj = LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
return FeedForward(
gate_proj=gate_proj,
down_proj=down_proj,
up_proj=up_proj,
)