Repository URL to install this package:
|
Version:
1.1.3 ▾
|
# pydantic config to be used from pyaml to provide validation
from __future__ import annotations
from pydantic import BaseModel
from sarus_llm.data.tokenizers.utils import SampleType
from sarus_llm.models.quantization import QuantizationConfig
from sarus_llm.models.base import LoraConfig
import typing as t
class FinetuningConfig(BaseModel):
sample_type: SampleType
foundation_model_name: t.Literal[
"llama3-8b", "open_mistral_7b", "llama3-70b", "phi_3_mini"
]
gradient_checkpointing: bool
quantization: QuantizationConfig
lora: t.Optional[LoraConfig]
dataset: DatasetConfig
deepspeed: DeepSpeedConfig
trainer: TrainerConfig
local_rank: int = -1
dp: DPConfig
class DPOConfig(BaseModel):
sample_type: SampleType
foundation_model_name: t.Literal[
"llama3-8b", "open_mistral_7b", "llama3-70b", "phi_3_mini"
]
gradient_checkpointing: bool
lora: t.Optional[LoraConfig]
dataset: DatasetConfig
deepspeed: DeepSpeedConfig
trainer: TrainerConfig
beta_dpo: float
local_rank: int = -1
class SamplingConfig(BaseModel):
sample_type: SampleType
foundation_model_name: t.Literal[
"llama3-8b", "open_mistral_7b", "llama3-70b", "phi_3_mini"
]
quantization: QuantizationConfig
lora: t.Optional[LoraConfig]
checkpoint_path: t.Optional[str]
dataset: SamplingDataset
sampler: Sampler
saving_dir: str
triton_kernel: bool = False
class DatasetConfig(BaseModel):
train_dir: str
train_tokenization_dir: str
test_dir: t.Optional[str] = None
test_tokenization_dir: t.Optional[str] = None
class SamplingDataset(BaseModel):
data_dir: str
tokenization_dir: str
class DeepSpeedConfig(BaseModel):
gradient_accumulation_steps: int
physical_batch_size: int
dtype: str
zero_stage: int
class TrainerConfig(BaseModel):
epochs: int
eval_every_n_grad_steps: int
checkpoint_path: str
save_every_n_grad_steps: int
learning_rate: LrConfig
triton_kernel: bool = False
# lr schedule parameters
class LrConfig(BaseModel):
value: float
lr_schedule: bool = False
num_warmup_steps: int = 0 # number of steps for the warmup phase.
num_cycles: float = 0.5 # number of waves in the cosine schedule.
class DPConfig(BaseModel):
is_dp: bool
noise_multiplier: t.Optional[float] = None
l2_norm_clip: t.Optional[float] = None
class Sampler(BaseModel):
batch_size: int
max_length: int
temperature: float
top_k: t.Optional[int]