import os
from abc import ABC, abstractmethod
import torch.testing._internal.dist_utils
class RpcAgentTestFixture(ABC):
@property
def world_size(self):
return 4
@property
def init_method(self):
use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
if use_tcp_init == "1":
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
return f"tcp://{master_addr}:{master_port}"
else:
return self.file_init_method
@property
def file_init_method(self):
return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
file_name=self.file_name
)
@property
@abstractmethod
def rpc_backend(self):
pass
@property
@abstractmethod
def rpc_backend_options(self):
pass
def setup_fault_injection(self, faulty_messages, messages_to_delay):
"""Method used by dist_init to prepare the faulty agent.
Does nothing for other agents.
"""
pass
# Shutdown sequence is not well defined, so we may see any of the following
# errors when running tests that simulate errors via a shutdown on the
# remote end.
@abstractmethod
def get_shutdown_error_regex(self):
"""
Return various error message we may see from RPC agents while running
tests that check for failures. This function is used to match against
possible errors to ensure failures were raised properly.
"""
pass
@abstractmethod
def get_timeout_error_regex(self):
"""
Returns a partial string indicating the error we should receive when an
RPC has timed out. Useful for use with assertRaisesRegex() to ensure we
have the right errors during timeout.
"""
pass