Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / config / _pydantic.py
Size: Mime:
# pydantic config to be used from pyaml to provide validation
from __future__ import annotations
from pydantic import BaseModel
from sarus_llm.data.tokenizers.utils import SampleType
from sarus_llm.models.quantization import QuantizationConfig
from sarus_llm.models.base import LoraConfig
import typing as t


class FinetuningConfig(BaseModel):
    sample_type: SampleType
    foundation_model_name: t.Literal[
        "llama3-8b", "open_mistral_7b", "llama3-70b", "phi_3_mini"
    ]
    gradient_checkpointing: bool
    quantization: QuantizationConfig
    lora: t.Optional[LoraConfig]
    dataset: DatasetConfig
    deepspeed: DeepSpeedConfig
    trainer: TrainerConfig
    local_rank: int = -1
    dp: DPConfig


class DPOConfig(BaseModel):
    sample_type: SampleType
    foundation_model_name: t.Literal[
        "llama3-8b", "open_mistral_7b", "llama3-70b", "phi_3_mini"
    ]
    gradient_checkpointing: bool
    lora: t.Optional[LoraConfig]
    dataset: DatasetConfig
    deepspeed: DeepSpeedConfig
    trainer: TrainerConfig
    beta_dpo: float
    local_rank: int = -1


class SamplingConfig(BaseModel):
    sample_type: SampleType
    foundation_model_name: t.Literal[
        "llama3-8b", "open_mistral_7b", "llama3-70b", "phi_3_mini"
    ]
    quantization: QuantizationConfig
    lora: t.Optional[LoraConfig]
    checkpoint_path: t.Optional[str]
    dataset: SamplingDataset
    sampler: Sampler
    saving_dir: str
    triton_kernel: bool = False


class DatasetConfig(BaseModel):
    train_dir: str
    train_tokenization_dir: str
    test_dir: t.Optional[str] = None
    test_tokenization_dir: t.Optional[str] = None


class SamplingDataset(BaseModel):
    data_dir: str
    tokenization_dir: str


class DeepSpeedConfig(BaseModel):
    gradient_accumulation_steps: int
    physical_batch_size: int
    dtype: str
    zero_stage: int


class TrainerConfig(BaseModel):
    epochs: int
    eval_every_n_grad_steps: int
    checkpoint_path: str
    save_every_n_grad_steps: int
    learning_rate: LrConfig
    triton_kernel: bool = False
    # lr schedule parameters


class LrConfig(BaseModel):
    value: float
    lr_schedule: bool = False
    num_warmup_steps: int = 0  # number of steps for the warmup phase.
    num_cycles: float = 0.5  # number of waves in the cosine schedule.


class DPConfig(BaseModel):
    is_dp: bool
    noise_multiplier: t.Optional[float] = None
    l2_norm_clip: t.Optional[float] = None


class Sampler(BaseModel):
    batch_size: int
    max_length: int
    temperature: float
    top_k: t.Optional[int]