import sys
import threading
import time
import unittest
from enum import Enum
import torch
from datetime import timedelta
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributed.rpc import RRef
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.dist_utils import (
dist_init,
initialize_pg,
wait_until_node_failure,
worker_name,
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
# Right now we test up to 3-layer nested rpc calls.
# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
# sent from prev rank respectively.
# rpc_done[2] and ctx_ids[2] represents for prev of prev rank.
# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank.
# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used.
rpc_done = [False, False, False, False]
ctx_ids = [-1, -1, -1, -1]
known_context_ids = set()
requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
# Send rpc done info and context_id to
# dst_rank = (self.rank + rank_distance) % self.world_size
# we don't need a lock here since the GIL is held while executing remote
# python UDFs, so access is serialized across several workers.
def _set_rpc_done(ctx_id, rank_distance):
global rpc_done
global ctx_ids
global known_context_ids
rpc_done[rank_distance] = True
ctx_ids[rank_distance] = ctx_id
known_context_ids.add(ctx_id)
def _check_rpc_done(rank_distance):
while not rpc_done[rank_distance]:
time.sleep(0.1)
def _torch_ones(sizes, requires_grad=False):
return torch.ones(sizes, requires_grad=requires_grad)
# This method must be called on the rref owner, and verifies that the grad of
# rref tensor equals to the given grad.
def _compare_owner_value(context_id, rref, grad):
grads = dist_autograd.get_gradients(context_id)
return torch.equal(grads[rref.local_value()], grad)
def create_tensor():
return torch.ones((3, 3), requires_grad=True)
@torch.jit.script
def create_torchscript_tensor() -> torch.Tensor:
return torch.ones((3, 3)).requires_grad_()
def my_py_add(t1, t2):
return torch.add(t1, t2)
def my_scalar_add(a, b):
return a + b
def my_rref_add(rref_t1, t2):
ret = torch.add(rref_t1.local_value(), t2)
return ret
@torch.jit.script
def my_script_add(t1, t2):
return torch.add(t1, t2)
@torch.jit.script
def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor:
t1 = ref_t1.to_here()
return torch.add(t1, t2)
def my_nested_rref_add(dst, rref_t1, t2):
return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
def ret_requires_grad():
return requires_grad_tensor
def my_py_nested_call(t1, t2, dst, world_size, hops):
next_dst = (dst + 1) % world_size
if hops > 0:
return rpc.rpc_sync(
worker_name(next_dst),
my_py_nested_call,
args=(t1, t2, next_dst, world_size, hops - 1),
)
else:
return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2))
# after dist autograd context is cleaned up, it should be cleaned up on other
# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
# ensures that all the contexts have been cleaned up in that timeframe.any
def _all_contexts_cleaned_up(timeout_seconds=10):
global known_context_ids
start = time.time()
context_id_to_raised = set()
while (
time.time() - start < timeout_seconds
and context_id_to_raised != known_context_ids
):
for context_id in known_context_ids:
try:
dist_autograd._retrieve_context(context_id)
except RuntimeError:
context_id_to_raised.add(context_id)
# all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
success = context_id_to_raised == known_context_ids
return success
# This function creates a dis atugorad context, run rpc_sync on the given ps,
# and then blocks until the ps has verified the grads are correctly accumulated.
def _run_trainer(rref_t1, t2, ps, rank_diff):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
dist_autograd.backward(context_id, [ret.sum()])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
# This function is the same as _run_trainer, except rpc calls torchscript
# function "my_script_ref_add" instead of python funciton "my_rref_add"
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff):
with dist_autograd.context() as context_id:
ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
dist_autograd.backward(context_id, [ret.sum()])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
class SimulateBackwardError(Function):
_simulate_error = True
@staticmethod
def forward(ctx, input):
return input
@staticmethod
@once_differentiable
def backward(ctx, input):
if SimulateBackwardError._simulate_error:
raise Exception("Simulate error on backward pass")
else:
return input
class ExecMode(Enum):
LOCAL = 1 # Run the operation locally.
RPC_SYNC = 2 # Run the operation using rpc_sync
REMOTE = 3 # Run the operation using remote.
RPC_ASYNC = 4 # Run the operation using rpc_async
class DistAutogradTest(RpcAgentTestFixture):
def _exec_func_with_dst(self, dst, exec_mode, method, *args):
if ExecMode.LOCAL == exec_mode:
if len(args) == 1 and isinstance(args[0], list):
return method(*args[0])
return method(*args)
elif ExecMode.RPC_SYNC == exec_mode:
return rpc.rpc_sync(worker_name(dst), method, args=(args))
elif ExecMode.REMOTE == exec_mode:
return rpc.remote(worker_name(dst), method, args=(args)).to_here()
elif ExecMode.RPC_ASYNC == exec_mode:
fut = rpc.rpc_async(worker_name(dst), method, args=(args))
return fut.wait()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
def _exec_func(self, exec_mode, method, *args):
return self._exec_func_with_dst(
self._next_rank(), exec_mode, method, *args
)
def _next_rank(self):
if hasattr(self, "dst_rank"):
self.dst_rank = (self.dst_rank + 1) % self.world_size
if self.dst_rank == self.rank:
return self._next_rank()
else:
self.dst_rank = (self.rank + 1) % self.world_size
return self.dst_rank
def _check_rpc_done(self, rank_distance):
_check_rpc_done(rank_distance)
@dist_init
def test_autograd_context(self):
# Verify max possible id.
max_auto_increment = 281474976710655
self.assertEqual(
max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id()
)
context_ids = []
for i in range(200):
with dist_autograd.context() as context_id:
self.assertEqual(
context_id,
dist_autograd._retrieve_context(context_id)._context_id(),
)
# First 16 bits should be worker_id.
self.assertEqual(self.worker_id, context_id >> 48)
context_ids.append(context_id)
for context_id in context_ids:
with self.assertRaisesRegex(
RuntimeError,
"Could not find autograd context with id: {}".format(context_id),
):
dist_autograd._retrieve_context(context_id)
@dist_init
def test_nested_context(self):
with dist_autograd.context() as context_id:
# Nested contexts not supported.
with self.assertRaisesRegex(
RuntimeError, "Already have an autograd context id for this thread"
):
with dist_autograd.context() as context_id:
pass
# For current context, this rank sends t1 and t2 tensors to dst_rank,
# then get t3 = torch.add(t1, t2) result tensor.
# For the current context in this rank, it expects graph like this:
# send function:
# rpcSendBackward
# / \
# t1.AccumulateGrad t2.AccumulateGrad
#
# recv function:
#
# |
# t3.rpcRecvBackward
#
def _verify_graph_for_first_rpc_call(
self, send_function, recv_function, t1, t2, ret
):
# Retrieve the next functions in the graph.
next_funcs = send_function.next_functions
self.assertEqual(2, len(next_funcs))
# We should now hit t1 and t2 in the autograd graph.
self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name())
self.assertEqual(t1, next_funcs[0][0].variable)
self.assertEqual(0, next_funcs[0][1])
self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name())
self.assertEqual(t2, next_funcs[1][0].variable)
self.assertEqual(0, next_funcs[1][1])
# Test recv functions.
self.assertEqual(ret.grad_fn, recv_function)
# For a context passed from previous nested chain calls, this rank
# receives two tensors t1 and t2, executes torch.add(t1, t2) and sends
# result tensor t3 back.
# For this context in this rank, it expects graph like this:
# send and recv functions:
# rpcSendBackward
# |
# t3.AddBackward0
# / \
# t1.recvRpcBackward t2.recvRpcBackward
def _verify_graph_for_rpc_call_exec(self, send_function):
# Verify next function is AddBackward0
next_funcs = send_function.next_functions
self.assertEqual(1, len(next_funcs))
add_backward_fn = next_funcs[0][0]
self.assertEqual("AddBackward0", add_backward_fn.name())
# Verify the next two functions are the same recv backward function.
next_funcs = add_backward_fn.next_functions
self.assertEqual(2, len(next_funcs))
self.assertEqual(
"torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
)
self.assertEqual(
"torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
)
self.assertEqual(next_funcs[0][0], next_funcs[1][0])
# For a context passed from previous nested chain calls, this rank
# receives two tensors t1 and t2, forwards t1 and t2 tensors using
# nested rpc call to next dst. In return route, receive result tensor t3
# from next dst and forwarding t3 back to previous calls.
# For this context in this rank, it expects graph like this:
# send and recv functions for receiving and forwarding t1 and t2:
# rpcSendBackward
# / \
# t1.recvRpcBackward t2.recvRpcBackward
# send and recv functions for receiving and forwarding t3:
# rpcSendBackward
# |
# t3.recvRpcBackward
def _verify_graph_for_nested_rpc_call(self, ctx):
send_functions = ctx._send_functions()
self.assertEqual(2, len(send_functions))
# For send function when making nest rpc call,
# next functions of the send function are two recv functions
# for received two tensors from previous call
next_funcs = list(send_functions.values())[0].next_functions
self.assertEqual(2, len(next_funcs))
self.assertEqual(
"torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
)
self.assertEqual(
"torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
)
self.assertEqual(next_funcs[0][0], next_funcs[1][0])
# For send function when returning resonpose to previous call
Loading ...