Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ distributed / rpc / _testing / faulty_agent_backend_registry.py

#!/usr/bin/env python3

import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.distributed.distributed_c10d as dc10d
from torch.distributed.rpc import constants as rpc_constants

from datetime import timedelta

def _faulty_process_group_construct_rpc_backend_options_handler(
    rpc_timeout,
    init_method,
    num_send_recv_threads,
    messages_to_fail,
    messages_to_delay,
    num_fail_sends,
    **kwargs
):
    from . import FaultyProcessGroupRpcBackendOptions

    return FaultyProcessGroupRpcBackendOptions(
        rpc_timeout=rpc_timeout,
        init_method=init_method,
        num_send_recv_threads=num_send_recv_threads,
        messages_to_fail=messages_to_fail,
        messages_to_delay=messages_to_delay,
        num_fail_sends=num_fail_sends,
    )

def _faulty_process_group_init_backend_handler(
    store, name, rank, world_size, rpc_backend_options
):
    from . import FaultyProcessGroupAgent

    if dist.is_initialized():
        raise RuntimeError(
            "Process group must not be initialized before init_rpc."
        )

    process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT

    dist.init_process_group(
        backend=dist.Backend.GLOO,
        store=store,
        rank=rank,
        world_size=world_size,
        timeout=process_group_timeout,
    )

    try:
        group = dc10d._get_default_group()
        assert group is not None, "Failed to initialize default ProcessGroup."

        if (rank != -1) and (rank != group.rank()):
            raise RuntimeError(
                "rank argument {} doesn't match pg rank {}".format(rank, group.rank())
            )
        if (world_size != -1) and (world_size != group.size()):
            raise RuntimeError(
                "world_size argument {} doesn't match pg size {}".format(
                    world_size, group.size()
                )
            )

        return FaultyProcessGroupAgent(
            name,
            group,
            rpc_backend_options.num_send_recv_threads,
            timedelta(seconds=rpc_backend_options.rpc_timeout),
            rpc_backend_options.messages_to_fail,
            rpc_backend_options.messages_to_delay,
            rpc_backend_options.num_fail_sends,
        )
    except Exception as ex:
        dist.destroy_process_group()
        raise ex

rpc.backend_registry.register_backend(
    "FAULTY_PROCESS_GROUP",
    _faulty_process_group_construct_rpc_backend_options_handler,
    _faulty_process_group_init_backend_handler,
)