Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / pipeline / sync / microbatch.py

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Any, Callable, List, Union, cast, Sequence

import torch
from torch import Tensor
import torch.cuda.comm

__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]


Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]]


class NoChunk:
    """
    Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor
    should not be chunked on the batch dimension and instead be replicated
    as-is across all micro-batches. This is useful for tensors which might
    not have any 'batch' semantics for the model.
    """
    def __init__(self, inp: Tensor):
        if not torch.is_tensor(inp):
            raise TypeError(f'NoChunk only supported for tensors, found: {inp}')
        self._tensor = inp

    @property
    def tensor(self):
        return self._tensor


class Batch:
    """
    An abstraction representing a microbatch in the pipeline.
    """

    def __init__(self, values: Union[List[Any], Tensor]) -> None:
        self._values = values
        self.atomic = torch.is_tensor(values)

        # Verify at least on tensor
        if not self.atomic:
            if not any(torch.is_tensor(value) for value in self._values):
                raise TypeError(f'No tensors found in batch: {self._values}')

    @property
    def tensor(self) -> Tensor:
        """Retrieves the underlying tensor."""
        if not self.atomic:
            raise AttributeError("not atomic batch")
        return cast(Tensor, self._values)

    @property
    def values(self):
        """Retreives the underlying values for the batch"""
        return self._values

    def find_tensor_idx(self):
        """
        Retrieves the index of first tensor found.
        """
        if self.atomic:
            return 0
        for i, value in enumerate(self._values):
            if torch.is_tensor(value):
                return i

        raise TypeError("No tensor found!")

    def get_device(self):
        """
        Retrieves the device for this microbatch.
        """
        if self.atomic:
            return self._values.device  # type: ignore[union-attr]

        for value in self._values:
            if torch.is_tensor(value):
                return value.device

    def call(self, function: Function) -> "Batch":
        """Calls a function on the microbatch. It also wraps
        the output with :class:`Batch`.
        """
        if self.atomic:
            return Batch(function(self._values))
        else:
            return Batch(function(*self._values))

    def __repr__(self) -> str:
        return f"Batch[atomic={self.atomic!r}]({self._values!r})"

    def __iter__(self):
        if self.atomic:
            yield self._values
        else:
            yield from self._values

    def __len__(self) -> int:
        return 1 if self.atomic else len(self._values)

    def __getitem__(self, index: int):
        if not self.atomic:
            return self._values[index]

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        return self._values

    # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
    @typing.overload
    def __setitem__(self, index: int, value: Tensor) -> None:
        ...

    @typing.overload
    def __setitem__(self, index: slice, value: Tensors) -> None:
        ...

    def __setitem__(self, index: Union[int, slice], value) -> None:
        if isinstance(index, int):
            self._setitem_by_index(index, value)
        else:
            self._setitem_by_slice(index, value)

    def _setitem_by_index(self, index: int, value) -> None:
        if not self.atomic:
            i = index
            self._values = self._values[:i] + (value,) + self._values[i + 1 :]  # type: ignore[operator]
            return

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        self._values = value

    def _setitem_by_slice(self, index: slice, value) -> None:
        if not (index.start is index.stop is index.step is None):
            raise NotImplementedError("only slice [:] supported")

        if not self.atomic:
            self._values = value
            return

        if len(value) != 1:
            raise IndexError("atomic batch cannot be replaced with multiple tensors")

        self._values = value[0]


def check(first_device, *inputs) -> None:
    """
    Checks whether the input contains at least one tensor and each tensor is
    on the same device as the first partition.

    Raises:
        ValueError: input does not contain at least one tensor

    """

    if not any(torch.is_tensor(input) for input in inputs):
        raise TypeError(f'inputs do not have any tensors: {inputs}')
    if any(torch.is_tensor(input) and input.device != first_device for input in inputs):
        raise ValueError('All inputs should be on the same device as the first partition')


def scatter(*inputs, chunks: int) -> List[Batch]:
    """Splits an input mini-batch into multiple micro-batches."""
    if len(inputs) == 1 and isinstance(inputs[0], Tensor):
        return [Batch(x) for x in inputs[0].chunk(chunks)]

    batches: List[Any] = [[] for _ in range(chunks)]
    # Actual number of chunks produced
    num_chunks = -1
    for input in inputs:
        if torch.is_tensor(input):
            # Chunk only tensors.
            tensors = input.chunk(chunks)

            # Validate number of chunks equal across all inputs.
            if num_chunks != -1 and num_chunks != len(tensors):
                raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}')
            num_chunks = len(tensors)

            for i, tensor in enumerate(tensors):
                batches[i].append(tensor)
        else:
            # Replicate non-tensors or tensors wrapped with 'NoChunk'.
            for i in range(chunks):
                if isinstance(input, NoChunk):
                    # Extract the tensor out.
                    batches[i].append(input.tensor)
                else:
                    batches[i].append(input)

    # Truncate to actual number of chunks
    batches = batches[:num_chunks]

    return [Batch(x) for x in batches]


def gather(outputs: List[Batch]):
    """Concatenates output micro-batches into a mini-batch."""
    output: Any

    if outputs[0].atomic:
        tensors = tuple(b.tensor for b in outputs)
        output = torch.cat(tensors)
    else:
        output_buf: List[Any] = []
        for i in range(len(outputs[0])):
            output_type = type(outputs[0][i])
            current_outputs = []
            for batch in outputs:
                if output_type != type(batch[i]):
                    raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}')
                current_outputs.append(batch[i])

            if torch.is_tensor(outputs[0][i]):
                output_buf.append(torch.cat(current_outputs))
            else:
                output_buf.append(current_outputs)

        output = tuple(output_buf)

    return output