Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ testing / _internal / distributed / rpc / rpc_agent_test_fixture.py

import os
from abc import ABC, abstractmethod

import torch.testing._internal.dist_utils


class RpcAgentTestFixture(ABC):
    @property
    def world_size(self):
        return 4

    @property
    def init_method(self):
        use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
        if use_tcp_init == "1":
            master_addr = os.environ["MASTER_ADDR"]
            master_port = os.environ["MASTER_PORT"]
            return f"tcp://{master_addr}:{master_port}"
        else:
            return self.file_init_method

    @property
    def file_init_method(self):
        return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
            file_name=self.file_name
        )

    @property
    @abstractmethod
    def rpc_backend(self):
        pass

    @property
    @abstractmethod
    def rpc_backend_options(self):
        pass

    def setup_fault_injection(self, faulty_messages, messages_to_delay):
        """Method used by dist_init to prepare the faulty agent.

        Does nothing for other agents.
        """
        pass

    # Shutdown sequence is not well defined, so we may see any of the following
    # errors when running tests that simulate errors via a shutdown on the
    # remote end.
    @abstractmethod
    def get_shutdown_error_regex(self):
        """
        Return various error message we may see from RPC agents while running
        tests that check for failures. This function is used to match against
        possible errors to ensure failures were raised properly.
        """
        pass

    @abstractmethod
    def get_timeout_error_regex(self):
        """
        Returns a partial string indicating the error we should receive when an
        RPC has timed out. Useful for use with assertRaisesRegex() to ensure we
        have the right errors during timeout.
        """
        pass