import copy
import inspect
import itertools
import logging
import os
import sys
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum, auto
from typing import Callable, Any, Type
import torch
import torch.distributed as dist
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import (
Join,
Joinable,
JoinHook,
)
from torch.utils._pytree import tree_flatten, tree_unflatten
RPC_AVAILABLE = False
if dist.is_available():
from torch.distributed.utils import (
_verify_param_shape_across_processes,
_sync_module_states,
_to_kwargs,
)
from torch.distributed.distributed_c10d import ReduceOp, _get_default_group
if torch.distributed.rpc.is_available():
RPC_AVAILABLE = True
from torch.distributed.rpc import RRef
from torch._utils import _get_device_index
from ..modules import Module
from ._replicated_tensor_ddp_utils import _ddp_with_replicated_tensor_enabled
from .scatter_gather import gather, scatter_kwargs # noqa: F401
__all__ = ["DistributedDataParallel"]
logger = logging.getLogger(__name__)
def _tree_flatten_with_rref(output):
output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
if output_is_rref:
output_tensor_list, treespec = tree_flatten(output.local_value())
else:
output_tensor_list, treespec = tree_flatten(output)
# Need to return flattened tensors, spec to re-pack them, as well
# as if the return type was actually an RRef to reconstruct.
return output_tensor_list, treespec, output_is_rref
def _tree_unflatten_with_rref(output, treespec, output_is_rref):
output = tree_unflatten(output, treespec)
if output_is_rref:
output = RRef(output)
return output
def _find_tensors(obj):
r"""
Recursively find all tensors contained in the specified object.
"""
if RPC_AVAILABLE and isinstance(obj, RRef):
# If the current node is the owner of the RRef, unwrap it and try to
# find Tensors.
# TODO: Expand to remote RRefs.
if obj.is_owner():
return _find_tensors(obj.local_value())
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
if is_dataclass(obj):
return itertools.chain(
*map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
)
return []
def _dump_DDP_relevant_env_vars():
relevant_env_vars = [
"RANK",
"LOCAL_RANK",
"WORLD_SIZE",
"MASTER_PORT",
"MASTER_ADDR",
"CUDA_VISIBLE_DEVICES",
"GLOO_SOCKET_IFNAME",
"GLOO_DEVICE_TRANSPORT",
"NCCL_SOCKET_IFNAME",
"NCCL_BLOCKING_WAIT",
"NCCL_DEBUG",
"NCCL_DEBUG_SUBSYS",
"NCCL_IB_DISABLE",
# More NCCL env vars:
"NCCL_P2P_DISABLE",
"NCCL_P2P_LEVEL",
"NCCL_SHM_DISABLE",
"NCCL_SOCKET_NTHREADS",
"NCCL_NSOCKS_PERTHREAD",
"NCCL_BUFFSIZE",
"NCCL_NTHREADS",
"NCCL_RINGS",
"NCCL_MAX_NCHANNELS",
"NCCL_MIN_NCHANNELS",
"NCCL_CHECKS_DISABLE",
"NCCL_CHECK_POINTERS",
"NCCL_LAUNCH_MODE",
"NCCL_IB_HCA",
"NCCL_IB_TIMEOUT",
"NCCL_IB_RETRY_CNT",
"NCCL_IB_GID_INDEX",
"NCCL_IB_SL",
"NCCL_IB_TC",
"NCCL_IB_AR_THRESHOLD",
"NCCL_IB_CUDA_SUPPORT",
"NCCL_NET_GDR_LEVEL",
"NCCL_NET_GDR_READ",
"NCCL_SINGLE_RING_THRESHOLD",
"NCCL_LL_THRESHOLD",
"NCCL_TREE_THRESHOLD",
"NCCL_ALGO",
"NCCL_PROTO",
"NCCL_IGNORE_CPU_AFFINITY",
"NCCL_DEBUG_FILE",
"NCCL_COLLNET_ENABLE",
"NCCL_TOPO_FILE",
"NCCL_TOPO_DUMP_FILE",
"NCCL_ASYNC_ERROR_HANDLING",
]
formatted_output = ""
for var in relevant_env_vars:
value = os.environ[var] if var in os.environ else "N/A"
formatted_output += "env:%s=%s\n" % (var, value)
print(formatted_output)
class _BufferCommHookLocation(Enum):
PRE_FORWARD = auto()
POST_FORWARD = auto()
@dataclass
class _BufferCommHook:
buffer_comm_hook: Callable
buffer_comm_hook_state: Any
buffer_comm_hook_location: _BufferCommHookLocation
# Add a DDPSink to run various functions when backwards starts, such as
# queueing call back of out-most backward/graph task,
# this helps call back is fired after all gradients' calculation
# is completed.
class _DDPSink(Function):
@staticmethod
def forward(ctx, reducer, state_dict, *inputs):
# set_materialize_grads(False) will ensure that None gradients stay as
# None and are not filled with zeros.
ctx.set_materialize_grads(False)
ctx.reducer = reducer
ctx.state_dict = state_dict
ret = tuple(
inp.clone() if isinstance(inp, torch.Tensor) else inp
for inp in inputs
)
return ret
@staticmethod
def backward(ctx, *grad_outputs):
state_dict = ctx.state_dict
# Enqueue delay allreduce for static graph training on the first
# iteration.
if (
ctx.state_dict["static_graph"]
and ctx.state_dict["num_iterations"] == 1
):
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
ctx.reducer._delay_all_reduce
)
return (None, None, *grad_outputs)
class _DDPJoinHook(JoinHook):
def __init__(self, ddp, divide_by_initial_world_size):
"""
Sets config variables for internal usage.
"""
assert isinstance(ddp, DistributedDataParallel), (
"DDP join hook requires passing in a DistributedDataParallel "
"instance as the state"
)
assert ddp.logger is not None
ddp.logger._set_uneven_input_join()
self.ddp = ddp
self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
super().__init__()
def main_hook(self):
"""
Shadows the DDP collective communication operations in the forward and
backward passes.
"""
ddp = self.ddp
# Buckets are rebuilt only once during a training period
ddp.reducer._rebuild_buckets()
# Schedule a broadcast if we are syncing module buffers in the
# forward pass
# TODO: make DDP uneven inputs context manager support buffer
# comm hook (https://github.com/pytorch/pytorch/issues/65436)
ddp._check_and_sync_module_buffers()
# Check if need to sync in the backward pass
work = ddp._check_global_requires_backward_grad_sync(
is_joined_rank=True
)
work.wait()
should_sync_backwards = work.result()[0].item() != 0
# Forward parameter sync is disabled in the next iteration if we
# are skipping gradient sync this iteration, so set
# `require_forward_param_sync` accordingly
ddp.require_forward_param_sync = should_sync_backwards
if not should_sync_backwards:
return
# Schedule one allreduce per gradient bucket to match the backward
# pass allreduce
ddp._match_all_reduce_for_bwd_pass()
# Check if we need to allreduce locally unused parameters
if ddp.find_unused_parameters:
ddp._match_unused_params_allreduce()
# Rebuilt parameters are pushed only once during a training period
ddp.reducer._push_all_rebuilt_params()
def post_hook(self, is_last_joiner: bool):
"""
Syncs the final model to ensure that the model is the same across all
processes.
"""
self.ddp._sync_final_model(is_last_joiner)
class DistributedDataParallel(Module, Joinable):
r"""Implements distributed data parallelism that is based on
``torch.distributed`` package at the module level.
This container provides data parallelism by synchronizing gradients
across each model replica. The devices to synchronize across are
specified by the input ``process_group``, which is the entire world
by default. Note that ``DistributedDataParallel`` does not chunk or
otherwise shard the input across participating GPUs; the user is
responsible for defining how to do so, for example through the use
of a :class:`DistributedSampler`.
See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
The same constraints on input as in :class:`torch.nn.DataParallel` apply.
Creation of this class requires that ``torch.distributed`` to be already
initialized, by calling :func:`torch.distributed.init_process_group`.
``DistributedDataParallel`` is proven to be significantly faster than
:class:`torch.nn.DataParallel` for single-node multi-GPU data
parallel training.
To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
up ``N`` processes, ensuring that each process exclusively works on a single
GPU from 0 to N-1. This can be done by either setting
``CUDA_VISIBLE_DEVICES`` for every process or by calling:
>>> # xdoctest: +SKIP("undefined variables")
>>> torch.cuda.set_device(i)
where i is from 0 to N-1. In each process, you should refer the following
to construct this module:
>>> # xdoctest: +SKIP("undefined variables")
>>> torch.distributed.init_process_group(
>>> backend='nccl', world_size=N, init_method='...'
>>> )
>>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
In order to spawn up multiple processes per node, you can use either
``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
.. note::
Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
for a brief introduction to all features related to distributed training.
.. note::
``DistributedDataParallel`` can be used in conjunction with
:class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
per-rank optimizer states memory footprint. Please refer to
`ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
for more details.
.. note:: ``nccl`` backend is currently the fastest and highly recommended
backend when using GPUs. This applies to both single-node and
multi-node distributed training.
.. note:: This module also supports mixed-precision distributed training.
This means that your model can have different types of parameters such
as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
mixed types of parameters will just work fine.
.. note:: If you use ``torch.save`` on one process to checkpoint the module,
and ``torch.load`` on some other processes to recover it, make sure that
``map_location`` is configured properly for every process. Without
``map_location``, ``torch.load`` would recover the module to devices
where the module was saved from.
.. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
gradient will be ``M`` times smaller when compared to the same model
trained on a single node with ``batch=M*N`` if the loss is summed (NOT
averaged as usual) across instances in a batch (because the gradients
between different nodes are averaged). You should take this into
consideration when you want to obtain a mathematically equivalent
training process compared to the local training counterpart. But in most
cases, you can just treat a DistributedDataParallel wrapped model, a
DataParallel wrapped model and an ordinary model on a single GPU as the
same (E.g. using the same learning rate for equivalent batch size).
.. note::
Parameters are never broadcast between processes. The module performs
an all-reduce step on gradients and assumes that they will be modified
by the optimizer in all processes in the same way. Buffers
(e.g. BatchNorm stats) are broadcast from the module in process of rank
0, to all other replicas in the system in every iteration.
.. note::
If you are using DistributedDataParallel in conjunction with the
:ref:`distributed-rpc-framework`, you should always use
:meth:`torch.distributed.autograd.backward` to compute gradients and
Loading ...