Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / models / modules / transformer.py
Size: Mime:
import copy
from typing import Optional, Union, cast

import torch
from torch import nn, Tensor

from .kv_cache import KVCache
from .attention import CausalSelfAttention
from .peft.lora import LoRALinear


class TransformerDecoderLayer(nn.Module):
    """Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer.

    Args:
        attn (CausalSelfAttention): Attention module.
        mlp (nn.Module): Feed-forward module.
        sa_norm (nn.Module): Normalization to be applied before self-attention.
        mlp_norm (nn.Module): Normalization to be applied before the feed-forward layer.
    """

    def __init__(
        self,
        attn: CausalSelfAttention,
        mlp: nn.Module,
        sa_norm: nn.Module,
        mlp_norm: nn.Module,
    ) -> None:
        super().__init__()
        self.sa_norm = sa_norm
        self.attn = attn
        self.mlp_norm = mlp_norm
        self.mlp = mlp

    def forward(
        self,
        x: Tensor,
        *,
        mask: Optional[Tensor] = None,
        input_pos: Optional[Tensor] = None,
        triton_kernel: bool = False,
    ) -> Tensor:
        """
        Args:
            x (Tensor): input tensor with shape
                [batch_size x seq_length x embed_dim]
            mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
                with shape [batch_size x seq_length x seq_length]. This is applied after
                the query-key multiplication and before the softmax. A value of True in row i
                and column j means token i attends to token j. A value of False means token i
                does not attend to token j. If no mask is specified, a causal mask
                is used by default. Default is None.
            input_pos (Optional[Tensor]): Optional tensor which contains the position ids
                of each token. During training, this is used to indicate the positions
                of each token relative to its sample when packed, shape [b x s].
                During inference, this indicates the position of the current token.
                If none, assume the index of the token is its position id. Default is None.

        Returns:
            Tensor: output tensor with same shape as input
                [batch_size x seq_length x embed_dim]

        TODO:
            - Make position of norm configurable
        """
        # Input tensor and attention output have the same shape
        # [b, s, d]
        # Norm applied before self-attention
        attn_out = self.attn(
            self.sa_norm(x, triton_kernel=triton_kernel),
            mask=mask,
            input_pos=input_pos,
            triton_kernel=triton_kernel,
        )

        # Residual connection; shape: [batch_size, seq_length, embed_dim]
        h = attn_out + x

        # Norm applied before the feedforward layer
        mlp_out = self.mlp(
            self.mlp_norm(h, triton_kernel=triton_kernel),
            triton_kernel=triton_kernel,
        )

        # Residual connection; shape: [batch_size, seq_length, embed_dim]
        out = h + mlp_out
        return cast(Tensor, out)


def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
    """
    Return a list of ``n`` identical layers.

    Args:
        module (nn.Module): module to be cloned
        n (int): number of clones

    Returns:
        nn.ModuleList: list of ``n`` identical layers
    """
    # FIXME: copy.deepcopy() is not defined on nn.module
    return nn.ModuleList([copy.deepcopy(module) for i in range(n)])


