# Copyright 2019 Kakao Brain
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, Sequence, Tuple, cast

import torch
from torch import Tensor, nn
from torch.distributed.rpc import RRef
import torch.autograd
import torch.cuda

from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream

__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]

Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]

Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

    # Typechecking: nn.Module is not a Generic
    Module = nn.Module[TensorOrTensors]  # type: ignore[type-arg]
    NamedModules = OrderedDict[str, Module]
    Module = nn.Module
    NamedModules = OrderedDict

def _recommend_auto_balance(message: str) -> str:
    """Expands a message with recommendation to :mod:`torchpipe.balance`."""
    return f"""{message}

If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for
naive automatic balancing:

  from torch.distributed.pipeline.sync import Pipe
  from torch.distributed.pipeline.sync.balance import balance_by_time

  partitions = torch.cuda.device_count()
  sample = torch.empty(...)
  balance = balance_by_time(partitions, model, sample)

  model = Pipe(model, balance, ...)

def _verify_module(module: nn.Sequential) -> None:
    if not isinstance(module, nn.Sequential):
        raise TypeError("module must be nn.Sequential to be partitioned")

    named_children = list(module.named_children())
    if len(named_children) != len(module):
        raise ValueError("module with duplicate children is not supported")

def _verify_splitting(
    module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device]
) -> None:
    num_parameters = len(list(module.parameters()))
    num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
    if num_parameters == num_child_parameters:

    for i in range(len(partitions)):
        for j in range(i + 1, len(partitions)):
            parti = partitions[i]
            partj = partitions[j]
            if devices[i] == devices[j]:
            for p in parti.parameters():
                for q in partj.parameters():
                    if p is q:
                        raise ValueError("module with duplicate parameters on distinct devices is not supported")

class BalanceError(ValueError):

def _retrieve_device(module: nn.Module) -> torch.device:
    """Validates all parameters in the Module have the same device and returns
    the appropriate device.

        An ``nn.Module`` to process.

        ``torch.Device`` for the entire module.

            If devices for ``nn.Module`` parameters are not all same.

    device = None
    for parameter in module.parameters():
        if device is None:
            device = parameter.device
        elif device != parameter.device:
            raise ValueError(
                'nn.Module: {}, should have all parameters on a single device,'
                ' please use .to() to place the module on a single device'.format(module))

    return device if device is not None else torch.device("cpu")

class PipeSequential(nn.Sequential):
    Pipe variant of ``nn.Sequential`` which supports multiple inputs.

    def forward(self, *inputs):
        for module in self:
            if isinstance(inputs, Tuple):  # type: ignore[arg-type]
                inputs = module(*inputs)
                # Don't expand single variables (ex: lists/Tensor)
                inputs = module(inputs)
        return inputs

