Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / _training.py
Size: Mime:
import typing
from os import PathLike
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.utils.data

from torch_wrapper.metrics import Metric
from .callbacks import ModelCheckpoint
from .loggers.aim import AimLoggerCustom
from .lr_scheduler import LRScheduler

__all__ = [
    'get_trainer', 'PLModule', 'PLDataModule'
]


def get_trainer(
    *,
    pl_module: 'PLModule',
    model_checkpoint='use_last',
    repo: PathLike,
    callbacks: list = None,
    logger: AimLoggerCustom,
    max_epochs,
):
    if callbacks is None:
        callbacks = []
    if model_checkpoint == 'use_last':
        model_checkpoint = ModelCheckpoint(Path(repo) / logger.version / 'checkpoints')
        callbacks.append(model_checkpoint)
    else:
        raise NotImplementedError(model_checkpoint)
    return _Trainer(
        pl_trainer=pl.Trainer(
            default_root_dir=str(repo),
            callbacks=callbacks,
            logger=logger,
            max_epochs=max_epochs,
            check_val_every_n_epoch=1,
            accelerator="gpu" if torch.cuda.is_available() else None,
            # Ref: https://pytorch-lightning.readthedocs.io/en/latest/advanced/training_tricks.html#batch-size-finder
            auto_scale_batch_size=False,
        ),
        pl_module=pl_module,
        model_checkpoint=model_checkpoint,
    )


class _Trainer:
    def __init__(self, pl_trainer: pl.Trainer, pl_module, model_checkpoint: ModelCheckpoint):
        self.pl_trainer = pl_trainer
        self.pl_module = pl_module
        self.model_checkpoint = model_checkpoint

    def fit(self, data: 'PLDataModule'):
        if self.model_checkpoint.last_ckpt.exists():
            self.pl_trainer.fit(model=self.pl_module, datamodule=data, ckpt_path=str(self.model_checkpoint.last_ckpt))
        else:
            self.pl_trainer.fit(model=self.pl_module, datamodule=data)


# Ref: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningDataModule.html
class PLDataModule(pl.LightningDataModule):
    """
    A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage
    is consistent data splits, data preparation and transforms across models.
    """

    def __init__(self, datasets: dict, batch_size: int = 1, num_workers: int = 0, collate_fn=None):
        super(PLDataModule, self).__init__()
        self.datasets = datasets
        # this batch_size attributes can be used later for auto-tuning.
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn

    def _helper_dataloader(self, split):
        if split not in ("train", "val", "test"):
            raise ValueError(split)
        try:
            dataset = self.datasets[split]
        except KeyError:
            raise KeyError(f'data for {split} not prepared.')
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def train_dataloader(self):
        return self._helper_dataloader('train')

    def val_dataloader(self):
        return self._helper_dataloader('val')

    def test_dataloader(self):
        return self._helper_dataloader('test')


class PLModule(pl.LightningModule):
    def __init__(
        self,
        *,
        metrics: typing.List[Metric],
        loss: Metric,
        model: torch.nn.Module,
        opt_params: dict = None,
        logger: AimLoggerCustom,
        track_grad_norm=False,  # manual tracking grad norm despite pytorch_lightning have same func.
        clip_gradient: typing.Union[None, typing.Tuple[str, float]] = None,
        lr_scheduler: LRScheduler = None,
    ):
        """

        Parameters
        ----------
        track_grad_norm: False/-1 for not tracking, other integer for tracking, and the norm used for tracking
        clip_gradient:
            None: not clip
            or: (method, value), this will clip by method with value.
            method: see torch.nn.utils.clip_grad_norm_ (assume 2 norm) and clip_grad_value_

        TODO:
        - currently, the lr_scheduler won't work properly after restart.
        """
        super().__init__()
        # disable automatic optimization, must put after super().__init__()
        self.automatic_optimization = False
        self.main_model = model

        # saves the arguments passed into the LightningModule
        # note: hparams API seems to be confusing for many:
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/3981
        # therefore, we don't use self.save_hyperparameters any more
        # self.save_hyperparameters(hparams)

        self._metrics = metrics
        self._loss = loss
        self._logger = logger
        self._opt_params = opt_params
        self._track_grad_norm = track_grad_norm
        self._clip_gradient = clip_gradient
        self._lr_scheduler = lr_scheduler

    def forward(self, x):
        return self.main_model(x)

    def _shared_train_val_step(self, batch, split: str):
        x, y = batch
        y_pred = self.main_model(x)
        loss = self._loss(y_label=y, y_pred=y_pred)
        self._logger.experiment.track(
            name=self._loss.name, value=loss,
            step=self.global_step, epoch=self.current_epoch, context={'split': split}
        )
        return {
            # loss key tells pytorch_lightning which value to optimize, it is important and fixed.
            "loss": loss,
            # store y and y_pred so that we can later accumulate them to compute other metrics
            "y_label": y,
            "y_pred": y_pred,
        }

    def training_step(self, batch, batch_idx):
        """manual training step

        Ref: https://pytorch-lightning.readthedocs.io/en/latest/common/optimization.html
        """
        opt = self.optimizers()
        # noinspection PyUnresolvedReferences
        opt.zero_grad()
        output = self._shared_train_val_step(batch, 'train')
        self.manual_backward(output['loss'])
        self._manual_track_grad()
        self._manual_clip_norm()

        opt.step()
        return output

    def validation_step(self, batch, batch_idx):
        output = self._shared_train_val_step(batch, 'val')
        return output

    def _manual_clip_norm(self):
        if self._clip_gradient is not None:
            method, value = self._clip_gradient
            f_ = {
                'norm': torch.nn.utils.clip_grad_norm_,
                'value': torch.nn.utils.clip_grad_value_,
            }[method]
            f_(self.parameters(), value)

    def _manual_track_grad(self) -> None:
        """manual tracking gradient norm"""
        if not self._track_grad_norm:
            return
        model = self.main_model
        run = self._logger.experiment
        name = f'Grad_{self._track_grad_norm}_norm'
        for tag, value in model.named_parameters():
            if value.grad is not None:
                grad_norm = torch.norm(value.grad.detach(), p=self._track_grad_norm).cpu()
                run.track(
                    name=f'{name}/{tag}', value=grad_norm, step=self.global_step, epoch=self.current_epoch,
                    context={'split': 'train'})

    def _shared_train_val_epoch_end(self, outputs, split: str):
        # noinspection PyUnresolvedReferences
        logs = {}
        # y_label and y_pred needs to be accumulated first
        accumulated = {
            k: torch.concat([_[k] for _ in outputs], dim=0)
            for k in ('y_label', 'y_pred')
        }
        # then compute metrics
        for m in self._metrics:
            logs[m.name] = m(y_label=accumulated['y_label'], y_pred=accumulated['y_pred'])
        # now write output
        for k, v in logs.items():
            self._logger.experiment.track(
                name=k, value=v, step=self.global_step, epoch=self.current_epoch,
                context={'split': split},
            )
        return logs

    def training_epoch_end(self, outputs):
        self._shared_train_val_epoch_end(outputs, split="train")
        if self._lr_scheduler is not None:
            sch = self.lr_schedulers()
            # noinspection PyArgumentList
            sch.step()

    def validation_epoch_end(self, outputs) -> None:
        self._shared_train_val_epoch_end(outputs, split="val")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), **self._opt_params)
        if self._lr_scheduler is None:
            return optimizer
        return [optimizer], [self._lr_scheduler.get_pytorch(optimizer)]