Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / models / mistral / model.py
Size: Mime:
import typing as t
from ._component_builders import mistral, lora_mistral

from ..modules import TransformerDecoder
from ..modules.peft.utils import LORA_ATTN_MODULES


def mistral_7b(
    gradient_checkpoint: bool,
) -> TransformerDecoder:
    """
    Builder for creating a Mistral 7B model initialized w/ the default 7b parameter values
    from https://mistral.ai/news/announcing-mistral-7b/


    Returns:
        TransformerDecoder: Instantiation of Mistral 7B model
    """
    return mistral(
        vocab_size=32768,
        num_layers=32,
        num_heads=32,
        num_kv_heads=8,
        embed_dim=4096,
        intermediate_dim=14336,
        max_seq_len=32768,
        attn_dropout=0.0,
        norm_eps=1e-5,
        gradient_checkpoint=gradient_checkpoint,
    )


def lora_mistral_7b(
    gradient_checkpoint: bool,
    lora_attn_modules: t.List[LORA_ATTN_MODULES],
    apply_lora_to_mlp: bool = False,
    apply_lora_to_output: bool = False,
    lora_rank: int = 128,
    lora_alpha: float = 256,
) -> TransformerDecoder:
    """
    Builder for creating a Lora Mistral 7B model initialized w/ the default 7b parameter values
    from https://mistral.ai/news/announcing-mistral-7b/
    and Lora weights on top.

    Returns:
        TransformerDecoder: Instantiation of Mistral 7B model
    """
    return lora_mistral(
        lora_attn_modules=lora_attn_modules,
        apply_lora_to_mlp=apply_lora_to_mlp,
        apply_lora_to_output=apply_lora_to_output,
        lora_rank=lora_rank,
        lora_alpha=lora_alpha,
        vocab_size=32768,
        num_layers=32,
        num_heads=32,
        num_kv_heads=8,
        embed_dim=4096,
        intermediate_dim=14336,
        max_seq_len=32768,
        attn_dropout=0.0,
        norm_eps=1e-5,
        gradient_checkpoint=gradient_checkpoint,
    )