import contextlib
import logging
import pickle
import torch
import warnings
import time
from torch._six import string_classes
from datetime import timedelta
from typing import Dict, Optional, Tuple, Union

# This module is wildcard imported from torch.distributed.
# TODO: specify __all__

from .constants import default_pg_timeout
from .rendezvous import rendezvous, register_rendezvous_handler  # noqa: F401
from torch._C._distributed_c10d import (


    from torch._C._distributed_c10d import ProcessGroupMPI
except ImportError:
    _MPI_AVAILABLE = False

    from torch._C._distributed_c10d import ProcessGroupNCCL
except ImportError:

    from torch._C._distributed_c10d import ProcessGroupGloo
except ImportError:

# Some reduce ops are not supported by complex numbers and will result in an error.
# We currently provide complex support to the distributed API by viewing
# complex tensors as real (torch.view_as_real), meaning that calling
# these unsupported ops will return garbage values rather than error out.
# (e.g. max(2+3i, 3+2i) = 3+3i)
# We'd like calls to unsupported ops to error out accordingly,
# rather than returning garbage values.
def supports_complex(reduceOp: ReduceOp) -> bool:
    denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT,
                ReduceOp.BAND, ReduceOp.BOR, ReduceOp.BXOR]
    return reduceOp not in denyList

class Backend(object):
    An enum-like class of available backends: GLOO, NCCL, MPI, and other registered

    The values of this class are lowercase strings, e.g., ``"gloo"``. They can
    be accessed as attributes, e.g., ``Backend.NCCL``.

    This class can be directly called to parse the string, e.g.,
    ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
    return the parsed lowercase string if so. It also accepts uppercase strings,
    e.g., ``Backend("GLOO")`` returns ``"gloo"``.

    .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
              initial value of some fields. Users should neither use it directly
              nor assume its existence.
    UNDEFINED = "undefined"
    GLOO = "gloo"
    NCCL = "nccl"
    MPI = "mpi"
    TCP = "tcp"

    def __new__(cls, name: str):
        if not isinstance(name, string_classes):
            raise ValueError("Backend name must be a string, but got: {}".format(name))
        value = getattr(Backend, name.upper(), Backend.UNDEFINED)

        if value == Backend.TCP:
            raise ValueError("TCP backend has been deprecated. Please use "
                             "Gloo or MPI backend for collective operations "
                             "on CPU tensors.")
        elif value == Backend.UNDEFINED:
            raise ValueError("Invalid backend: '{}'".format(name))
        elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI:
            value = name
        return value

    def register_backend(cls, name, func):
        Registers a new backend.

        This class method is used by 3rd party cpp extension to register new backend.

            name (str): Backend name matching with the one in `init_process_group()`.
            func (function): Function handler that instantiates the backend.
                             The function should be implemented in the backend cpp extension
                             and takes four arguments, including prefix_store, rank,
                             world_size, and timeout.

        .. note:: This support of 3rd party backend is experimental and subject to change.

        setattr(Backend, name.upper(), func)

# `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward
# compatibility with pre-c10d distributed package.
# TODO: remove them when users are ready to take a hard dependency on PyTorch 1.
_backend: str = Backend.UNDEFINED
dist_backend = Backend

class _reduce_op(object):
    Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
    ``MIN``, and ``MAX``.

    :class:`~torch.distributed.ReduceOp` is recommended to use instead.

    def __init__(self):
        # __members__ is a dict storing key-value pairs for enum classes
        for k, v in ReduceOp.__members__.items():
            setattr(self, k, v)
        self.__members__ = ReduceOp.__members__

    def __getattribute__(self, key):
        warnings.warn("torch.distributed.reduce_op is deprecated, please use "
                      "torch.distributed.ReduceOp instead")
        return object.__getattribute__(self, key)

reduce_op = _reduce_op()

class group(object):
    # Points to the default PG once initialized.
    WORLD: Optional[ProcessGroup] = None

class GroupMember(object):
    # Alias to group.WORLD for backward compatibility
    WORLD = group.WORLD
    NON_GROUP_MEMBER = object()

# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}

# Default process group state
_default_pg_init_method = None

# Process group count for default naming
_group_count = 0

STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"

