Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import os
from typing import Any, Mapping, Union

import torch

import pytorch_lightning as pl
from lightning_fabric.utilities.spike import SpikeDetection as FabricSpikeDetection
from pytorch_lightning.callbacks.callback import Callback


class SpikeDetection(FabricSpikeDetection, Callback):
    @torch.no_grad()
    def on_train_batch_end(  # type: ignore
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        outputs: Union[torch.Tensor, Mapping[str, torch.Tensor]],
        batch: Any,
        batch_idx: int,
    ) -> None:
        if isinstance(outputs, torch.Tensor):
            loss = outputs.detach()
        elif isinstance(outputs, Mapping):
            loss = outputs["loss"].detach()
        else:
            raise TypeError(f"outputs have to be of type torch.Tensor or Mapping, got {type(outputs).__qualname__}")

        if self.exclude_batches_path is None:
            self.exclude_batches_path = os.path.join(trainer.default_root_dir, "skip_batches.json")

        return FabricSpikeDetection.on_train_batch_end(self, trainer, loss, batch, batch_idx)  # type: ignore