Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ nn / parallel / distributed.py

import copy
import inspect
import itertools
import logging
import os
import sys
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum, auto
from typing import Callable, Any, Type

import torch
import torch.distributed as dist
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import (

from torch.utils._pytree import tree_flatten, tree_unflatten

if dist.is_available():
    from torch.distributed.utils import (
    from torch.distributed.distributed_c10d import ReduceOp, _get_default_group
if torch.distributed.rpc.is_available():
    from torch.distributed.rpc import RRef

from torch._utils import _get_device_index

from ..modules import Module
from ._replicated_tensor_ddp_utils import _ddp_with_replicated_tensor_enabled
from .scatter_gather import gather, scatter_kwargs  # noqa: F401

__all__ = ["DistributedDataParallel"]

logger = logging.getLogger(__name__)

def _tree_flatten_with_rref(output):
    output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
    if output_is_rref:
        output_tensor_list, treespec = tree_flatten(output.local_value())
        output_tensor_list, treespec = tree_flatten(output)
    # Need to return flattened tensors, spec to re-pack them, as well
    # as if the return type was actually an RRef to reconstruct.
    return output_tensor_list, treespec, output_is_rref

def _tree_unflatten_with_rref(output, treespec, output_is_rref):
    output = tree_unflatten(output, treespec)
    if output_is_rref:
        output = RRef(output)
    return output

def _find_tensors(obj):
    Recursively find all tensors contained in the specified object.
    if RPC_AVAILABLE and isinstance(obj, RRef):
        # If the current node is the owner of the RRef, unwrap it and try to
        # find Tensors.
        # TODO: Expand to remote RRefs.
        if obj.is_owner():
            return _find_tensors(obj.local_value())
    if isinstance(obj, torch.Tensor):
        return [obj]
    if isinstance(obj, (list, tuple)):
        return itertools.chain(*map(_find_tensors, obj))
    if isinstance(obj, dict):
        return itertools.chain(*map(_find_tensors, obj.values()))
    if is_dataclass(obj):
        return itertools.chain(
            *map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))

    return []

def _dump_DDP_relevant_env_vars():
    relevant_env_vars = [
        # More NCCL env vars:
    formatted_output = ""
    for var in relevant_env_vars:
        value = os.environ[var] if var in os.environ else "N/A"
        formatted_output += "env:%s=%s\n" % (var, value)

class _BufferCommHookLocation(Enum):
    PRE_FORWARD = auto()
    POST_FORWARD = auto()

class _BufferCommHook:
    buffer_comm_hook: Callable
    buffer_comm_hook_state: Any
    buffer_comm_hook_location: _BufferCommHookLocation

# Add a DDPSink to run various functions when backwards starts, such as
# queueing call back of out-most backward/graph task,
# this helps call back is fired after all gradients' calculation
# is completed.
class _DDPSink(Function):
    def forward(ctx, reducer, state_dict, *inputs):
        # set_materialize_grads(False) will ensure that None gradients stay as
        # None and are not filled with zeros.
        ctx.reducer = reducer
        ctx.state_dict = state_dict
        ret = tuple(
            inp.clone() if isinstance(inp, torch.Tensor) else inp
            for inp in inputs
        return ret

    def backward(ctx, *grad_outputs):
        state_dict = ctx.state_dict
        # Enqueue delay allreduce for static graph training on the first
        # iteration.
        if (
            and ctx.state_dict["num_iterations"] == 1
            Variable._execution_engine.queue_callback(  # type: ignore[call-arg,misc]

        return (None, None, *grad_outputs)

class _DDPJoinHook(JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        Sets config variables for internal usage.
        assert isinstance(ddp, DistributedDataParallel), (
            "DDP join hook requires passing in a DistributedDataParallel "
            "instance as the state"
        assert ddp.logger is not None
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size

    def main_hook(self):
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        # TODO: make DDP uneven inputs context manager support buffer
        # comm hook (https://github.com/pytorch/pytorch/issues/65436)

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:

        # Rebuilt parameters are pushed only once during a training period

    def post_hook(self, is_last_joiner: bool):
        Syncs the final model to ensure that the model is the same across all

class DistributedDataParallel(Module, Joinable):
    r"""Implements distributed data parallelism that is based on
    ``torch.distributed`` package at the module level.

    This container provides data parallelism by synchronizing gradients
    across each model replica. The devices to synchronize across are
    specified by the input ``process_group``, which is the entire world
    by default. Note that ``DistributedDataParallel`` does not chunk or
    otherwise shard the input across participating GPUs; the user is
    responsible for defining how to do so, for example through the use
    of a :class:`DistributedSampler`.

    See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
    The same constraints on input as in :class:`torch.nn.DataParallel` apply.

    Creation of this class requires that ``torch.distributed`` to be already
    initialized, by calling :func:`torch.distributed.init_process_group`.

    ``DistributedDataParallel`` is proven to be significantly faster than
    :class:`torch.nn.DataParallel` for single-node multi-GPU data
    parallel training.

    To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
    up ``N`` processes, ensuring that each process exclusively works on a single
    GPU from 0 to N-1. This can be done by either setting
    ``CUDA_VISIBLE_DEVICES`` for every process or by calling:

        >>> # xdoctest: +SKIP("undefined variables")
        >>> torch.cuda.set_device(i)

    where i is from 0 to N-1. In each process, you should refer the following
    to construct this module:

        >>> # xdoctest: +SKIP("undefined variables")
        >>> torch.distributed.init_process_group(
        >>>     backend='nccl', world_size=N, init_method='...'
        >>> )
        >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)

    In order to spawn up multiple processes per node, you can use either
    ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.

    .. note::
        Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
        for a brief introduction to all features related to distributed training.

    .. note::
        ``DistributedDataParallel`` can be used in conjunction with
        :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
        per-rank optimizer states memory footprint. Please refer to
        `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
        for more details.

    .. note:: ``nccl`` backend is currently the fastest and highly recommended
        backend when using GPUs. This applies to both single-node and
        multi-node distributed training.

    .. note:: This module also supports mixed-precision distributed training.
        This means that your model can have different types of parameters such
        as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
        mixed types of parameters will just work fine.

    .. note:: If you use ``torch.save`` on one process to checkpoint the module,
        and ``torch.load`` on some other processes to recover it, make sure that
        ``map_location`` is configured properly for every process. Without
        ``map_location``, ``torch.load`` would recover the module to devices
        where the module was saved from.

    .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
        gradient will be ``M`` times smaller when compared to the same model
        trained on a single node with ``batch=M*N`` if the loss is summed (NOT
        averaged as usual) across instances in a batch (because the gradients
        between different nodes are averaged). You should take this into
        consideration when you want to obtain a mathematically equivalent
        training process compared to the local training counterpart. But in most
        cases, you can just treat a DistributedDataParallel wrapped model, a
        DataParallel wrapped model and an ordinary model on a single GPU as the
        same (E.g. using the same learning rate for equivalent batch size).

    .. note::
        Parameters are never broadcast between processes. The module performs
        an all-reduce step on gradients and assumes that they will be modified
        by the optimizer in all processes in the same way. Buffers
        (e.g. BatchNorm stats) are broadcast from the module in process of rank
        0, to all other replicas in the system in every iteration.

    .. note::
        If you are using DistributedDataParallel in conjunction with the
        :ref:`distributed-rpc-framework`, you should always use
        :meth:`torch.distributed.autograd.backward` to compute gradients and
Loading ...