Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import typing as t
from ._component_builders import mistral, lora_mistral
from ..modules import TransformerDecoder
from ..modules.peft.utils import LORA_ATTN_MODULES
def mistral_7b(
gradient_checkpoint: bool,
) -> TransformerDecoder:
"""
Builder for creating a Mistral 7B model initialized w/ the default 7b parameter values
from https://mistral.ai/news/announcing-mistral-7b/
Returns:
TransformerDecoder: Instantiation of Mistral 7B model
"""
return mistral(
vocab_size=32768,
num_layers=32,
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
intermediate_dim=14336,
max_seq_len=32768,
attn_dropout=0.0,
norm_eps=1e-5,
gradient_checkpoint=gradient_checkpoint,
)
def lora_mistral_7b(
gradient_checkpoint: bool,
lora_attn_modules: t.List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 128,
lora_alpha: float = 256,
) -> TransformerDecoder:
"""
Builder for creating a Lora Mistral 7B model initialized w/ the default 7b parameter values
from https://mistral.ai/news/announcing-mistral-7b/
and Lora weights on top.
Returns:
TransformerDecoder: Instantiation of Mistral 7B model
"""
return lora_mistral(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
vocab_size=32768,
num_layers=32,
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
intermediate_dim=14336,
max_seq_len=32768,
attn_dropout=0.0,
norm_eps=1e-5,
gradient_checkpoint=gradient_checkpoint,
)