import time
import io
from typing import Dict, List, Tuple, Any
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.autograd.profiler import record_function
from torch.distributed.rpc import RRef
from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key
from torch.futures import Future
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.dist_utils import (
dist_init,
get_function_event,
initialize_pg,
worker_name,
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
def rref_isinstance(rref, cls_to_check):
return isinstance(rref.local_value(), cls_to_check)
def sleep(t):
time.sleep(t)
def rpc_return_rref(dst):
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
@torch.jit.script
def rref_local_value(rref: RRef[Tensor]) -> Tensor:
return rref.local_value()
@torch.jit.script
def list_create() -> List[int]:
global_list = [1, 2, 3]
return global_list
@torch.jit.script
def rref_list_mutate(rref: RRef[List[int]]) -> None:
rref.local_value().append(4)
rref.to_here().append(5)
rref.to_here(5.0).append(6)
def return_value(value: int) -> int:
return value
class RRefAPITest:
@dist_init
def test_rref_is_owner(self):
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
rref_var = rpc_return_rref(dst_worker_name)
@torch.jit.script
def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool:
return rref_var.is_owner()
res = rref_tensor_is_owner(rref_var)
self.assertEqual(res, False)
@dist_init
def test_rref_local_value(self):
if self.rank != 0:
return
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
rref = rpc_return_rref(dst_worker_name)
with self.assertRaisesRegex(
RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef"
):
rref_local_value(rref)
ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,))
self.assertEqual(ret, torch.add(torch.ones(2, 2), 1))
@dist_init
def test_local_rref_local_value(self):
if self.rank != 0:
return
dst_worker_name = worker_name(self.rank)
rref = rpc.remote(dst_worker_name, return_value, (5,), {})
ret = rref_local_value(rref)
self.assertEqual(ret, 5)
def _create_rref(self):
owner_rank = (self.rank + 2) % self.world_size
return rpc.remote(
worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
)
@dist_init
def test_user_rrefs_confirmed(self):
dst_rank = (self.rank + 1) % self.world_size
rref = self._create_rref()
ret = rpc.rpc_sync(
worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
)
self.assertEqual(ret, True)
@dist_init
def test_user_rrefs_confirmed_remote(self):
dst_rank = (self.rank + 1) % self.world_size
rref = self._create_rref()
ret_rref = rpc.remote(
worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
)
self.assertEqual(ret_rref.to_here(), True)
@dist_init
def test_rref_list_mutate(self):
dst = worker_name((self.rank + 1) % self.world_size)
list_rref = rpc.remote(dst, list_create)
rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,))
self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6])
@torch.jit.script
def no_arg():
return 0
@torch.jit.script
def one_arg(value):
return value + 1
@torch.jit.script
def script_add_ones(x):
return torch.add(x, torch.ones(1))
@torch.jit.script
def script_add_ones_with_record_function(x, block: str):
with record_function(block):
return torch.add(x, torch.ones(1))
@torch.jit.script
def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
t: Tensor = torch.ones(1)
with record_function(block) as rf:
fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
res = fut1.wait() + fut2.wait()
return res
@torch.jit.script
def script_fork_wait_udf(tensor):
fut = torch.jit._fork(script_add_ones, tensor)
x = torch.jit._wait(fut)
return x
@torch.jit.script
def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
return rref_var.to_here()
@torch.jit.script
def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]:
return rref_var
@torch.jit.script
def script_raise_func(value):
if value.numel() == 2:
raise ValueError("Expected error")
return value + 1
@torch.jit.script
def script_fork_wait_throw(invalue):
fut = torch.jit._fork(script_raise_func, invalue)
value = torch.jit._wait(fut)
return value
@torch.jit.script
def call_rpc_with_profiling(handle: Tensor, dst_worker_name: str) -> Tensor:
# Call rpc_async from within ScriptFunction and ensure that we can attach
# profiling callbacks. Note that handle here is a Tensor representation of
# RecordFunction.
fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
ret = fut.wait()
return ret
@torch.jit.script
def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor:
fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block))
return fut.wait()
@torch.jit.script
def call_fork_with_profiling(handle: Tensor) -> Tensor:
# Call fork from within ScriptFunction and ensure that we can attach profiling
# callbacks to the resulting future. Note that handle here is a Tensor
# representation of RecordFunction.
fut = torch.jit._fork(one_arg, torch.tensor(1))
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
ret = fut.wait()
return ret
class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
def __init__(self, dst_worker):
super().__init__()
self.rrefs = []
for _ in range(4):
self.rrefs.append(rpc_return_rref(dst_worker))
@torch.jit.script_method
def forward(self) -> Tensor:
res_tensor = torch.ones(2, 2)
for rref in self.rrefs:
res_tensor += rref.to_here()
return res_tensor
@torch.jit.ignore
def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]:
return rref_var
@torch.jit.script
def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor:
return rref_python_annotation(rref_var).to_here()
class RRefTypingTest:
@dist_init
def test_rref_as_arg_and_return(self):
n = self.rank + 1
dst_rank = n % self.world_size
local_ret = one_arg(torch.ones(2, 2))
# create rref on current rank
rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))
# pass rref to another user in rpc call
ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
self.assertEqual(ret, local_ret)
# return rref in rpc call
rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
self.assertEqual(rref1.to_here(), local_ret)
# pass rref to another user in remote call
rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
self.assertEqual(rref2.to_here(), local_ret)
# return rref in remote call
rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
self.assertEqual(rref3.to_here().to_here(), local_ret)
@dist_init
def test_my_script_module_with_rrefs(self):
n = self.rank + 1
dst_rank = n % self.world_size
module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank))
res = module_with_rrefs()
self.assertEqual(res, torch.ones(2, 2) * 9)
@dist_init
def test_rref_python_annotation(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_var = rpc_return_rref(worker_name(dst_rank))
res = rref_script_annotation(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)
class FutureTypingTest:
@dist_init
def test_future_passed_between_python_and_jit(self):
dst_rank = (self.rank + 1) % self.world_size
inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
expected_res = torch.tensor([10, 10])
@torch.jit.script
def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
return fut.wait()
self.assertEqual(future_wait_in_script(ret_fut), expected_res)
@torch.jit.script
def future_return_to_python(
dst_rank: int, inputs: Tuple[Tensor, Tensor]
) -> Future[Tensor]:
return rpc.rpc_async(
"worker{}".format(dst_rank), two_args_two_kwargs, inputs
)
fut_res = future_return_to_python(dst_rank, inputs)
self.assertEqual(fut_res.wait(), expected_res)
@dist_init
def test_future_python_annotation(self):
if self.rank != 0:
return
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
input_0 = torch.ones(2, 2)
input_1 = 1
expected_res = torch.add(input_0, input_1)
@torch.jit.ignore
def python_return_future() -> Future[Tensor]:
fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
return fut
@torch.jit.script
def script_use_future() -> Tensor:
fut = python_return_future()
return fut.wait()
res = script_use_future()
self.assertEqual(res, expected_res)
@torch.jit.script
class MyScriptClass:
def __init__(self, a: int):
self.a = a
def get_value(self) -> int:
return self.a
Loading ...