Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import copy
from typing import Optional, Union, cast
import torch
from torch import nn, Tensor
from .kv_cache import KVCache
from .attention import CausalSelfAttention
from .peft.lora import LoRALinear
class TransformerDecoderLayer(nn.Module):
"""Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer.
Args:
attn (CausalSelfAttention): Attention module.
mlp (nn.Module): Feed-forward module.
sa_norm (nn.Module): Normalization to be applied before self-attention.
mlp_norm (nn.Module): Normalization to be applied before the feed-forward layer.
"""
def __init__(
self,
attn: CausalSelfAttention,
mlp: nn.Module,
sa_norm: nn.Module,
mlp_norm: nn.Module,
) -> None:
super().__init__()
self.sa_norm = sa_norm
self.attn = attn
self.mlp_norm = mlp_norm
self.mlp = mlp
def forward(
self,
x: Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
triton_kernel: bool = False,
) -> Tensor:
"""
Args:
x (Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
with shape [batch_size x seq_length x seq_length]. This is applied after
the query-key multiplication and before the softmax. A value of True in row i
and column j means token i attends to token j. A value of False means token i
does not attend to token j. If no mask is specified, a causal mask
is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Returns:
Tensor: output tensor with same shape as input
[batch_size x seq_length x embed_dim]
TODO:
- Make position of norm configurable
"""
# Input tensor and attention output have the same shape
# [b, s, d]
# Norm applied before self-attention
attn_out = self.attn(
self.sa_norm(x, triton_kernel=triton_kernel),
mask=mask,
input_pos=input_pos,
triton_kernel=triton_kernel,
)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
h = attn_out + x
# Norm applied before the feedforward layer
mlp_out = self.mlp(
self.mlp_norm(h, triton_kernel=triton_kernel),
triton_kernel=triton_kernel,
)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
out = h + mlp_out
return cast(Tensor, out)
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
"""
Return a list of ``n`` identical layers.
Args:
module (nn.Module): module to be cloned
n (int): number of clones
Returns:
nn.ModuleList: list of ``n`` identical layers
"""
# FIXME: copy.deepcopy() is not defined on nn.module
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
class TransformerDecoder(nn.Module):
"""
Transformer Decoder derived from the Llama2 architecture.
Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layer (TransformerDecoderLayer): Transformer Decoder layer.
num_layers (int): Number of Transformer Decoder layers.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
output (nn.Linear): Callable that applies a linear transformation to the output of
the decoder.
Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""
def __init__(
self,
tok_embeddings: nn.Embedding,
layer: TransformerDecoderLayer,
num_layers: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
output: nn.Module,
) -> None:
super().__init__()
self.tok_embeddings = tok_embeddings
self.layers = _get_clones(layer, num_layers)
self.norm = norm
self.output = output
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
def setup_caches(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
max_seq_len: int,
) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
"""
for layer in self.layers:
layer.attn.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
head_dim=self.head_dim,
dtype=dtype,
device=device,
)
def reset_caches(self, device: Union[str, int]) -> None:
"""Reset the key value caches."""
if self.layers[0].attn.kv_cache is None:
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
for layer in self.layers:
layer.attn.kv_cache.reset(device)
def forward(
self,
tokens: Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
triton_kernel: bool = False,
) -> Tensor:
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
with shape [b x s x s]. This is applied after the query-key multiplication and
before the softmax. A value of True in row i and column j means token i attends
to token j. A value of False means token i does not attend to token j. If no
mask is specified, a causal mask is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Note: At the very first step of inference, when the model is provided with a prompt,
``input_pos`` would contain the positions of all of the tokens in the prompt
(eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
KV values for each position.
Returns:
Tensor: output tensor with shape [b x s x v]
Raises:
ValueError: if causal_mask is set but input_pos is None
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- v: vocab size
- d: embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
for layer in self.layers:
# shape: [b, s, d]
h = layer(
h, mask=mask, input_pos=input_pos, triton_kernel=triton_kernel
)
# shape: [b, s, d]
h = self.norm(h, triton_kernel=triton_kernel)
# shape: [b, s, out_dim] - out_dim is usually the vocab size
output = self.output(h).float()
return cast(Tensor, output)
def setup_rope(self, device: Union[str, int]) -> None:
"""Initialise rope values after all pretrained weights
have been loaded
"""
for layer in self.layers:
layer.attn.pos_embeddings.to_empty(device=device)
layer.attn.pos_embeddings.reset_parameters()
def setup_lora(self, device: Union[str, int]) -> None:
"""Initialise all lora layers after pre-trained
weights have been loaded"""
for layer in self.layers:
if isinstance(layer.attn.q_proj, LoRALinear):
move_lora_params_and_reset(layer.attn.q_proj, device=device)
if isinstance(layer.attn.k_proj, LoRALinear):
move_lora_params_and_reset(layer.attn.k_proj, device=device)
if isinstance(layer.attn.v_proj, LoRALinear):
move_lora_params_and_reset(layer.attn.v_proj, device=device)
if isinstance(layer.attn.output_proj, LoRALinear):
move_lora_params_and_reset(layer.attn.v_proj, device=device)
# TODO: MLPs might be different for LLAMA/PHi-3
if isinstance(layer.mlp.w1, LoRALinear):
move_lora_params_and_reset(layer.mlp.w1, device=device)
if isinstance(layer.mlp.w2, LoRALinear):
move_lora_params_and_reset(layer.mlp.w2, device=device)
if isinstance(layer.mlp.w3, LoRALinear):
move_lora_params_and_reset(layer.mlp.w3, device=device)
if isinstance(self.output, LoRALinear):
move_lora_params_and_reset(self.output, device=device)
def move_lora_params_and_reset(
module: LoRALinear, device: Union[str, int]
) -> None:
""" "Lora parameters are initialized as meta tensors,
here they are moved to"""
module.lora_a.to_empty(device=device)
module.lora_b.to_empty(device=device)
module.initialize_parameters()