def _store_based_barrier(rank, store, timeout):
    Barrier based on store which is used for synchronizing processes after
    ``init_process_group`` or ``new_group``. Intended to be used only with
    those two methods and is not a generic alternative to ``barrier()``.
    store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
    store.add(store_key, 1)
    logging.info('Added key: {} to store for rank: {}'.format(store_key, rank))

    # Now wait for all workers to check in with the store.
    world_size = get_world_size()
    # Use 'add' instead of 'get' since for some store implementations 'add'
    # doesn't work well with 'get'. Ideally the store implementations should
    # be fixed, but for backward compatiblity reasons it is risky to change
    # the store implementations. Once, we completely migrate away from these
    # legacy stores, we can use 'get' here instead.
    worker_count = store.add(store_key, 0)
    start = time.time()
    log_time = time.time()
    while worker_count != world_size:
        worker_count = store.add(store_key, 0)

        # Print status periodically to keep track.
        if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10):
                "Waiting in store based barrier to initialize process group for "
                "rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format(
                    rank, store_key, world_size, worker_count, timeout))
            log_time = time.time()

        if timedelta(seconds=(time.time() - start)) > timeout:
            raise RuntimeError(
                "Timed out initializing process group in store based barrier on "
                "rank: {}, for key: {} (world_size={}, worker_count={}, timeout={})".format(
                    rank, store_key, world_size, worker_count, timeout))

def _rank_not_in_group(group: ProcessGroup):
    Helper that checks if the current process's rank is not in a given group.
    if group is None:
        return False
    return group == GroupMember.NON_GROUP_MEMBER

def _get_group_rank(group: ProcessGroup, rank):
    Helper that gets a given group's local rank in the group from a given global
    if group is GroupMember.WORLD:
        raise RuntimeError("group.WORLD does not have local rank to global "
                           "rank mapping")
    if group not in _pg_group_ranks:
        raise RuntimeError("The given group does not exist")
        group_rank = _pg_group_ranks[group][rank]
    except KeyError:
        raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
    return group_rank

def _get_global_rank(group, group_rank):
    Helper that gets a given group's global rank from a given local rank in the
    if group is GroupMember.WORLD:
        raise RuntimeError("group.WORLD does not have local rank to global "
                           "rank mapping")
    group_rank_map = _pg_group_ranks[group]
    for rank, grp_rank in group_rank_map.items():
        if grp_rank == group_rank:
            return rank
    raise RuntimeError("The group rank is not part of the group")

def _get_group_size(group):
    Helper that gets a given group's world size.
    if group is GroupMember.WORLD or group is None:
        default_pg = _get_default_group()
        return default_pg.size()
    if group not in _pg_group_ranks:
        raise RuntimeError("The given group does not exist")
    return len(_pg_group_ranks[group])

def _check_single_tensor(param, param_name):
    Helper to check that the parameter ``param_name`` is a single tensor.
    if not isinstance(param, torch.Tensor):
        raise RuntimeError("Invalid function argument. Expected parameter `{}` "
                           "to be of type torch.Tensor.".format(param_name))

def _check_tensor_list(param, param_name):
    Helper to check that the parameter ``param_name`` is a list of tensors.
    if not isinstance(param, list) or \
       not all(isinstance(p, torch.Tensor) for p in param):
        raise RuntimeError("Invalid function argument. Expected parameter `{}` "
                           "to be of type List[torch.Tensor].".format(param_name))

def _check_op(op):
    Helper to check that the ``op`` is either isend or irecv.
    if op not in [isend, irecv]:
        raise RuntimeError("Invalid ``op``. Expected ``op`` "
                           "to be of type ``torch.distributed.isend`` or "

def _check_p2p_op_list(p2p_op_list):
    Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
    all ops use the same backend.
    if not isinstance(p2p_op_list, list) or \
       not all(isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
        raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
                           "to be of type ``torch.distributed.P2POp``.")

    backend = get_backend(p2p_op_list[0].group)
    if not all(backend == get_backend(p2p_op.group) for p2p_op in p2p_op_list):
        raise RuntimeError("All groups need to use the same backend.")

def is_mpi_available():
    Checks if the MPI backend is available.
    return _MPI_AVAILABLE

def is_nccl_available():
    Checks if the NCCL backend is available.
    return _NCCL_AVAILABLE

def is_gloo_available():
    Checks if the Gloo backend is available.
    return _GLOO_AVAILABLE

def is_initialized():
    Checking if the default process group has been initialized
    return GroupMember.WORLD is not None

def _get_default_group():
    Getting the default process group created by init_process_group
Loading ...