Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / train.py
Size: Mime:
from __future__ import annotations

import json
import logging
import os
import time
import typing as t
from pathlib import Path
import torch.nn.functional as F
import bitsandbytes as bnb
import deepspeed
import torch
from datasets import Dataset as HFDataset
from fastDP.privacy_engine import PrivacyEngine
from torch.utils.data import DataLoader, Dataset, RandomSampler
from tqdm.auto import trange, tqdm

from sarus_llm.privacy import compute_epsilon
from sarus_llm.data.data_collator import (
    TrainingDataCollator,
    PreferenceDataCollator,
    DataCollator,
)
from sarus_llm.models.base import SarusModelProvider
import sarus_llm.liger_kernels as liger_kernels
from sarus_llm.models.modules.peft.utils import disable_adapter
import math

try:
    from torch.utils.tensorboard import SummaryWriter

    # NB: for logging, low level tensorflow ops are used
    # as they allow to log text tensors
except ModuleNotFoundError:
    PT_TENSORBOARD_LOGGING = False
else:
    PT_TENSORBOARD_LOGGING = True

OPTIMIZER_NAME = "optimizer.pt"

logger = logging.getLogger(__name__)


def get_cosine_schedule_with_warmup(
    optimizer: torch.optim.Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    last_epoch: int = -1,
) -> torch.optim.lr_scheduler.LambdaLR:
    """
    Create a learning rate schedule that linearly increases the learning rate from
    0.0 to lr over ``num_warmup_steps``, then decreases to 0.0 on a cosine schedule over
    the remaining ``num_training_steps-num_warmup_steps`` (assuming ``num_cycles`` = 0.5).

    This is based on the Hugging Face implementation
    https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/optimization.py#L104.
    """

    def lr_lambda(current_step: int) -> float:
        # linear warmup phase
        if current_step < num_warmup_steps:
            return current_step / max(1, num_warmup_steps)

        # cosine
        progress = (current_step - num_warmup_steps) / max(
            1, num_training_steps - num_warmup_steps
        )

        cosine_lr_multiple = 0.5 * (
            1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)
        )
        return max(0.0, cosine_lr_multiple)

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)


