Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / loggers / aim.py
Size: Mime:
"""Copy of aim.sdk.adapters.pytorch_lightning.AimLogger.
- Disabled a few functions
- Enabled reusing run hash (restarting experiment)
"""
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union

from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT
from aim.sdk.adapters.pytorch_lightning import AimLogger
from aim.sdk.run import Run
from torch_wrapper.aim import Run
from pytorch_lightning.loggers.base import (
    rank_zero_experiment,
)
from pytorch_lightning.utilities import rank_zero_only
from . import LoggerInterface

__all__ = ['AimLoggerCustom']


class AimLoggerCustom(AimLogger, LoggerInterface):
    # use this set of prefix so that the logger can properly assign the context of the logged value.
    # Ref: https://aimstack.io/aim-basics-using-context-and-subplots-to-compare-validation-and-test-metrics/
    PREFIX = {x: x + '_' for x in ('train', 'val', 'test')}

    def __init__(
        self,
        repo: Optional[str] = None,
        experiment: Optional[str] = None,  # a string shown in Aim dashboard identifying name of this experiment
        system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
        run_hash: str = None,
    ):
        # log_system_params creates too much noise in the params tab.
        super().__init__(
            repo=repo,
            experiment=experiment,
            system_tracking_interval=system_tracking_interval,
            log_system_params=False,
            train_metric_prefix='train_',
            val_metric_prefix='val_',
            test_metric_prefix='test_',
        )
        self._run_hash = run_hash

    def log_hparams(self, hparams: dict):
        self.experiment['hparams'] = hparams

    @property
    @rank_zero_experiment
    def experiment(self) -> Run:
        if self._run is None:
            kwargs = dict(
                repo=self._repo_path,
                experiment=self._experiment_name,
                system_tracking_interval=self._system_tracking_interval,
            )
            if self._run_hash:
                self._run = Run(**kwargs, run_hash=self._run_hash)
            else:
                self._run = Run(**kwargs)
                self._run_hash = self._run.hash
        return self._run

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]):
        raise NotImplementedError("Disabled!")

    @rank_zero_only
    def log_metrics(self, metrics: Dict[str, float],
                    step: Optional[int] = None):
        # noinspection PyUnresolvedReferences
        assert rank_zero_only.rank == 0, \
            'experiment tried to log from global_rank != 0'
        try:
            epoch = metrics['epoch']
        except KeyError as e:
            raise KeyError(f'{e}\nEpoch number missing, weird. Set a debugger here.')
        else:
            del metrics['epoch']
            epoch = int(epoch)

        for k, v in metrics.items():
            name = k

            context = {}
            CONTEXT_NAME = 'split'
            if self._train_metric_prefix \
                and name.startswith(self._train_metric_prefix):
                name = name[len(self._train_metric_prefix):]
                context[CONTEXT_NAME] = 'train'
            elif self._test_metric_prefix \
                and name.startswith(self._test_metric_prefix):
                name = name[len(self._test_metric_prefix):]
                context[CONTEXT_NAME] = 'test'
            elif self._val_metric_prefix \
                and name.startswith(self._val_metric_prefix):
                name = name[len(self._val_metric_prefix):]
                context[CONTEXT_NAME] = 'val'
            self.experiment.track(v, name=name, step=step, epoch=epoch, context=context)

    # customized logger attributes
    def log_root(self):
        """Return the root path of this aim logger. ".aim" is located inside this root path."""
        return Path(self.experiment.repo.root_path)

    def log_dir(self) -> Path:
        """Return a dir specific to this version of logged experiment."""
        return self.log_root() / self.version