Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
haliax / nn / linear.py
Size: Mime:
# Copyright 2025 The Levanter Authors
#
# SPDX-License-Identifier: Apache-2.0


import dataclasses
import math
from functools import partial
from typing import Optional

import equinox as eqx
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.megablox import gmm
from jax.experimental.shard_map import shard_map
from jax.random import PRNGKey
from jaxtyping import PRNGKeyArray

import haliax as hax


from . import mup
from .mup import AbstractLinearReparam, ReparamEnabled, LinearStandardParam
from .._src.state_dict import (
    Mod,
    ModuleWithStateDictSerialization,
    StateDict,
    default_eqx_module_from_state_dict,
    default_eqx_module_to_state_dict,
)
from ..axis import Axis, AxisSpec
from ..core import NamedArray
from ..jax_utils import named_call
from ..partitioning import ResourceAxis
from ..quantization import DotGeneralOp
from ..util import ensure_tuple


class Linear(ModuleWithStateDictSerialization, ReparamEnabled):
    """A named Linear layer. This module allows you to specify multiple named axes for both input
    and output, which is occasionally useful."""

    weight: NamedArray
    bias: NamedArray | None

    In: AxisSpec = eqx.field(static=True)
    Out: AxisSpec = eqx.field(static=True)
    dot_general: DotGeneralOp = eqx.field(default_factory=DotGeneralOp.default)

    _reparam_cls: type[AbstractLinearReparam] = eqx.field(static=True, default=LinearStandardParam)

    @property
    def reparam(self) -> AbstractLinearReparam:
        return self._reparam_cls(self.In, self.Out)

    @staticmethod
    def init(
        In: AxisSpec,
        Out: AxisSpec,
        *,
        key: PRNGKey,
        use_bias: bool = True,
        out_first: bool = True,
        dot_general: DotGeneralOp | None = None,
        init_scale: float = 1.0,
        reparam_cls: type[AbstractLinearReparam] = LinearStandardParam,
    ) -> "Linear":
        """
        Args:
            In: AxisSpec: The input axis spec
            Out: AxisSpec: The output axis spec
            key: PRNGKeyArray: The PRNG key to use for initialization
            use_bias: bool: Whether to use a bias term
            out_first: bool: Whether to put output axes first in the weight matrix. out_first is how PyTorch does it.
            dot_general: Callable: The dot_general function to use. Defaults to jax.lax.dot_general.
            init_scale: float: The scale to use for initialization. We scale init by 1/sqrt(Input.size)*init_scale
        """
        joint_spec = hax.concat_axis_specs(Out, In) if out_first else hax.concat_axis_specs(In, Out)
        weight = hax.random.truncated_normal(key, joint_spec, -3, 3) * (init_scale * reparam_cls.init_scale(In, Out))
        bias = hax.zeros(Out) if use_bias else None

        if dot_general is None:
            dot_general = DotGeneralOp.default()

        return Linear(weight, bias, In, Out, dot_general=dot_general, _reparam_cls=reparam_cls)

    @named_call
    def __call__(self, inputs, *, key: PRNGKeyArray | None = None):
        """
        Args:
            inputs (NamedArray): Input array
            key: Not used, but there for compat with other modules
        """
        del key
        q = inputs.dot(
            self.weight * self.reparam.active_scale,
            axis=self.In,
            dot_general=self.dot_general,
        )
        q = hax.auto_sharded(q)

        if self.bias is not None:
            q = q + self.bias
            q = hax.auto_sharded(q)

        return q

    def flatten_for_export(self: Mod) -> Mod:
        if isinstance(self.Out, hax.Axis) and isinstance(self.In, hax.Axis):
            return self

        weight = self.weight
        bias = self.bias

        new_Out = hax.flatten_axes(self.Out, "__OUT__")
        new_In = hax.flatten_axes(self.In, "__IN__")

        if weight is not None and weight.array is not None:
            out_first = self._out_first
            weight = weight.flatten_axes(self.Out, new_Out).flatten_axes(self.In, new_In)

            if out_first:
                weight = weight.rearrange((..., "__OUT__", "__IN__"))
            else:
                weight = weight.rearrange((..., "__IN__", "__OUT__"))

        if isinstance(bias, NamedArray):
            bias = bias.flatten_axes(self.Out, new_Out)

        return dataclasses.replace(self, weight=weight, bias=bias, In=new_In, Out=new_Out)

    def unflatten_from_export(self: Mod, template: Mod) -> Mod:
        weight = self.weight
        bias = self.bias

        if (template.In, template.Out) == (self.In, self.Out):
            return self

        if weight.array is not None:
            weight = weight.unflatten_axis("__OUT__", template.Out).unflatten_axis("__IN__", template.In)
            weight = weight.rearrange(template.weight.axes)

        if isinstance(bias, NamedArray):
            bias = bias.unflatten_axis("__OUT__", template.Out)
            bias = bias.rearrange(template.bias.axes)

        return dataclasses.replace(template, weight=weight, bias=bias)

    @property
    def _out_first(self):
        """
        Returns: bool: Whether the output axes are first in the weight matrix
        """
        # We do it this way because of scan layers
        if isinstance(self.Out, hax.Axis):
            return self.weight.axes[-1] != self.Out
        else:
            return self.weight.axes[-len(self.Out) :] != self.Out

    def to_state_dict(self, prefix: Optional[str] = None) -> StateDict:
        # weight can be None for certain filtering things like LoRA
        scaled = dataclasses.replace(self, weight=self.weight * self.reparam.active_scale if self.weight is not None else None)
        return default_eqx_module_to_state_dict(scaled, prefix)

    def from_state_dict(self: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod:
        unscaled = default_eqx_module_from_state_dict(self, state_dict, prefix)
        if unscaled.weight is not None: 
            unscaled = dataclasses.replace(unscaled, weight=unscaled.weight / self.reparam.active_scale)
        return unscaled

    @staticmethod
    def input_reparam(use_mup: bool = True) -> type[AbstractLinearReparam]:
        """Return the reparameterization class for an input linear layer."""

        return mup.InputLinearMup if use_mup else mup.LinearStandardParam

    @staticmethod
    def hidden_reparam(use_mup: bool = True) -> type[AbstractLinearReparam]:
        """Return the reparameterization class for a hidden linear layer."""

        return mup.HiddenLinearMup if use_mup else mup.LinearStandardParam

    @staticmethod
    def output_reparam(use_mup: bool = True) -> type[AbstractLinearReparam]:
        """Return the reparameterization class for an output linear layer."""

        return mup.OutputLinearMup if use_mup else mup.LinearStandardParam


class MoELinear(eqx.Module):
    """A named Linear layer for MoE. This module allows you to specify multiple named axes for both input
    and output, which is occasionally useful."""

    weight: NamedArray
    bias: NamedArray | None

    Experts: AxisSpec = eqx.field(static=True)
    In: Axis = eqx.field(static=True)
    Out: Axis = eqx.field(static=True)
    # TODO: support quantization for ragged_dot?
    # dot_general: DotGeneralOp = eqx.field(default_factory=DotGeneralOp.default)

    use_gmm: bool = eqx.field(static=True)

    @staticmethod
    def init(
        Experts: Axis,
        In: Axis,
        Out: Axis,
        *,
        key: PRNGKey,
        use_bias: bool = True,
        out_first: bool = False,
        init_scale: float = 1.0,
        use_gmm: bool = False,
    ) -> "MoELinear":
        """
        Args:
            Experts: Axis: The expert axis
            In: Axis: The input axis
            Out: Axis: The output axis
            key: PRNGKeyArray: The PRNG key to use for initialization
            use_bias: bool: Whether to use a bias term
            out_first: bool: Whether to put output axes first in the weight matrix. out_first is how PyTorch does it.
            dot_general: Callable: The dot_general function to use. Defaults to jax.lax.dot_general. For fp8 or int8
            init_scale: float: The scale to use for initialization. We scale init by 1/sqrt(Input.size)*init_scale
        """
        joint_spec = hax.concat_axis_specs(Out, In) if out_first else hax.concat_axis_specs(In, Out)
        joint_spec = hax.concat_axis_specs(Experts, joint_spec)
        input_size = hax.axis_size(In)
        weight = hax.random.truncated_normal(key, joint_spec, -3, 3) * (init_scale / math.sqrt(input_size))
        bias = hax.zeros(Out) if use_bias else None

        return MoELinear(weight, bias, Experts, In, Out, use_gmm=use_gmm)

    @named_call
    def __call__(self, inputs, group_sizes, *, key: PRNGKeyArray | None = None):
        """
        Args:
            inputs (NamedArray): Input array    (Batch, In)
            group_sizes (NamedArray): MoE expert sizes (Experts)
            key: Not used, but there for compat with other modules
        """
        del key

        dim_numbers = jax.lax.RaggedDotDimensionNumbers(
            dot_dimension_numbers=(
                # contracting
                (
                    ensure_tuple(inputs.axis_indices(self.In)),
                    ensure_tuple(self.weight.axis_indices(self.In)),
                ),
                # batch
                ((), ()),
            ),
            # Everything other than contracting dim is ragged
            lhs_ragged_dimensions=(inputs.axis_indices(hax.axis.without_axes(inputs.axes, self.In))),
            rhs_group_dimensions=(self.weight.axis_indices(self.Experts),),
        )

        if self.use_gmm:
            inputs = inputs.rearrange((..., self.In))
            out_axes = hax.replace_axis(inputs.axes, self.In, self.Out)
            q = _gmm(
                inputs,
                self.weight,
                group_sizes,
                out_axes,
                ar=hax.partitioning.physical_axis_name(self.In) == ResourceAxis.MODEL,
            )  # gmm((B, D), (E, D, d)) -> (B, d)
        else:
            q_raw = jax.lax.ragged_dot_general(
                lhs=inputs.array,
                rhs=self.weight.array,
                group_sizes=group_sizes.rearrange((..., self.Experts)).array,
                ragged_dot_dimension_numbers=dim_numbers,
            )
            out_axes = hax.replace_axis(inputs.axes, self.In, self.Out)
            q = hax.named(q_raw, out_axes)

        if self.bias is not None:
            q = q + self.bias

        q = hax.auto_sharded(q)

        return q

    @property
    def out_first(self):
        """
        Returns: bool: Whether the output axes are first in the weight matrix
        """
        # We do it this way because of scan layers
        if isinstance(self.Out, hax.Axis):
            return self.weight.axes[-1] != self.Out
        else:
            return self.weight.axes[-len(self.Out) :] != self.Out


def _gmm(lhs, rhs, group_sizes, out_axes, sharded=False, ar=False):
    if sharded:
        gmm_fn = gmm_sharded
    else:
        gmm_fn = shard_map(
            partial(gmm_sharded, ar=ar),
            mesh=jax.sharding.get_abstract_mesh(),
            in_specs=(
                hax.partitioning.pspec_for_axis(lhs.axes),
                hax.partitioning.pspec_for_axis(rhs.axes),
                hax.partitioning.pspec_for_axis(group_sizes.axes),
            ),
            out_specs=hax.partitioning.pspec_for_axis(out_axes),
            check_rep=False,
        )

    out = gmm_fn(lhs.array, rhs.array, group_sizes.array)

    return hax.NamedArray(out, axes=out_axes)


def gmm_sharded(lhs_: jnp.ndarray, rhs_: jnp.ndarray, group_sizes_: jnp.ndarray, ar: bool = False) -> jnp.ndarray:
    hs_shape = lhs_.shape
    if hs_shape[0] % 512:
        pad_length = 512 - hs_shape[0] % 512
        lhs_ = jax.lax.pad(lhs_, 0.0, [(0, pad_length, 0), (0, 0, 0)])

    tile_size = (512, 1024, 1024)  # (m, k, n)
    m, k, n = lhs_.shape[0], lhs_.shape[1], rhs_.shape[2]
    out = gmm(
        lhs_,
        rhs_,
        group_sizes_,
        preferred_element_type=lhs_.dtype,
        tiling=(min(m, tile_size[0]), min(k, tile_size[1]), min(n, tile_size[2])),
        interpret=jax.default_backend() == "cpu",
    )

    if ar:
        out = jax.lax.psum(out, ResourceAxis.MODEL)

    if hs_shape[0] % 512:
        out = out[: hs_shape[0]]

    return out