class TransformerDecoder(nn.Module):
    """
    Transformer Decoder derived from the Llama2 architecture.

    Args:
        tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
            tokens to an embedding space.
        layer (TransformerDecoderLayer): Transformer Decoder layer.
        num_layers (int): Number of Transformer Decoder layers.
        max_seq_len (int): maximum sequence length the model will be run with, as used
            by :func:`~torchtune.modules.KVCache`
        num_heads (int): number of query heads. For MHA this is also the
            number of heads for key and value. This is used to setup the
            :func:`~torchtune.modules.KVCache`
        head_dim (int): embedding dimension for each head in self-attention. This is used
            to setup the :func:`~torchtune.modules.KVCache`
        norm (nn.Module): Callable that applies normalization to the output of the decoder,
            before final MLP.
        output (nn.Linear): Callable that applies a linear transformation to the output of
            the decoder.

    Note:
        Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
        in the module where they are used. This helps reduces the number of raise
        statements in code and improves readability.
    """

    def __init__(
        self,
        tok_embeddings: nn.Embedding,
        layer: TransformerDecoderLayer,
        num_layers: int,
        max_seq_len: int,
        num_heads: int,
        head_dim: int,
        norm: nn.Module,
        output: nn.Module,
    ) -> None:
        super().__init__()

        self.tok_embeddings = tok_embeddings
        self.layers = _get_clones(layer, num_layers)
        self.norm = norm
        self.output = output
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.causal_mask = None

    def setup_caches(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
        max_seq_len: int,
    ) -> None:
        """Setup key value caches for attention calculation.

        Args:
            batch_size (int): batch size for the caches.
            dtype (torch.dtype): dtype for the caches.
        """
        for layer in self.layers:
            layer.attn.kv_cache = KVCache(
                batch_size=batch_size,
                max_seq_len=max_seq_len,
                num_heads=self.num_heads,
                head_dim=self.head_dim,
                dtype=dtype,
                device=device,
            )

    def reset_caches(self, device: Union[str, int]) -> None:
        """Reset the key value caches."""
        if self.layers[0].attn.kv_cache is None:
            raise RuntimeError(
                "Key value caches are not setup. Call ``setup_caches()`` first."
            )

        for layer in self.layers:
            layer.attn.kv_cache.reset(device)

    def forward(
        self,
        tokens: Tensor,
        *,
        mask: Optional[Tensor] = None,
        input_pos: Optional[Tensor] = None,
        triton_kernel: bool = False,
    ) -> Tensor:
        """
        Args:
            tokens (Tensor): input tensor with shape [b x s]
            mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
                with shape [b x s x s]. This is applied after the query-key multiplication and
                before the softmax. A value of True in row i and column j means token i attends
                to token j. A value of False means token i does not attend to token j. If no
                mask is specified, a causal mask is used by default. Default is None.
            input_pos (Optional[Tensor]): Optional tensor which contains the position ids
                of each token. During training, this is used to indicate the positions
                of each token relative to its sample when packed, shape [b x s].
                During inference, this indicates the position of the current token.
                If none, assume the index of the token is its position id. Default is None.

        Note: At the very first step of inference, when the model is provided with a prompt,
        ``input_pos`` would contain the positions of all of the tokens in the prompt
        (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
        KV values for each position.

        Returns:
            Tensor: output tensor with shape [b x s x v]

        Raises:
            ValueError: if causal_mask is set but input_pos is None

        Notation used for tensor shapes:
            - b: batch size
            - s: sequence length
            - v: vocab size
            - d: embed dim
            - m_s: max seq len
        """
        # input tensor of shape [b, s]
        bsz, seq_len = tokens.shape

        # shape: [b, s, d]
        h = self.tok_embeddings(tokens)

        for layer in self.layers:
            # shape: [b, s, d]
            h = layer(
                h, mask=mask, input_pos=input_pos, triton_kernel=triton_kernel
            )

        # shape: [b, s, d]
        h = self.norm(h, triton_kernel=triton_kernel)

        # shape: [b, s, out_dim] - out_dim is usually the vocab size
        output = self.output(h).float()
        return cast(Tensor, output)

    def setup_rope(self, device: Union[str, int]) -> None:
        """Initialise rope values after all pretrained weights
        have been loaded
        """
        for layer in self.layers:
            layer.attn.pos_embeddings.to_empty(device=device)
            layer.attn.pos_embeddings.reset_parameters()

    def setup_lora(self, device: Union[str, int]) -> None:
        """Initialise all lora layers after pre-trained
        weights have been loaded"""

        for layer in self.layers:
            if isinstance(layer.attn.q_proj, LoRALinear):
                move_lora_params_and_reset(layer.attn.q_proj, device=device)
            if isinstance(layer.attn.k_proj, LoRALinear):
                move_lora_params_and_reset(layer.attn.k_proj, device=device)
            if isinstance(layer.attn.v_proj, LoRALinear):
                move_lora_params_and_reset(layer.attn.v_proj, device=device)
            if isinstance(layer.attn.output_proj, LoRALinear):
                move_lora_params_and_reset(layer.attn.v_proj, device=device)
            # TODO: MLPs might be different for LLAMA/PHi-3
            if isinstance(layer.mlp.w1, LoRALinear):
                move_lora_params_and_reset(layer.mlp.w1, device=device)
            if isinstance(layer.mlp.w2, LoRALinear):
                move_lora_params_and_reset(layer.mlp.w2, device=device)
            if isinstance(layer.mlp.w3, LoRALinear):
                move_lora_params_and_reset(layer.mlp.w3, device=device)
        if isinstance(self.output, LoRALinear):
            move_lora_params_and_reset(self.output, device=device)


def move_lora_params_and_reset(
    module: LoRALinear, device: Union[str, int]
) -> None:
    """ "Lora parameters are initialized as meta tensors,
    here they are moved to"""
    module.lora_a.to_empty(device=device)
    module.lora_b.to_empty(device=device)
    module.initialize_parameters()