from typing import Dict, Tuple
import torch
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.distributed.rpc import RRef
from torch.testing._internal.dist_utils import (
dist_init,
worker_name,
wait_until_pending_futures_and_users_flushed
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
@torch.jit.script
def two_args_two_kwargs(
first_arg,
second_arg,
first_kwarg=torch.tensor([3, 3]),
second_kwarg=torch.tensor([4, 4]),
):
return first_arg + second_arg + first_kwarg + second_kwarg
@torch.jit.script
def script_rpc_async_call(
dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
ret = fut.wait()
return ret
@torch.jit.script
def rpc_async_call_with_timeout(
dst_worker_name: str,
args: Tuple[Tensor, Tensor],
kwargs: Dict[str, Tensor],
timeout: float,
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
ret = fut.wait()
return ret
@torch.jit.script
def rpc_async_call_with_timeout_future_ret(
dst_worker_name: str,
args: Tuple[Tensor, Tensor],
kwargs: Dict[str, Tensor],
timeout: float,
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
return fut
@torch.jit.script
def rpc_async_call_future_ret(
dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
return fut
@torch.jit.script
def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
return rref_var.to_here()
@torch.jit.script
def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
return rref_var.to_here(timeout)
@torch.jit.script
def rpc_async_with_rref_arg(dst_worker_name: str, args: Tuple[RRef[Tensor]]) -> Tensor:
fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
ret = fut.wait()
return ret
class JitFaultyAgentRpcTest(RpcAgentTestFixture):
"""
Run tests for rpc_async in JIT under the faulty agent test fixture to test
arbitrary timeouts.
"""
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
def test_timeout_in_torchscript_function(self):
# Call rpc_async + fut.wait() in torchscript function and ensure that
# timeout is raised.
if self.rank != 0:
return
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {
"first_kwarg": torch.tensor([2, 2]),
"second_kwarg": torch.tensor([3, 3]),
}
expected_error = self.get_timeout_error_regex()
# Ensure that we get a timeout if we override the default timeout and
# the RPC takes longer to execute.
with self.assertRaisesRegex(RuntimeError, expected_error):
rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
# Ensure that we timeout if we don't specify a timeout but the default
# is less than the RPC takes to execute.
rpc._set_rpc_timeout(0.001)
with self.assertRaisesRegex(RuntimeError, expected_error):
script_rpc_async_call(
dst_worker_name, args, kwargs
)
# Ensure that we run to completion if zero timeout is specified.
ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
self.assertEqual(ret, torch.tensor([8, 8]))
# reset for clean shutdown
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
def test_timeout_in_python(self):
# Ensures timeouts are raised if we call rpc_async from within a
# torchscript function, but wait on the future in python.
if self.rank != 0:
return
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {
"first_kwarg": torch.tensor([2, 2]),
"second_kwarg": torch.tensor([3, 3]),
}
expected_error = self.get_timeout_error_regex()
fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
with self.assertRaisesRegex(RuntimeError, expected_error):
fut.wait()
# Ensure timeout if we don't specify but the default is less than the
# RPC takes to execute.
rpc._set_rpc_timeout(0.001)
fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
with self.assertRaisesRegex(RuntimeError, expected_error):
fut.wait()
# Ensure run to completion if zero timeout is specified
fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
result = fut.wait()
self.assertEqual(result, torch.tensor([8, 8]))
# reset for clean shutdown
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
@dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
def test_remote_timeout_to_here_in_jit(self):
# Test that calling to_here() in JIT will raise timeout error if
# rpc.remote failed.
if self.rank != 0:
return
dst_rank = (self.rank + 1) % self.world_size
dst_worker = "worker{}".format(dst_rank)
rref = rpc.remote(
dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
)
# Will ensure error handling callbacks are run.
wait_until_pending_futures_and_users_flushed()
# Call to_here() within a ScriptFunction and ensure it raises
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
rref_to_here(rref)
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
def test_rref_to_here_timeout_in_jit(self):
if self.rank != 0:
return
dst_rank = (self.rank + 1) % self.world_size
dst_worker = "worker{}".format(dst_rank)
rref = rpc.remote(
dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
)
expected_error = self.get_timeout_error_regex()
with self.assertRaisesRegex(RuntimeError, expected_error):
rref_to_here_with_timeout(rref, 0.01)
rref_to_here_with_timeout(rref, 100)
@dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
def test_rref_timeout_pickle_in_jit(self):
if self.rank != 0:
return
dst_rank = (self.rank + 1) % self.world_size
dst_worker = "worker{}".format(dst_rank)
rref = rpc.remote(
dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
)
# Will ensure error handling callbacks are run.
wait_until_pending_futures_and_users_flushed()
# Call RPC with RRef arg in JIT, which will go through JIT pickling and
# ensure error is raised.
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
rpc_async_with_rref_arg(dst_worker, (rref, ))
@dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
def test_rref_timeout_pickle_script_func(self):
# Similar to above test, but calls python rpc with script function.
if self.rank != 0:
return
dst_rank = (self.rank + 1) % self.world_size
dst_worker = "worker{}".format(dst_rank)
rref = rpc.remote(
dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
)
# Will ensure error handling callbacks are run.
wait_until_pending_futures_and_users_flushed()
# Call RPC with script function that takes RRef, ensure timeout during pickling
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))