import threading
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torch.testing._internal.dist_utils import dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
class MyModule:
lock = threading.Lock()
def __init__(self):
# cannot directly use torch.manual_seed(0) as all threads share the same
# default generator. The race from multiple RPC threads could mess up
# the draw order from the default RNG instance, leading to
# non-deterministic behavior. Hence, create a dedicated RNG here.
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
self.w = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
def forward(self, t1):
return torch.mm(self.w, t1)
def get_w(self):
return self.w
class FailingOptimizer(optim.Optimizer):
def __init__(self, params):
super().__init__(params, {})
def step(self, closure=None):
raise ValueError("Error running optimizer.")
class OptimizerFailingOnConstructor(optim.Optimizer):
def __init__(self, params):
super().__init__(params, {})
raise ValueError("Error creating optimizer.")
def step(self, closure=None):
raise NotImplementedError
def _call_method(method, obj_rref, *args, **kwargs):
return method(obj_rref.local_value(), *args, **kwargs)
def remote_method(method, obj_rref, *args, **kwargs):
"""
Call rpc.remote on a method in a remote object.
Args:
method: the method (for example, Class.method)
obj_rref (RRef): remote reference to the object
args: positional arguments to pass to the method
kwargs: keyword arguments to pass to the method
Returns a RRef to the remote method call result.
"""
return rpc.remote(
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
kwargs=kwargs,
)
def rpc_async_method(method, obj_rref, *args, **kwargs):
"""
Call rpc.rpc_async on a method in a remote object.
Args:
method: the method (for example, Class.method)
obj_rref (RRef): remote reference to the object
args: positional arguments to pass to the method
kwargs: keyword arguments to pass to the method
Returns a Future to the method call result.
"""
return rpc.rpc_async(
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
kwargs=kwargs,
)
class DistOptimizerTest(RpcAgentTestFixture):
@dist_init()
def test_dist_optim_exception(self):
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
dist_optim = DistributedOptimizer(
FailingOptimizer, [remote_param1, remote_param2]
)
with dist_autograd.context() as context_id:
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1).sum()
dist_autograd.backward(context_id, [loss])
with self.assertRaisesRegex(Exception, "Error running optimizer"):
dist_optim.step(context_id)
@dist_init()
def test_dist_optim_exception_on_constructor(self):
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
with self.assertRaisesRegex(Exception, "Error creating optimizer."):
dist_optim = DistributedOptimizer(
OptimizerFailingOnConstructor, [remote_param1, remote_param2]
)
def _test_dist_optim_base(self, optim_cls, *args, **kwargs):
# local version
module1 = MyModule()
module2 = MyModule()
params = [module1.get_w(), module2.get_w()]
local_optim = optim_cls(params, *args, **kwargs)
old_w1 = module1.w.clone().detach()
old_w2 = module2.w.clone().detach()
g_cpu = torch.Generator()
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = module1.forward(t2)
output2 = module2.forward(output1)
loss = torch.add(output2, t1).sum()
loss.backward()
local_optim.step()
# distributed version
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
remote_param1 = remote_method(MyModule.get_w, remote_module1)
remote_param2 = remote_method(MyModule.get_w, remote_module2)
old_w1_remote = remote_param1.to_here()
# sanity check: local and remote initial weights should match
self.assertEqual(old_w1, remote_param1.to_here())
self.assertEqual(old_w2, remote_param2.to_here())
dist_optim = DistributedOptimizer(
optim_cls, [remote_param1, remote_param2], *args, **kwargs
)
with dist_autograd.context() as context_id:
g_cpu.manual_seed(0)
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1)
dist_autograd.backward(context_id, [loss.sum()])
dist_optim.step(context_id)
new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
# ensure optimizer changed weights
self.assertNotEqual(old_w1, new_w1)
self.assertNotEqual(old_w2, new_w2)
# ensure local equals remote
self.assertEqual(new_w1, module1.get_w())
self.assertEqual(new_w2, module2.get_w())
@dist_init()
def test_dist_optim(self):
self._test_dist_optim_base(optim.Adagrad, lr=0.05)
self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
self._test_dist_optim_base(optim.SGD, lr=0.05)
self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True)
self._test_dist_optim_base(optim.Adadelta, rho=0.95)
self._test_dist_optim_base(optim.RMSprop, lr=0.05)