#!/usr/bin/python3
import enum
from typing import Tuple
import torch
import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils as dist_utils
from torch import Tensor, nn
from torch._jit_internal import Future
from torch.distributed.nn import RemoteModule
from torch.distributed.nn.api.remote_module import _RemoteModule
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
_PARAM_VAL = torch.nn.Parameter(torch.ones(1))
# RPC handler for querying the device on the destination worker.
def remote_device(module_rref):
for param in module_rref.local_value().parameters():
return param.device
class ModuleCreationMode(enum.Enum):
MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
MODULE_CTOR = "module_ctor"
@torch.jit.interface
class MyModuleInterface:
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[str, int, Tensor]:
pass
@torch.jit.interface
class RemoteMyModuleInterface:
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[str, int, Tensor]:
pass
def forward_async(
self, tensor: Tensor, number: int, word: str = "default"
) -> Future[Tuple[str, int, Tensor]]:
pass
class MyModule(nn.Module):
def __init__(self, first_arg, first_kwarg=-1):
super().__init__()
self.param1 = _PARAM_VAL
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[str, int, Tensor]:
return word, number, tensor
class BadModule:
def __init__(self, first_arg, first_kwarg=-1):
pass
def create_scripted_module(first_arg, first_kwarg=-1):
module = MyModule(first_arg, first_kwarg=first_kwarg)
scripted_module = torch.jit.script(module)
return scripted_module
class RemoteModuleTest(RpcAgentTestFixture):
@property
def world_size(self): # Override setting in RpcAgentTestFixture
return 2
@staticmethod
def _create_remote_module_iter(remote_device, modes=None):
if modes is None:
modes = ModuleCreationMode.__members__.values()
args = (1,)
kwargs = dict(first_kwarg=2)
if ModuleCreationMode.MODULE_CTOR in modes:
remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
yield remote_module
if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
remote_module = _RemoteModule(
remote_device,
create_scripted_module,
args,
kwargs,
_module_interface_cls=MyModuleInterface,
)
scripted_remote_module = torch.jit.script(remote_module)
yield scripted_remote_module
@dist_utils.dist_init
def test_bad_module(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
remote_device = "{}/cpu".format(dst_worker_name)
args = (1,)
kwargs = dict(first_kwarg=2)
with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(remote_device, BadModule, args, kwargs)
with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(remote_device, BadModule, args, kwargs)
@dist_utils.dist_init
def test_forward_async(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (torch.ones(1), 2, "3")
for remote_module in self._create_remote_module_iter(dst_worker_name):
ret_fut = remote_module.forward_async(*args)
ret = ret_fut.wait()
self.assertEqual(ret, tuple(reversed(args)))
@dist_utils.dist_init
def test_forward_async_script(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
scripted_remote_module = next(
self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
)
)
@torch.jit.script
def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
ret = ret_fut.wait()
return ret
ret = run_forward_async(scripted_remote_module)
self.assertEqual(ret, ("3", 2, torch.ones(1)))
@dist_utils.dist_init
def test_forward_sync(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (torch.ones(1), 2, "3")
for remote_module in self._create_remote_module_iter(dst_worker_name):
ret = remote_module.forward(*args)
self.assertEqual(ret, tuple(reversed(args)))
@dist_utils.dist_init
def test_forward_sync_script(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
scripted_remote_module = next(
self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
)
)
@torch.jit.script
def run_forward(scripted_remote_module: MyModuleInterface):
ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
return ret
ret = run_forward(scripted_remote_module)
self.assertEqual(ret, ("3", 2, torch.ones(1)))
@dist_utils.dist_init
def test_forward_with_kwargs(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (torch.ones(1), 2)
kwargs = dict(word="3")
# Only test Python nn.Module, because script module methods don't support taking kwargs.
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
ret_fut = remote_module.forward_async(*args, **kwargs)
ret = ret_fut.wait()
self.assertEqual(ret, tuple(reversed(args + ("3",))))
ret = remote_module.forward(*args, **kwargs)
self.assertEqual(ret, tuple(reversed(args + ("3",))))
@dist_utils.dist_init
def test_remote_parameters(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
# Only test Python nn.Module, because script module methods don't support ``remote_parameters``.
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
param_rrefs = remote_module.remote_parameters()
self.assertEqual(len(param_rrefs), 1)
self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL))
@dist_utils.dist_init
def test_get_module_rref(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
# Only test Python nn.Module, because script module methods don't support ``get_module_rref``.
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
rref = remote_module.get_module_rref()
self.assertEqual(rref, remote_module.module_rref)
for param in rref.to_here().parameters():
self.assertTrue(torch.equal(param, _PARAM_VAL))
@skip_if_lt_x_gpu(1)
@dist_utils.dist_init
def test_valid_device(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
for remote_module in self._create_remote_module_iter(
"{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR]
):
device = rpc.rpc_sync(
dst_worker_name, remote_device, (remote_module.module_rref,)
)
self.assertEqual(device.type, "cuda")
self.assertEqual(device.index, 0)
@skip_if_lt_x_gpu(1)
@dist_utils.dist_init
def test_invalid_devices(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
with self.assertRaisesRegex(
RuntimeError,
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
" device type at start of device string",
):
list(
self._create_remote_module_iter(
"{}/foo".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(
RuntimeError, r"CUDA error: invalid device ordinal"
):
list(
self._create_remote_module_iter(
"{}/cuda:100".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
list(
self._create_remote_module_iter(
"{}/cpu2".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"):
list(
self._create_remote_module_iter(
"{}/".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(
RuntimeError,
r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '<workername>/<device>'",
):
list(
self._create_remote_module_iter(
"{}/cuda:0/cuda:1".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(
RuntimeError,
r"The workername in remote_device '/' cannot be empty. The valid format is '<workername>/<device>'",
):
list(
self._create_remote_module_iter(
"/",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(
RuntimeError,
r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '<workername>/<device>'",
):
list(
self._create_remote_module_iter(
"/cuda:0",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
@dist_utils.dist_init
def test_unsupported_methods(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
with self.assertRaisesRegex(
ValueError, r"Method ``register_buffer`` not supported for RemoteModule"
):
remote_module.register_buffer("buffer", torch.ones(5))
with self.assertRaisesRegex(
ValueError,
r"Method ``register_parameter`` not supported for RemoteModule",
):
Loading ...