class WithDevice(nn.Module):
    Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe`
    that overrides the device for that module. In cases where :class:`Pipe`
    can't implicitly determine the device for the module and places it on CPU,
    this wrapper can be used to override the implicit behavior and explicitly
    specify which device a module should run on.

    The provided module is also moved to the given device via ``.to(device)``
    by :class:`Pipe`

        module(:class:`torch.nn.Module`): The module to be wrapped.
        device(:class:`torch.device`): The device to run the module on.

        >>> # xdoctest: +SKIP("distributed")
        >>> fc1 = nn.Linear(16, 8).cuda(0)
        >>> fc2 = nn.Linear(8, 4).cuda(1)
        >>> dropout = nn.Dropout()
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
        >>> # Dropout does not have any parameters/buffers, but we want to
        >>> # run it on cuda:1 to avoid any GPU to CPU transfers.
        >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
        >>> # xdoctest: +SKIP("Needs RPC framework init")
        >>> model = Pipe(model, chunks=8)
    def __init__(self, module: nn.Module, device: torch.device):
        self._module = module
        self._device = torch.device(device)

    def forward(self, *args, **kwargs):
        return self._module(*args, **kwargs)

    def module(self):
        return self._module

    def device(self):
        return self._device

def _assemble_partition(modules: List[nn.Module]):
    modules_list: List[nn.Module] = []
    for module in modules:
        if isinstance(module, nn.Sequential):
    return PipeSequential(*modules_list)

def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]:
    partitions = []
    devices = []

    current_partition = []
    current_device = None
    for name, module in modules.named_children():
        if isinstance(module, WithDevice):
            # Process device override and move module to appropriate device.
            device = module.device
            module = module.module
            device = _retrieve_device(module)
        if current_device is not None and (current_device != device or device.type == 'cpu'):
            current_partition = []
        current_device = device

    if current_device is not None:

    partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))

    return partitions, devices

MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")

class Pipe(Module):
    """Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
    to train on using synchronous pipeline parallelism. If the module requires
    lots of memory and doesn't fit on a single GPU, pipeline parallelism is a
    useful technique to employ for training.

    The implementation is based on the torchgpipe_ paper.

    .. _torchgpipe: https://arxiv.org/abs/2004.09910

    Pipe combines pipeline parallelism with checkpointing to reduce peak
    memory required to train while minimizing device under-utilization.

    You should place all the modules on the appropriate devices and wrap them
    into an :class:`nn.Sequential <torch.nn.Sequential>` module defining the
    desired order of execution. If a module does not contain any
    parameters/buffers, it is assumed this module should be executed on CPU
    and appropriate input tensors to the module are moved to CPU before
    execution. This behavior can be overridden by the :class:`WithDevice`
    wrapper which can be used to explicitly specify which device a module
    should run on.

        module (:class:`nn.Sequential <torch.nn.Sequential>`):
            sequential module to be parallelized using pipelining. Each module
            in the sequence has to have all of its parameters on a single
            device. Each module in the sequence has to either be an nn.Module
            or :class:`nn.Sequential <torch.nn.Sequential>` (to combine multiple
            sequential modules on a single device)
        chunks (int):
            number of micro-batches (default: ``1``)
        checkpoint (str):
            when to enable checkpointing, one of ``'always'``,
            ``'except_last'``, or ``'never'`` (default: ``'except_last'``).
            ``'never'`` disables checkpointing completely, ``'except_last'``
            enables checkpointing for all micro-batches except the last one
            and ``'always'`` enables checkpointing for all micro-batches.
        deferred_batch_norm (bool):
            whether to use deferred ``BatchNorm`` moving statistics (default:
            :data:`False`). If set to :data:`True`, we track statistics across
            multiple micro-batches to update the running statistics per

            the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
            invalid arguments

        Pipeline of two FC layers across GPUs 0 and 1.

        >>> # Need to initialize RPC framework first.
        >>> # xdoctest: +SKIP
        >>> os.environ['MASTER_ADDR'] = 'localhost'
        >>> os.environ['MASTER_PORT'] = '29500'
        >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)
        >>> # Build pipe.
        >>> fc1 = nn.Linear(16, 8).cuda(0)
        >>> fc2 = nn.Linear(8, 4).cuda(1)
        >>> model = nn.Sequential(fc1, fc2)
        >>> model = Pipe(model, chunks=8)
        >>> input = torch.rand(16, 16).cuda(0)
        >>> output_rref = model(input)

    .. note::
        You can wrap a :class:`Pipe` model with
        :class:`torch.nn.parallel.DistributedDataParallel` only when the
        checkpoint parameter of :class:`Pipe` is ``'never'``.

    .. note::
        :class:`Pipe` only supports intra-node pipelining currently, but
        will be expanded to support inter-node pipelining in the future.
        The forward function returns an :class:`~torch.distributed.rpc.RRef`
        to allow for inter-node pipelining in the future, where the output
        might be on a remote host. For intra-node pipelinining you can use
        :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the
        output locally.

    .. warning::
        :class:`Pipe` is experimental and subject to change.

    def __init__(
        module: nn.Sequential,
        chunks: int = 1,
        checkpoint: str = "except_last",
        deferred_batch_norm: bool = False,
    ) -> None:

        # Check if RPC framework is initialized.
        if not torch.distributed.rpc._is_current_rpc_agent_set():
            raise RuntimeError(
                'Please initialize RPC framework for Pipe using '

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if chunks <= 0:
            raise ValueError("number of chunks must be positive integer")
        if checkpoint not in ["always", "except_last", "never"]:
            raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")


        # Verify if the underlying skippable modules satisfy integrity. The
        # integrity can be verified before forward() because it is static.

        self.chunks = chunks
        self.checkpoint = checkpoint

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)

        self.partitions, self.devices = _split_module(module)
        _verify_splitting(module, self.partitions, self.devices)
