Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from typing import cast
import torch
from torch import nn, Tensor
import sarus_llm.liger_kernels as liger_kernels
class RMSNorm(nn.Module):
"""
Implements Root Mean Square Normalization introduced in
https://arxiv.org/pdf/1910.07467.pdf.
Reference implementation (used for correctness verfication)
can be found here:
https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
dim (int): embedding size
eps (float): small value to avoid division by zero. Default: 1e-6
"""
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor, triton_kernel: bool = False) -> Tensor:
"""
Args:
x (Tensor): input tensor to normalize
Returns:
Tensor: The output tensor after applying RMSNorm.
"""
if triton_kernel:
return cast(
Tensor,
liger_kernels.LigerRMSNormFunction.apply(
x, self.scale, self.eps, 0.0, "gemma"
),
)
# computation is in fp32
x_fp32 = x.float()
x_normed = (
x_fp32
* torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
).type_as(x)
return x_normed * self.scale