Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch-wrapper / pl / loggers / aim.py
Size: Mime:
"""Copy of aim.sdk.adapters.pytorch_lightning.AimLogger.
- Disabled a few functions
- Enabled reusing run hash (restarting experiment)
"""
from argparse import Namespace
from typing import Any, Dict, Optional, Union

from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT
from aim.sdk.adapters.pytorch_lightning import AimLogger
from aim.sdk.run import Run
from torch_wrapper.aim import Run
from pytorch_lightning.loggers.base import (
    rank_zero_experiment,
)
from pytorch_lightning.utilities import rank_zero_only
from . import LoggerInterface

__all__ = ['AimLoggerCustom']


class AimLoggerCustom(AimLogger, LoggerInterface):
    def __init__(
        self,
        repo: Optional[str] = None,
        experiment: Optional[str] = None,
        system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
        run_hash: str = None,
    ):
        # log_system_params creates too much noise in the params tab.
        super().__init__(
            repo=repo,
            experiment=experiment,
            system_tracking_interval=system_tracking_interval,
            log_system_params=False,
        )
        self._run_hash = run_hash

    def log_hparams(self, hparams: dict):
        self.experiment['hparams'] = hparams

    @property
    @rank_zero_experiment
    def experiment(self) -> Run:
        if self._run is None:
            kwargs = dict(
                repo=self._repo_path,
                experiment=self._experiment_name,
                system_tracking_interval=self._system_tracking_interval,
            )
            if self._run_hash:
                self._run = Run(**kwargs, run_hash=self._run_hash)
            else:
                self._run = Run(**kwargs)
                self._run_hash = self._run.hash
        return self._run

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]):
        raise NotImplementedError("Disabled!")

    @rank_zero_only
    def log_metrics(self, metrics: Dict[str, float],
                    step: Optional[int] = None):
        import warnings
        warnings.warn(
            'NOT tested. Also, the epoch step seems to fail frequently.',
            UserWarning,
        )
        # noinspection PyUnresolvedReferences
        assert rank_zero_only.rank == 0, \
            'experiment tried to log from global_rank != 0'
        try:
            epoch = metrics['epoch']
        except KeyError as e:
            raise KeyError(f'{e}\nEpoch number missing, weird. Set a debugger here.')
        else:
            del metrics['epoch']
            epoch = int(epoch)

        for k, v in metrics.items():
            name = k

            context = {}
            CONTEXT_NAME = 'split'
            if self._train_metric_prefix \
                    and name.startswith(self._train_metric_prefix):
                name = name[len(self._train_metric_prefix):]
                context[CONTEXT_NAME] = 'train'
            elif self._test_metric_prefix \
                    and name.startswith(self._test_metric_prefix):
                name = name[len(self._test_metric_prefix):]
                context[CONTEXT_NAME] = 'test'
            elif self._val_metric_prefix \
                    and name.startswith(self._val_metric_prefix):
                name = name[len(self._val_metric_prefix):]
                context[CONTEXT_NAME] = 'val'
            self.experiment.track(v, name=name, step=step, epoch=epoch, context=context)