Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / models / modules / position_embeddings.py
Size: Mime:
from typing import Optional, Tuple, cast

import torch

from torch import nn, Tensor
import sarus_llm.liger_kernels as liger_kernels


class RotaryPositionalEmbeddings(nn.Module):
    """
    This class implements Rotary Positional Embeddings (RoPE)
    proposed in https://arxiv.org/abs/2104.09864.

    Reference implementation (used for correctness verfication)
    can be found here:
    https://github.com/facebookresearch/llama/blob/main/llama/model.py#L450

    In this implementation we cache the embeddings for each position upto
    ``max_seq_len`` by computing this during init.

    Args:
        dim (int): Embedding dimension. This is usually set to the dim of each
            head in the attention module computed as ````embed_dim`` // ``num_heads````
        max_seq_len (int): Maximum expected sequence length for the
            model, if exceeded the cached freqs will be recomputed
        base (int): The base for the geometric progression used to compute
            the rotation angles
    """

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 4096,
        base: int = 10_000,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_seq_len = max_seq_len
        self._rope_init()

    # We need to explicitly define reset_parameters for FSDP initialization, see
    # https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885
    def reset_parameters(self) -> None:
        self._rope_init()

    def _rope_init(self) -> None:
        theta = 1.0 / (
            self.base
            ** (
                torch.arange(0, self.dim, 2)[: (self.dim // 2)].float()
                / self.dim
            )
        )
        self.register_buffer("theta", theta, persistent=False)
        self.build_rope_cache(self.max_seq_len)

    def build_rope_cache(self, max_seq_len: int = 4096) -> None:
        # Create position indexes `[0, 1, ..., max_seq_len - 1]`
        seq_idx = torch.arange(
            max_seq_len, dtype=self.theta.dtype, device=self.theta.device
        )

        # Outer product of theta and position index; output tensor has
        # a shape of [max_seq_len, dim // 2]
        idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()

        # cache includes both the cos and sin components and so the output shape is
        # [max_seq_len, dim // 2, 2]
        cache = torch.stack(
            [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1
        )
        self.register_buffer("cache", cache, persistent=False)

    def forward(
        self,
        q: Tensor,
        k: Tensor,
        *,
        input_pos: Optional[Tensor] = None,
        triton_kernel: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        if triton_kernel:
            seq_len = q.size(1)
            # extract the values based on whether input_pos is set or not
            rope_cache = (
                self.cache[:seq_len]
                if input_pos is None
                else self.cache[input_pos]
            )
            cos, sin = rope_cache[..., 0][None, :], rope_cache[..., 1][None, :]
            cos = torch.cat([cos, cos], dim=-1)
            sin = torch.cat([sin, sin], dim=-1)
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            # inputs should be of shape
            # q: (bsz, n_q_head, seq_len, head_dim)
            # k: (bsz, n_kv_head, seq_len, head_dim)
            # cos: (1, seq_len, head_dim)
            # sin: (1, seq_len, head_dim)
            return cast(
                Tuple[Tensor, Tensor],
                liger_kernels.LigerRopeFunction.apply(q, k, cos, sin),
            )
        else:
            return self.forward_no_kernel(
                q, input_pos=input_pos
            ), self.forward_no_kernel(k, input_pos=input_pos)

    def forward_no_kernel(
        self, x: Tensor, *, input_pos: Optional[Tensor] = None
    ) -> Tensor:
        """
        Args:
            x (Tensor): input tensor with shape
                [b, s, n_h, h_d]
            input_pos (Optional[Tensor]): Optional tensor which contains the position ids
                of each token. During training, this is used to indicate the positions
                of each token relative to its sample when packed, shape [b, s].
                During inference, this indicates the position of the current token.
                If none, assume the index of the token is its position id. Default is None.

        Returns:
            Tensor: output tensor with RoPE applied

        Notation used for tensor shapes:
            - b: batch size
            - s: sequence length
            - n_h: num heads
            - h_d: head dim

        TODO: The implementation below can be made more efficient
        for inference.
        """
        # input tensor has shape [b, s, n_h, h_d]
        seq_len = x.size(1)

        # extract the values based on whether input_pos is set or not
        rope_cache = (
            self.cache[:seq_len]
            if input_pos is None
            else self.cache[input_pos]
        )

        # reshape input; the last dimension is used for computing the output.
        # Cast to float to match the reference implementation
        # tensor has shape [b, s, n_h, h_d // 2, 2]
        xshaped = x.float().reshape(*x.shape[:-1], -1, 2)

        # reshape the cache for broadcasting
        # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples,
        # otherwise has shape [1, s, 1, h_d // 2, 2]
        rope_cache = rope_cache.view(
            -1, xshaped.size(1), 1, xshaped.size(3), 2
        )

        # tensor has shape [b, s, n_h, h_d // 2, 2]
        x_out = torch.stack(
            [
                xshaped[..., 0] * rope_cache[..., 0]
                - xshaped[..., 1] * rope_cache[..., 1],
                xshaped[..., 1] * rope_cache[..., 0]
                + xshaped[..., 0] * rope_cache[..., 1],
            ],
            -1,
        )

        # tensor has shape [b, s, n_h, h_d]
        x_out = x_out.flatten(3)
        return x_out.type_as(x)