Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
clu / metric_writers / torch_tensorboard_writer.py
Size: Mime:
# Copyright 2025 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter for Pytorch summary files.

Use this writer for the Pytorch-based code.

"""

from collections.abc import Mapping
from typing import Any, Optional
from absl import logging

from clu.metric_writers import interface
from torch.utils import tensorboard

Array = interface.Array
Scalar = interface.Scalar


class TorchTensorboardWriter(interface.MetricWriter):
  """MetricWriter that writes Pytorch summary files."""

  def __init__(self, logdir: str):
    super().__init__()
    self._writer = tensorboard.SummaryWriter(log_dir=logdir)


  def write_summaries(
      self, step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None):
    logging.log_first_n(
        logging.WARNING,
        "TorchTensorboardWriter does not support writing raw summaries.", 1)

  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    for key, value in scalars.items():
      self._writer.add_scalar(key, value, global_step=step)

  def write_images(self, step: int, images: Mapping[str, Array]):
    for key, value in images.items():
      self._writer.add_image(key, value, global_step=step, dataformats="HWC")

  def write_videos(self, step: int, videos: Mapping[str, Array]):
    logging.log_first_n(
        logging.WARNING,
        "TorchTensorBoardWriter does not support writing videos.", 1)

  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    for key, value in audios.items():
      self._writer.add_audio(
          key, value, global_step=step, sample_rate=sample_rate)

  def write_texts(self, step: int, texts: Mapping[str, str]):
    raise NotImplementedError(
        "TorchTensorBoardWriter does not support writing texts."
    )

  def write_histograms(self,
                       step: int,
                       arrays: Mapping[str, Array],
                       num_buckets: Optional[Mapping[str, int]] = None):
    for tag, values in arrays.items():
      bins = None if num_buckets is None else num_buckets.get(tag)
      self._writer.add_histogram(
          tag, values, global_step=step, bins="auto", max_bins=bins)

  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Array] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    logging.log_first_n(
        logging.WARNING,
        "TorchTensorBoardWriter does not support writing point clouds.",
        1,
    )

  def write_hparams(self, hparams: Mapping[str, Any]):
    self._writer.add_hparams(hparams, {})

  def flush(self):
    self._writer.flush()

  def close(self):
    self._writer.close()