Repository URL to install this package:
|
Version:
0.0.7 ▾
|
import typing
from os import PathLike
from pathlib import Path
import pytorch_lightning as pl
import torch
import torch.utils.data
from torch_wrapper.metrics import Metric
from .callbacks import ModelCheckpoint
from .loggers.aim import AimLoggerCustom
from .lr_scheduler import LRScheduler
__all__ = [
'get_trainer', 'PLModule', 'PLDataModule'
]
def get_trainer(
*,
pl_module: 'PLModule',
model_checkpoint='use_last',
repo: PathLike,
callbacks: list = None,
logger: AimLoggerCustom,
max_epochs,
):
if callbacks is None:
callbacks = []
if model_checkpoint == 'use_last':
model_checkpoint = ModelCheckpoint(Path(repo) / logger.version / 'checkpoints')
callbacks.append(model_checkpoint)
else:
raise NotImplementedError(model_checkpoint)
return _Trainer(
pl_trainer=pl.Trainer(
default_root_dir=str(repo),
callbacks=callbacks,
logger=logger,
max_epochs=max_epochs,
check_val_every_n_epoch=1,
accelerator="gpu" if torch.cuda.is_available() else None,
# Ref: https://pytorch-lightning.readthedocs.io/en/latest/advanced/training_tricks.html#batch-size-finder
auto_scale_batch_size=False,
),
pl_module=pl_module,
model_checkpoint=model_checkpoint,
)
class _Trainer:
def __init__(self, pl_trainer: pl.Trainer, pl_module, model_checkpoint: ModelCheckpoint):
self.pl_trainer = pl_trainer
self.pl_module = pl_module
self.model_checkpoint = model_checkpoint
def fit(self, data: 'PLDataModule'):
if self.model_checkpoint.last_ckpt.exists():
self.pl_trainer.fit(model=self.pl_module, datamodule=data, ckpt_path=str(self.model_checkpoint.last_ckpt))
else:
self.pl_trainer.fit(model=self.pl_module, datamodule=data)
# Ref: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningDataModule.html
class PLDataModule(pl.LightningDataModule):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage
is consistent data splits, data preparation and transforms across models.
"""
def __init__(self, datasets: dict, batch_size: int = 1, num_workers: int = 0, collate_fn=None):
super(PLDataModule, self).__init__()
self.datasets = datasets
# this batch_size attributes can be used later for auto-tuning.
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
def _helper_dataloader(self, split):
if split not in ("train", "val", "test"):
raise ValueError(split)
try:
dataset = self.datasets[split]
except KeyError:
raise KeyError(f'data for {split} not prepared.')
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=split == "train",
num_workers=self.num_workers,
collate_fn=self.collate_fn,
)
def train_dataloader(self):
return self._helper_dataloader('train')
def val_dataloader(self):
return self._helper_dataloader('val')
def test_dataloader(self):
return self._helper_dataloader('test')
class PLModule(pl.LightningModule):
def __init__(
self,
*,
metrics: typing.List[Metric],
loss: Metric,
model: torch.nn.Module,
opt_params: dict = None,
logger: AimLoggerCustom,
track_grad_norm=False, # manual tracking grad norm despite pytorch_lightning have same func.
clip_gradient: typing.Union[None, typing.Tuple[str, float]] = None,
lr_scheduler: LRScheduler = None,
):
"""
Parameters
----------
track_grad_norm: False/-1 for not tracking, other integer for tracking, and the norm used for tracking
clip_gradient:
None: not clip
or: (method, value), this will clip by method with value.
method: see torch.nn.utils.clip_grad_norm_ (assume 2 norm) and clip_grad_value_
TODO:
- currently, the lr_scheduler won't work properly after restart.
"""
super().__init__()
# disable automatic optimization, must put after super().__init__()
self.automatic_optimization = False
self.main_model = model
# saves the arguments passed into the LightningModule
# note: hparams API seems to be confusing for many:
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3981
# therefore, we don't use self.save_hyperparameters any more
# self.save_hyperparameters(hparams)
self._metrics = metrics
self._loss = loss
self._logger = logger
self._opt_params = opt_params
self._track_grad_norm = track_grad_norm
self._clip_gradient = clip_gradient
self._lr_scheduler = lr_scheduler
def forward(self, x):
return self.main_model(x)
def _shared_train_val_step(self, batch, split: str):
x, y = batch
y_pred = self.main_model(x)
loss = self._loss(y_label=y, y_pred=y_pred)
self._logger.experiment.track(
name=self._loss.name, value=loss,
step=self.global_step, epoch=self.current_epoch, context={'split': split}
)
return {
# loss key tells pytorch_lightning which value to optimize, it is important and fixed.
"loss": loss,
# store y and y_pred so that we can later accumulate them to compute other metrics
"y_label": y,
"y_pred": y_pred,
}
def training_step(self, batch, batch_idx):
"""manual training step
Ref: https://pytorch-lightning.readthedocs.io/en/latest/common/optimization.html
"""
opt = self.optimizers()
# noinspection PyUnresolvedReferences
opt.zero_grad()
output = self._shared_train_val_step(batch, 'train')
self.manual_backward(output['loss'])
self._manual_track_grad()
self._manual_clip_norm()
opt.step()
return output
def validation_step(self, batch, batch_idx):
output = self._shared_train_val_step(batch, 'val')
return output
def _manual_clip_norm(self):
if self._clip_gradient is not None:
method, value = self._clip_gradient
f_ = {
'norm': torch.nn.utils.clip_grad_norm_,
'value': torch.nn.utils.clip_grad_value_,
}[method]
f_(self.parameters(), value)
def _manual_track_grad(self) -> None:
"""manual tracking gradient norm"""
if not self._track_grad_norm:
return
model = self.main_model
run = self._logger.experiment
name = f'Grad_{self._track_grad_norm}_norm'
for tag, value in model.named_parameters():
if value.grad is not None:
grad_norm = torch.norm(value.grad.detach(), p=self._track_grad_norm).cpu()
run.track(
name=f'{name}/{tag}', value=grad_norm, step=self.global_step, epoch=self.current_epoch,
context={'split': 'train'})
def _shared_train_val_epoch_end(self, outputs, split: str):
# noinspection PyUnresolvedReferences
logs = {}
# y_label and y_pred needs to be accumulated first
accumulated = {
k: torch.concat([_[k] for _ in outputs], dim=0)
for k in ('y_label', 'y_pred')
}
# then compute metrics
for m in self._metrics:
logs[m.name] = m(y_label=accumulated['y_label'], y_pred=accumulated['y_pred'])
# now write output
for k, v in logs.items():
self._logger.experiment.track(
name=k, value=v, step=self.global_step, epoch=self.current_epoch,
context={'split': split},
)
return logs
def training_epoch_end(self, outputs):
self._shared_train_val_epoch_end(outputs, split="train")
if self._lr_scheduler is not None:
sch = self.lr_schedulers()
# noinspection PyArgumentList
sch.step()
def validation_epoch_end(self, outputs) -> None:
self._shared_train_val_epoch_end(outputs, split="val")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), **self._opt_params)
if self._lr_scheduler is None:
return optimizer
return [optimizer], [self._lr_scheduler.get_pytorch(optimizer)]