class TrainerMixin:
    """A class providing utils methods for a standard training, as well
    as utility methods for multiprocessing"""

    data_collator: DataCollator
    n_gpus: int
    train_dataset: HFDataset
    local_rank: int
    torch_model: torch.nn.Module
    training_model: torch.nn.Module
    model_provider: SarusModelProvider
    deepspeed_config: t.Dict[str, t.Any]
    checkpoint_path: str
    physical_batch_size: int
    epochs: int
    gradient_accumulation_steps: int
    eval_every_n_grad_steps: int
    save_every: int
    has_validation: bool
    validation_dataset: HFDataset
    optimizer: torch.optim.AdamW
    lr_schedule: bool
    num_warmup_steps: int = 0  # number of steps for the warmup phase.
    num_cycles: float = 0.5

    def training_step(
        self,
        batch: t.Any,
    ) -> t.Tuple[torch.Tensor, t.Any]:
        """One training step consists in:
        - putting data on GPU
        - computing loss/metrics
        - backward"""

        self.training_model.train()
        inputs = self._prepare_inputs(batch)
        loss, additional_metrics = self.compute_loss(inputs)
        self.training_model.backward(loss)
        return loss.detach(), additional_metrics

    def _prepare_inputs(
        self, inputs: t.Dict[str, torch.Tensor]
    ) -> t.Dict[str, torch.Tensor]:
        """Sets data on the device"""
        device = self.training_model.local_rank
        for key, tensor in inputs.items():
            inputs[key] = tensor.to(device)
        return inputs

    def train_dataloader(
        self,
        dataset: Dataset,
        batch_size: int,
    ) -> DataLoader:
        """Iterate over dataset batches & encode text"""
        dataloader_params: t.Dict[str, t.Any] = {
            "batch_size": batch_size,
            "collate_fn": self.data_collator.collate_batch,
            "num_workers": 1,
            "sampler": self._get_train_sampler(),
            "drop_last": False,
            "worker_init_fn": seed_worker,
            "pin_memory": True,
        }
        return DataLoader(dataset, **dataloader_params)

    def validation_dataloader(
        self,
        dataset: HFDataset,
        batch_size: int,
    ) -> DataLoader:
        """Iterate over dataset batches & encode text"""

        return DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=self.data_collator.collate_batch,
            num_workers=1,
            sampler=torch.utils.data.sampler.SequentialSampler(dataset),
            drop_last=False,
            worker_init_fn=seed_worker,
            pin_memory=True,
        )

    def _get_train_sampler(self) -> torch.utils.data.sampler.Sampler:
        return (
            RandomSampler(self.train_dataset)
            if self.n_gpus == 1
            else torch.utils.data.distributed.DistributedSampler(
                self.train_dataset
            )
        )

    def save_state(self, checkpoint_path: str, step: int) -> None:
        save_dir = os.path.join(checkpoint_path, str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.model_provider.save_state(
            save_dir, self.training_model.module, self.optimizer
        )

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
        """
        # TODO: handle multiple nodes, for now same as is_loca_process_zero
        return self.local_rank == 0

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
        return self.local_rank in [-1, 0]

    def train(self) -> None:
        """Methods that encloses the skeleton of training:
        - recover model
        - initialize distributed training
        - iterates over data making training steps
        - evaluates on test set when required"""

        self.torch_model = self.torch_model.to(self.local_rank)
        train_dataloader = self.train_dataloader(
            self.train_dataset, batch_size=self.physical_batch_size
        )
        if self.lr_schedule:
            num_training_steps = (
                self.epochs
                * len(train_dataloader)
                // self.gradient_accumulation_steps
            )
            lr_scheduler = get_cosine_schedule_with_warmup(
                optimizer=self.optimizer,
                num_warmup_steps=self.num_warmup_steps,
                num_training_steps=num_training_steps,
                num_cycles=self.num_cycles,
                last_epoch=-1,
            )
        else:
            lr_scheduler = None
        self.training_model, self.optimizer, _, _ = deepspeed.initialize(
            config=self.deepspeed_config,
            model=self.torch_model,
            optimizer=self.optimizer,
            lr_scheduler=lr_scheduler,
            model_parameters=self.torch_model.parameters(),
        )
        if self.local_rank == 0:
            if self.training_model.fp16_enabled():
                precision = "torch.float16"
            elif self.training_model.bfloat16_enabled():
                precision = "torch.bfloat16"
            else:
                precision = "torch.float32"
            log_begin_training(
                epochs=self.epochs,
                physical_batch_size=self.physical_batch_size,
                gradient_accumulation_steps=self.gradient_accumulation_steps,
                precision=precision,
                n_gpus=self.n_gpus,
            )

        should_log = (
            PT_TENSORBOARD_LOGGING and self.checkpoint_path is not None
        )

        summary_writer = (
            SummaryWriter(
                log_dir=os.path.join(self.checkpoint_path, "tensorboard")  # type: ignore
            )
            if should_log
            else None
        )

        training_loss = torch.tensor(0.0).to(self.local_rank)
        disable_tqdm = not self.is_local_process_zero()
        train_pbar = trange(self.epochs, desc="Epoch", disable=disable_tqdm)
        curr_step = 1
        start = time.time()
        num_samples = 0
        num_tokens = 0

        grad_step_update = 0
        self.training_model.zero_grad()
        for epoch in range(1, self.epochs + 1):
            if isinstance(
                train_dataloader.sampler, torch.utils.data.DistributedSampler
            ):
                train_dataloader.sampler.set_epoch(epoch)

            epoch_iterator = train_dataloader
            epoch_pbar = tqdm(
                epoch_iterator, desc="Iteration", disable=disable_tqdm
            )
            add_metrics = []
            for batch in epoch_iterator:
                loss, additional_metrics = self.training_step(batch)
                self.training_model.step()
                training_loss += loss
                add_metrics.append(additional_metrics)
                is_gradient_update = (
                    curr_step % self.gradient_accumulation_steps == 0
                )
                grad_step_update = (
                    curr_step // self.gradient_accumulation_steps
                )
                tensor_key = next(iter(batch))
                num_tokens += batch[tensor_key].numel() * self.n_gpus
                num_samples += batch[tensor_key].shape[0] ** self.n_gpus

                self._log_step_training(
                    is_gradient_update=is_gradient_update,
                    loss=float(
                        training_loss / self.gradient_accumulation_steps
                    ),
                    additional_metrics=add_metrics,
                    start_time=start,
                    num_samples=num_samples,
                    num_tokens=num_tokens,
                    summary_writer=summary_writer,
                    grad_step_update=grad_step_update,
                )
                if (
                    self.checkpoint_path
                    and self.save_every > 0
                    and grad_step_update > 0
                    and grad_step_update % self.save_every == 0
                ):
                    self.save_state(
                        step=grad_step_update,
                        checkpoint_path=self.checkpoint_path,
                    )
                epoch_pbar.update(1)

                if (
                    self.has_validation
                    and is_gradient_update
                    and self.eval_every_n_grad_steps > 0
                    and grad_step_update % self.eval_every_n_grad_steps == 0
                ):
                    self.evaluate(grad_step_update, summary_writer)
                curr_step += 1
                if is_gradient_update:
                    num_samples = 0
                    num_tokens = 0
                    if should_log:
                        assert summary_writer is not None
                        summary_writer.add_scalar(  # type:ignore[no-untyped-call]
                            "Grad L2 norm",
                            self.training_model._global_grad_norm,
                            global_step=grad_step_update,
                        )
                    add_metrics = []
                    training_loss = torch.tensor(0.0).to(self.local_rank)
                    start = time.time()

            epoch_pbar.close()
            train_pbar.update(1)

        train_pbar.close()
        if self.is_world_process_zero():
            self.save_state(
                checkpoint_path=self.checkpoint_path, step=grad_step_update
            )
        return

    def evaluate(
        self, step: int, summary_writer: t.Optional[SummaryWriter]
    ) -> None:
        eval_dataloader = self.validation_dataloader(
            self.validation_dataset, self.physical_batch_size
        )
        test_loss = torch.tensor(0.0, device=self.local_rank)
        self.training_model.eval()
        add_metrics = []
        for _i, batch in enumerate(eval_dataloader):
            inputs = self._prepare_inputs(batch)
            with torch.no_grad():
                loss, additional_metrics = self.compute_loss(inputs)
                test_loss += loss
                add_metrics.append(additional_metrics)
        test_loss /= _i + 1
        self.log_step_validation(
            test_loss=float(test_loss),
            grad_step_update=step,
            summary_writer=summary_writer,
            additional_metrics=add_metrics,
        )

    def _log_step_training(
        self,
        is_gradient_update: bool,
        grad_step_update: int,
        loss: float,
        start_time: float,
        num_samples: int,
        num_tokens: int,
        summary_writer: t.Optional[SummaryWriter],
        additional_metrics: t.Any,
    ) -> None:
        if is_gradient_update and self.is_world_process_zero():
            log_step_training_info(
                step=grad_step_update,
                is_training=True,
                loss=loss,
            )
            if summary_writer is not None:
                # compute speed metrics
                speed_metrics = hf_speed_metrics(
                    split="train",
                    start_time=start_time,
                    num_samples=num_samples,
                    num_tokens=num_tokens,
                )
                for metric_name, metric_value in speed_metrics.items():
                    summary_writer.add_scalar(  # type:ignore[no-untyped-call]
                        metric_name,
                        float(metric_value),
                        global_step=grad_step_update,
                    )
                summary_writer.add_scalar(  # type:ignore[no-untyped-call]
                    "train_loss",
                    loss,
                    global_step=grad_step_update,
                )

    def log_step_validation(
        self,
        test_loss: float,
        grad_step_update: int,
        summary_writer: t.Optional[SummaryWriter],
        additional_metrics: t.Any,
    ) -> None:
        if self.is_world_process_zero():
            log_step_training_info(
                step=grad_step_update,
                is_training=False,
                loss=test_loss,
            )

            if summary_writer is not None:
                summary_writer.add_scalar(  # type:ignore[no-untyped-call]
                    "validation_loss",
                    test_loss,
                    global_step=grad_step_update,
                )

    # Methods to be implemented by sub--classes

    def compute_loss(
        self,
        batch: t.Dict[str, t.Any],
    ) -> t.Tuple[torch.Tensor, t.Any]:
        raise NotImplementedError


class SFTTrainer(TrainerMixin):
    """Class responsible for standard fine tuning"""

    def __init__(
        self,
        model_provider: SarusModelProvider,
        deepspeed_config: t.Dict[str, t.Any],
        train_ds_uri: str,
        checkpoint_path: str,
        local_rank: int,
        physical_batch_size: int,
        lr_schedule: bool,
        learning_rate: float = 1e-4,
        num_warmup_steps: int = 0,
        num_cycles: float = 0.5,
        validation_dataset_uri: t.Optional[str] = None,
        epochs: int = 1,
        gradient_accumulation_steps: int = 1,
        quantized_optimizer: bool = False,
        eval_every_n_grad_steps: int = -1,
        triton_kernel: bool = False,
        save_every_n_grad_steps: int = -1,
        keep_ds_in_memory: bool = False,  # defaults in HF
    ) -> None:
        self.local_rank = local_rank
        self.n_gpus = torch.distributed.get_world_size()
        self.model_provider = model_provider
        torch_model = model_provider.torch_model(
            self.local_rank, triton_kernel
        )
        optimizer = (
            torch.optim.AdamW(
                params=torch_model.parameters(), lr=learning_rate
            )
            if not quantized_optimizer
            else bnb.optim.adamw.AdamW(
                params=torch_model.parameters(), lr=learning_rate, optim_bits=8
            )
        )
        self.torch_model = torch_model
        self.optimizer = optimizer
        self.triton_kernel = triton_kernel
        # Store other info
        self.eval_every_n_grad_steps = eval_every_n_grad_steps
        self.save_every = save_every_n_grad_steps
        self.checkpoint_path = checkpoint_path
        self.epochs = epochs
        self.physical_batch_size = physical_batch_size
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.train_dataset = HFDataset.load_from_disk(
            train_ds_uri, keep_in_memory=keep_ds_in_memory
        )
        if validation_dataset_uri is not None:
            self.validation_dataset = HFDataset.load_from_disk(
                validation_dataset_uri, keep_in_memory=keep_ds_in_memory
            )
            self.has_validation = True
        else:
            self.has_validation = False
        self.deepspeed_config = deepspeed_config
        self.data_collator = TrainingDataCollator(
            tokenizer_pad_token=0
        )  # the value is not important
        self.lr_schedule = lr_schedule
        self.num_cycles = num_cycles
        self.num_warmup_steps = num_warmup_steps

    def compute_loss(
        self,
        batch: t.Dict[str, torch.Tensor],
    ) -> t.Tuple[torch.Tensor, None]:
        logits = self.training_model.forward(
            tokens=batch["tokens"], triton_kernel=self.triton_kernel
        )
        logits = logits[..., :-1, :].contiguous()
        labels = batch["labels"][..., 1:].contiguous()

        if self.triton_kernel:
            labels = labels.view(-1)
            vocab_size = logits.shape[-1]
            logits = logits.view(-1, vocab_size)
            return t.cast(
                torch.Tensor,
                liger_kernels.LigerCrossEntropyFunction.apply(
                    logits, labels, -100
                ),
            ), None

        logits = logits.transpose(1, 2)
        return t.cast(
            torch.Tensor, torch.nn.CrossEntropyLoss()(logits, labels)
        ), None


class DPOTrainer(TrainerMixin):
    """Trainer for DPO"""

    def __init__(
        self,
        model_provider: SarusModelProvider,
        deepspeed_config: t.Dict[str, t.Any],
        beta_dpo: float,
        train_ds_uri: str,
        checkpoint_path: str,
        local_rank: int,
        physical_batch_size: int,
        lr_schedule: bool,
        learning_rate: float = 1e-4,
        num_warmup_steps: int = 0,
        num_cycles: float = 0.5,
        validation_dataset_uri: t.Optional[str] = None,
        epochs: int = 1,
        gradient_accumulation_steps: int = 1,
        quantized_optimizer: bool = False,
        eval_every_n_grad_steps: int = -1,
        save_every_n_grad_steps: int = -1,
        keep_ds_in_memory: bool = False,
        triton_kernel: bool = False,
    ) -> None:
        self.local_rank = local_rank
        self.n_gpus = torch.distributed.get_world_size()
        self.model_provider = model_provider
        torch_model = model_provider.torch_model(
            self.local_rank, triton_kernel
        )
        optimizer = (
            torch.optim.AdamW(
                params=torch_model.parameters(), lr=learning_rate
            )
            if not quantized_optimizer
            else bnb.optim.adamw.AdamW(
                params=torch_model.parameters(), lr=learning_rate, optim_bits=8
            )
        )
        self.torch_model = torch_model
        self.optimizer = optimizer
        self.triton_kernel = triton_kernel
        # Store other info
        self.eval_every_n_grad_steps = eval_every_n_grad_steps
        self.save_every = save_every_n_grad_steps
        self.checkpoint_path = checkpoint_path
        self.epochs = epochs
        self.physical_batch_size = physical_batch_size
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.train_dataset = HFDataset.load_from_disk(
            train_ds_uri, keep_in_memory=keep_ds_in_memory
        )
        if validation_dataset_uri is not None:
            self.validation_dataset = HFDataset.load_from_disk(
                validation_dataset_uri, keep_in_memory=keep_ds_in_memory
            )
            self.has_validation = True
        else:
            self.has_validation = False
        self.deepspeed_config = deepspeed_config
        self.beta_dpo = beta_dpo
        self.data_collator = PreferenceDataCollator(
            tokenizer_pad_token=0
        )  # the value is not important
        self.lr_schedule = lr_schedule
        self.num_cycles = num_cycles
        self.num_warmup_steps = num_warmup_steps

    def compute_loss(
        self,
        batch: t.Dict[str, torch.Tensor],
    ) -> t.Tuple[torch.Tensor, t.Any]:
        new_chosen_log_probs, new_rejected_log_probs = (
            self.concatenate_forward(batch)
        )

        with torch.no_grad(), disable_adapter(self.torch_model):
            reference_chosen_log_probs, reference_rejected_log_probs = (
                self.concatenate_forward(batch)
            )

        chosen_rewards = (
            self.beta_dpo
            * (new_chosen_log_probs - reference_chosen_log_probs).detach()
        )
        rejected_rewards = (
            self.beta_dpo
            * (new_rejected_log_probs - reference_rejected_log_probs).detach()
        )
        logratios = new_chosen_log_probs - new_rejected_log_probs
        ref_logratios = (
            reference_chosen_log_probs - reference_rejected_log_probs
        )

        return (
            -F.logsigmoid(self.beta_dpo * (logratios - ref_logratios)).mean(),
            (
                chosen_rewards,
                rejected_rewards,
                new_chosen_log_probs,
                new_rejected_log_probs,
                reference_chosen_log_probs,
                reference_rejected_log_probs,
            ),
        )

    def concatenate_forward(
        self, batch: t.Dict[str, torch.Tensor]
    ) -> t.Tuple[torch.Tensor, torch.Tensor]:
        concatenated_ids = torch.cat(
            [batch["chosen_tokens"], batch["rejected_tokens"]]
        )
        concatenated_labels = torch.cat(
            [batch["chosen_labels"], batch["rejected_labels"]]
        )
        all_logits = self.training_model.forward(
            tokens=concatenated_ids, triton_kernel=self.triton_kernel
        )

        all_log_probs = get_log_probs(all_logits, concatenated_labels)
        chosen_log_probs = all_log_probs[: all_log_probs.shape[0] // 2]
        rejected_log_probs = all_log_probs[all_log_probs.shape[0] // 2 :]
        return chosen_log_probs, rejected_log_probs

    def log_step_validation(
        self,
        test_loss: float,
        grad_step_update: int,
        summary_writer: t.Optional[SummaryWriter],
        additional_metrics: t.Any,
    ) -> None:
        super().log_step_validation(
            test_loss=test_loss,
            grad_step_update=grad_step_update,
            summary_writer=summary_writer,
            additional_metrics=additional_metrics,
        )
        if summary_writer is not None:
            self._log_additional_dpo(
                additional_metrics=additional_metrics,
                summary_writer=summary_writer,
                grad_step_update=grad_step_update,
                name_prefix="validation",
            )
        return

    def _log_step_training(
        self,
        is_gradient_update: bool,
        grad_step_update: int,
        loss: float,
        start_time: float,
        num_samples: int,
        num_tokens: int,
        summary_writer: t.Optional[SummaryWriter],
        additional_metrics: t.Any,
    ) -> None:
        super()._log_step_training(
            is_gradient_update=is_gradient_update,
            grad_step_update=grad_step_update,
            loss=loss,
            start_time=start_time,
            num_samples=num_samples,
            num_tokens=num_tokens,
            summary_writer=summary_writer,
            additional_metrics=additional_metrics,
        )
        if summary_writer is not None:
            self._log_additional_dpo(
                additional_metrics=additional_metrics,
                summary_writer=summary_writer,
                grad_step_update=grad_step_update,
                name_prefix="train",
            )

    def _log_additional_dpo(
        self,
        additional_metrics: t.List[t.Tuple[torch.Tensor, ...]],
        summary_writer: SummaryWriter,
        grad_step_update: int,
        name_prefix: str,
    ) -> None:
        (
            chosen_rewards,
            rejected_rewards,
            new_chosen_log_probs,
            new_rejected_log_probs,
            reference_chosen_log_probs,
            reference_rejected_log_probs,
        ) = [torch.cat(el) for el in zip(*additional_metrics)]
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
        log_dict = {
            name_prefix
            + "_rewards/chosen_difference": chosen_rewards.mean().cpu(),
            name_prefix
            + "_rewards/rejected_difference": rejected_rewards.mean().cpu(),
            name_prefix
            + "_rewards/accuracies": reward_accuracies.mean().cpu(),
            name_prefix + "_rewards/margins": (
                chosen_rewards - rejected_rewards
            )
            .mean()
            .cpu(),
            name_prefix
            + "_log_probs/rejected": new_rejected_log_probs.detach()
            .mean()
            .cpu(),
            name_prefix + "_log_probs/chosen": new_chosen_log_probs.detach()
            .mean()
            .cpu(),
        }

        if summary_writer is not None:
            for name, value in log_dict.items():
                summary_writer.add_scalar(  # type:ignore[no-untyped-call]
                    name,
                    value,
                    global_step=grad_step_update,
                )

        return


def get_log_probs(
    all_logits: torch.Tensor, all_labels: torch.Tensor
) -> torch.Tensor:
    all_logits = all_logits[..., :-1, :].contiguous()
    all_labels = all_labels[..., 1:].contiguous()
    mask = all_labels != -100

    all_labels[~mask] = 0
    log_probs = torch.gather(
        F.log_softmax(all_logits, dim=-1),
        dim=2,
        index=all_labels.unsqueeze(-1),
    ).squeeze()

    log_probs *= mask
    return log_probs.sum(-1)


class DPTrainer(SFTTrainer):
    def __init__(
        self,
        model_provider: SarusModelProvider,
        deepspeed_config: t.Dict[str, t.Any],
        noise_multiplier: float,
        train_ds_uri: str,
        l2_norm_clip: float,
        checkpoint_path: str,
        local_rank: int,
        physical_batch_size: int,
        lr_schedule: bool,
        learning_rate: float = 1e-4,
        num_warmup_steps: int = 0,
        num_cycles: float = 0.5,
        validation_dataset_uri: t.Optional[str] = None,
        epochs: int = 1,
        gradient_accumulation_steps: int = 1,
        quantized_optimizer: bool = False,
        eval_every_n_grad_steps: int = -1,
        save_every_n_grad_steps: int = -1,
        keep_ds_in_memory: bool = False,
        triton_kernel: bool = False,
    ) -> None:
        super().__init__(
            model_provider=model_provider,
            deepspeed_config=deepspeed_config,
            train_ds_uri=train_ds_uri,
            checkpoint_path=checkpoint_path,
            local_rank=local_rank,
            validation_dataset_uri=validation_dataset_uri,
            epochs=epochs,
            gradient_accumulation_steps=gradient_accumulation_steps,
            quantized_optimizer=quantized_optimizer,
            eval_every_n_grad_steps=eval_every_n_grad_steps,
            physical_batch_size=physical_batch_size,
            save_every_n_grad_steps=save_every_n_grad_steps,
            keep_ds_in_memory=keep_ds_in_memory,
            learning_rate=learning_rate,
            lr_schedule=lr_schedule,
            num_warmup_steps=num_warmup_steps,
            num_cycles=num_cycles,
            triton_kernel=triton_kernel,
        )
        self.noise_multiplier = noise_multiplier
        self.l2_norm_clip = l2_norm_clip

        # Wrap with DP
        _ = PrivacyEngine(
            module=self.torch_model,
            epochs=epochs,
            noise_multiplier=noise_multiplier,
            clipping_mode="MixOpt",
            clipping_fn="Abadi",
            max_grad_norm=l2_norm_clip,
            batch_size=physical_batch_size * gradient_accumulation_steps,
            sample_size=1,  # only needed for accounting but not a kwarg
            num_GPUs=self.n_gpus,
        )

    def _log_step_training(
        self,
        is_gradient_update: bool,
        grad_step_update: int,
        loss: float,
        start_time: float,
        num_samples: int,
        num_tokens: int,
        summary_writer: t.Optional[SummaryWriter],
        additional_metrics: t.Any,
    ) -> None:
        super()._log_step_training(
            is_gradient_update=is_gradient_update,
            grad_step_update=grad_step_update,
            loss=loss,
            start_time=start_time,
            num_samples=num_samples,
            num_tokens=num_tokens,
            summary_writer=summary_writer,
            additional_metrics=additional_metrics,
        )

        if (
            is_gradient_update
            and self.is_world_process_zero()
            and summary_writer is not None
        ):
            epsilon = compute_epsilon(
                steps=grad_step_update,
                noise_multiplier=self.noise_multiplier,
                batch_size=self.physical_batch_size
                * self.gradient_accumulation_steps
                * self.n_gpus,
                dataset_size=len(self.train_dataset),
                target_delta=1e-5,
            )
            summary_writer.add_scalar(  # type:ignore[no-untyped-call]
                "Epsilon for Delta=1e-5",
                epsilon,
                global_step=grad_step_update,
            )


def seed_worker(_) -> None:  # type:ignore[no-untyped-def]
    """
    Copied from HuggingFace
    Helper function to set worker seed during Dataloader initialization.
    """
    worker_seed = torch.initial_seed() % 2**32
    set_seed(worker_seed)


def set_seed(seed: int) -> None:
    """
    Copied from HuggingFace
    Helper function for reproducible behavior to set the seed in `random`,
    `numpy`, `torch` and/or `tf` (if installed).

    Args:
        seed (`int`): The seed to set.
    """
    import random

    import numpy as np

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def log_begin_training(
    epochs: int,
    physical_batch_size: int,
    gradient_accumulation_steps: int,
    precision: str,
    n_gpus: int,
) -> None:
    logger.info("***** Running training *****")
    logger.info(f" Num epochs = {epochs}")
    logger.info(
        f"  Gradient accumulation steps = {gradient_accumulation_steps}"
    )
    logger.info(f" Physical Batch size per GPU = {physical_batch_size}")
    logger.info(f" Number of GPUs = {n_gpus}")
    logger.info(f"Precision: {precision}")


def log_step_training_info(
    step: int, loss: float, is_training: bool = True
) -> None:
    """This method logs minimal training info in the stdout.
    More infos are logged in the SummaryWriter if tensorboard
    exists"""
    train = "Train " if is_training else "Test"
    logger.info(f"***** Training step {step} *****")
    logger.info(f"Current {train} Loss = {loss}")


def default_deepspeed_config() -> t.Dict[str, t.Any]:
    with open(
        os.path.join(Path(__file__).parent, "deepspeed_config.json")
    ) as file:
        config: t.Dict[str, t.Any] = json.load(file)
    return config


def hf_speed_metrics(
    split: str,
    start_time: float,
    num_samples: t.Optional[int] = None,
    num_steps: t.Optional[int] = None,
    num_tokens: t.Optional[int] = None,
) -> t.Dict[str, float]:
    """
    Measure and return speed performance metrics.

    This function requires a time snapshot `start_time` before the operation to be measured starts and this function
    should be run immediately after the operation to be measured has completed.

    Args:
    - split: name to prefix metric (like train, eval, test...)
    - start_time: operation start time
    - num_samples: number of samples processed
    - num_steps: number of steps processed
    - num_tokens: number of tokens processed
    """
    runtime = time.time() - start_time
    result = {f"{split}_runtime": round(runtime, 4)}
    if runtime == 0:
        return result
    if num_samples is not None:
        samples_per_second = num_samples / runtime
        result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
    if num_steps is not None:
        steps_per_second = num_steps / runtime
        result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
    if num_tokens is not None:
        tokens_per_second = num_tokens / runtime
        result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
    return result