from typing import Dict, Tuple
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.distributed.rpc import rpc_async
from torch.testing import FileCheck
from torch.testing._internal.dist_utils import dist_init, worker_name
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
@torch.jit.script
def local_add(t1, t2):
return torch.add(t1, t2)
@torch.jit.script
def remote_add(t1, t2, dst: str): # noqa: E999
return rpc_async(dst, local_add, (t1, t2)).wait()
@torch.jit.script
def fork_add(t1, t2, dst: str):
fut = torch.jit._fork(remote_add, t1, t2, dst)
return torch.jit._wait(fut)
class JitDistAutogradTest(RpcAgentTestFixture):
@dist_init
def test_get_gradients(self):
dst_rank = self.rank
@torch.jit.script
def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]):
return dist_autograd.get_gradients(context_id)
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
t3 = torch.add(t1, t2)
dist_autograd.backward(context_id, [t3.sum()])
grads = dist_get_gradients(context_id)
self.assertEqual(2, len(grads))
self.assertIn(t1, grads)
self.assertIn(t2, grads)
self.assertEqual(torch.ones(3, 3), grads[t1])
self.assertEqual(torch.ones(3, 3), grads[t2])
@dist_init
def test_dist_backward(self):
if self.rank != 0:
return
@torch.jit.script
def dist_backward_script(context_id: int, loss: torch.Tensor):
dist_autograd.backward(context_id, [loss])
FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
with dist_autograd.context() as context_id:
t1 = torch.rand(3, 3, requires_grad=True)
t2 = torch.rand(3, 3, requires_grad=True)
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum()
dist_backward_script(context_id, loss)
@dist_init
def test_jit_fork_within_context(self):
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
res = fork_add(t1, t2, dst_worker_name)
loss = res.sum()
dist_autograd.backward(context_id, [loss])
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(2, len(grads))
self.assertIn(t1, grads)
self.assertIn(t2, grads)
@dist_init
def test_restore_context_after_swtich_to_jit_thread(self):
if self.rank != 0:
return
@torch.jit.script
def forward_script(
context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor
) -> Tuple[Tensor, Tensor]:
res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1))
res1 = res1_fut.wait() # After this, the script runs in a new JIT thread.
loss1 = res1.sum()
# SendRpcBackward is not attched, since DistAutogradContext is lost here.
res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2))
res2 = res2_fut.wait()
loss2 = res2.sum()
return loss1, loss2
with dist_autograd.context() as context_id:
t1 = torch.ones((2, 3), requires_grad=True)
t2 = torch.ones((2, 3), requires_grad=True)
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2)
dist_autograd.backward(context_id, [loss0, loss1])
grad0, grad1 = dist_autograd.get_gradients(context_id)
self.assertEqual(grad0, grad1)