Repository URL to install this package:
|
Version:
0.0.20 ▾
|
"""Copy of aim.sdk.adapters.pytorch_lightning.AimLogger.
- Disabled a few functions
- Enabled reusing run hash (restarting experiment)
"""
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union
from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT
from aim.sdk.adapters.pytorch_lightning import AimLogger
from aim.sdk.run import Run
from torch_wrapper.aim import Run
from pytorch_lightning.loggers.base import (
rank_zero_experiment,
)
from pytorch_lightning.utilities import rank_zero_only
from . import LoggerInterface
__all__ = ['AimLoggerCustom']
class AimLoggerCustom(AimLogger, LoggerInterface):
# use this set of prefix so that the logger can properly assign the context of the logged value.
# Ref: https://aimstack.io/aim-basics-using-context-and-subplots-to-compare-validation-and-test-metrics/
PREFIX = {x: x + '_' for x in ('train', 'val', 'test')}
def __init__(
self,
repo: Optional[str] = None,
experiment: Optional[str] = None, # a string shown in Aim dashboard identifying name of this experiment
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
run_hash: str = None,
):
# log_system_params creates too much noise in the params tab.
super().__init__(
repo=repo,
experiment=experiment,
system_tracking_interval=system_tracking_interval,
log_system_params=False,
train_metric_prefix='train_',
val_metric_prefix='val_',
test_metric_prefix='test_',
)
self._run_hash = run_hash
def log_hparams(self, hparams: dict):
self.experiment['hparams'] = hparams
@property
@rank_zero_experiment
def experiment(self) -> Run:
if self._run is None:
kwargs = dict(
repo=self._repo_path,
experiment=self._experiment_name,
system_tracking_interval=self._system_tracking_interval,
)
if self._run_hash:
self._run = Run(**kwargs, run_hash=self._run_hash)
else:
self._run = Run(**kwargs)
self._run_hash = self._run.hash
return self._run
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]):
raise NotImplementedError("Disabled!")
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float],
step: Optional[int] = None):
# noinspection PyUnresolvedReferences
assert rank_zero_only.rank == 0, \
'experiment tried to log from global_rank != 0'
try:
epoch = metrics['epoch']
except KeyError as e:
raise KeyError(f'{e}\nEpoch number missing, weird. Set a debugger here.')
else:
del metrics['epoch']
epoch = int(epoch)
for k, v in metrics.items():
name = k
context = {}
CONTEXT_NAME = 'split'
if self._train_metric_prefix \
and name.startswith(self._train_metric_prefix):
name = name[len(self._train_metric_prefix):]
context[CONTEXT_NAME] = 'train'
elif self._test_metric_prefix \
and name.startswith(self._test_metric_prefix):
name = name[len(self._test_metric_prefix):]
context[CONTEXT_NAME] = 'test'
elif self._val_metric_prefix \
and name.startswith(self._val_metric_prefix):
name = name[len(self._val_metric_prefix):]
context[CONTEXT_NAME] = 'val'
self.experiment.track(v, name=name, step=step, epoch=epoch, context=context)
# customized logger attributes
def log_root(self):
"""Return the root path of this aim logger. ".aim" is located inside this root path."""
return Path(self.experiment.repo.root_path)
def log_dir(self) -> Path:
"""Return a dir specific to this version of logged experiment."""
return self.log_root() / self.version