import contextlib
import logging
import pickle
import torch
import warnings
import time
from torch._six import string_classes
from datetime import timedelta
from typing import Dict, Optional, Tuple, Union
# This module is wildcard imported from torch.distributed.
# TODO: specify __all__
from .constants import default_pg_timeout
from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401
from torch._C._distributed_c10d import (
AllreduceOptions,
AllreduceCoalescedOptions,
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
GatherOptions,
PrefixStore,
ProcessGroup,
ReduceOptions,
ReduceOp,
ReduceScatterOptions,
ScatterOptions,
Store,
)
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
try:
from torch._C._distributed_c10d import ProcessGroupMPI
except ImportError:
_MPI_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
except ImportError:
_NCCL_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupGloo
except ImportError:
_GLOO_AVAILABLE = False
# Some reduce ops are not supported by complex numbers and will result in an error.
# We currently provide complex support to the distributed API by viewing
# complex tensors as real (torch.view_as_real), meaning that calling
# these unsupported ops will return garbage values rather than error out.
# (e.g. max(2+3i, 3+2i) = 3+3i)
# We'd like calls to unsupported ops to error out accordingly,
# rather than returning garbage values.
def supports_complex(reduceOp: ReduceOp) -> bool:
denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT,
ReduceOp.BAND, ReduceOp.BOR, ReduceOp.BXOR]
return reduceOp not in denyList
class Backend(object):
"""
An enum-like class of available backends: GLOO, NCCL, MPI, and other registered
backends.
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
be accessed as attributes, e.g., ``Backend.NCCL``.
This class can be directly called to parse the string, e.g.,
``Backend(backend_str)`` will check if ``backend_str`` is valid, and
return the parsed lowercase string if so. It also accepts uppercase strings,
e.g., ``Backend("GLOO")`` returns ``"gloo"``.
.. note:: The entry ``Backend.UNDEFINED`` is present but only used as
initial value of some fields. Users should neither use it directly
nor assume its existence.
"""
UNDEFINED = "undefined"
GLOO = "gloo"
NCCL = "nccl"
MPI = "mpi"
TCP = "tcp"
def __new__(cls, name: str):
if not isinstance(name, string_classes):
raise ValueError("Backend name must be a string, but got: {}".format(name))
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
if value == Backend.TCP:
raise ValueError("TCP backend has been deprecated. Please use "
"Gloo or MPI backend for collective operations "
"on CPU tensors.")
elif value == Backend.UNDEFINED:
raise ValueError("Invalid backend: '{}'".format(name))
elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI:
value = name
return value
@classmethod
def register_backend(cls, name, func):
"""
Registers a new backend.
This class method is used by 3rd party cpp extension to register new backend.
Args:
name (str): Backend name matching with the one in `init_process_group()`.
func (function): Function handler that instantiates the backend.
The function should be implemented in the backend cpp extension
and takes four arguments, including prefix_store, rank,
world_size, and timeout.
.. note:: This support of 3rd party backend is experimental and subject to change.
"""
setattr(Backend, name.upper(), func)
# `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward
# compatibility with pre-c10d distributed package.
# TODO: remove them when users are ready to take a hard dependency on PyTorch 1.
_backend: str = Backend.UNDEFINED
dist_backend = Backend
class _reduce_op(object):
r"""
Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
``MIN``, and ``MAX``.
:class:`~torch.distributed.ReduceOp` is recommended to use instead.
"""
def __init__(self):
# __members__ is a dict storing key-value pairs for enum classes
for k, v in ReduceOp.__members__.items():
setattr(self, k, v)
self.__members__ = ReduceOp.__members__
def __getattribute__(self, key):
warnings.warn("torch.distributed.reduce_op is deprecated, please use "
"torch.distributed.ReduceOp instead")
return object.__getattribute__(self, key)
reduce_op = _reduce_op()
class group(object):
# Points to the default PG once initialized.
WORLD: Optional[ProcessGroup] = None
class GroupMember(object):
# Alias to group.WORLD for backward compatibility
WORLD = group.WORLD
NON_GROUP_MEMBER = object()
# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
# Default process group state
_default_pg_init_method = None
# Process group count for default naming
_group_count = 0
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
def _store_based_barrier(rank, store, timeout):
"""
Barrier based on store which is used for synchronizing processes after
``init_process_group`` or ``new_group``. Intended to be used only with
those two methods and is not a generic alternative to ``barrier()``.
"""
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
store.add(store_key, 1)
logging.info('Added key: {} to store for rank: {}'.format(store_key, rank))
# Now wait for all workers to check in with the store.
world_size = get_world_size()
# Use 'add' instead of 'get' since for some store implementations 'add'
# doesn't work well with 'get'. Ideally the store implementations should
# be fixed, but for backward compatiblity reasons it is risky to change
# the store implementations. Once, we completely migrate away from these
# legacy stores, we can use 'get' here instead.
worker_count = store.add(store_key, 0)
start = time.time()
log_time = time.time()
while worker_count != world_size:
time.sleep(0.01)
worker_count = store.add(store_key, 0)
# Print status periodically to keep track.
if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10):
logging.info(
"Waiting in store based barrier to initialize process group for "
"rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format(
rank, store_key, world_size, worker_count, timeout))
log_time = time.time()
if timedelta(seconds=(time.time() - start)) > timeout:
raise RuntimeError(
"Timed out initializing process group in store based barrier on "
"rank: {}, for key: {} (world_size={}, worker_count={}, timeout={})".format(
rank, store_key, world_size, worker_count, timeout))
def _rank_not_in_group(group: ProcessGroup):
"""
Helper that checks if the current process's rank is not in a given group.
"""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER
def _get_group_rank(group: ProcessGroup, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
"rank mapping")
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
try:
group_rank = _pg_group_ranks[group][rank]
except KeyError:
raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
return group_rank
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
"rank mapping")
group_rank_map = _pg_group_ranks[group]
for rank, grp_rank in group_rank_map.items():
if grp_rank == group_rank:
return rank
raise RuntimeError("The group rank is not part of the group")
def _get_group_size(group):
"""
Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD or group is None:
default_pg = _get_default_group()
return default_pg.size()
if group not in _pg_group_ranks:
raise RuntimeError("The given group does not exist")
return len(_pg_group_ranks[group])
def _check_single_tensor(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a single tensor.
"""
if not isinstance(param, torch.Tensor):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
"to be of type torch.Tensor.".format(param_name))
def _check_tensor_list(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a list of tensors.
"""
if not isinstance(param, list) or \
not all(isinstance(p, torch.Tensor) for p in param):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
"to be of type List[torch.Tensor].".format(param_name))
def _check_op(op):
"""
Helper to check that the ``op`` is either isend or irecv.
"""
if op not in [isend, irecv]:
raise RuntimeError("Invalid ``op``. Expected ``op`` "
"to be of type ``torch.distributed.isend`` or "
"``torch.distributed.irecv``.")
def _check_p2p_op_list(p2p_op_list):
"""
Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
all ops use the same backend.
"""
if not isinstance(p2p_op_list, list) or \
not all(isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
"to be of type ``torch.distributed.P2POp``.")
backend = get_backend(p2p_op_list[0].group)
if not all(backend == get_backend(p2p_op.group) for p2p_op in p2p_op_list):
raise RuntimeError("All groups need to use the same backend.")
def is_mpi_available():
"""
Checks if the MPI backend is available.
"""
return _MPI_AVAILABLE
def is_nccl_available():
"""
Checks if the NCCL backend is available.
"""
return _NCCL_AVAILABLE
def is_gloo_available():
"""
Checks if the Gloo backend is available.
"""
return _GLOO_AVAILABLE
def is_initialized():
"""
Checking if the default process group has been initialized
"""
return GroupMember.WORLD is not None
def _get_default_group():
"""
Getting the default process group created by init_process_group
"""
Loading ...