#!/usr/bin/env python3
import contextlib
import enum
import logging
import os
import threading
from typing import NamedTuple
import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.nn as nn
from torch.distributed import rpc
from torch.distributed.nn import RemoteModule
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
requires_gloo,
requires_nccl,
skip_if_lt_x_gpu,
skip_if_rocm,
)
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
NUM_EM_ROW = 2
D_SPARSE = 3
D_DENSE = 2
D_HID = 3
D_OUT = 1
NUM_TRAINERS = 4
# Trainers + the master + the remote worker
WORLD_SIZE = NUM_TRAINERS + 2
TRAINER_RANKS = list(range(NUM_TRAINERS))
REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
MASTER_RANK = REMOTE_WORKER_RANK + 1
class DdpMode(enum.Enum):
# Don't apply DDP
NONE = enum.auto()
# Apply DDP to the top level nn.Module
OUTSIDE = enum.auto()
# Embed DDP inside the top level nn.Module
INSIDE = enum.auto()
def init_logger():
logger = logging.getLogger(__name__)
level = logging.DEBUG if "debug" in os.environ else logging.INFO
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger
gLogger = init_logger()
class FeatureSet(NamedTuple):
""" A feature set has 2 types of features"""
dense_features: torch.Tensor
sparse_features: torch.LongTensor
values: torch.Tensor
def _call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)
def _remote_method(method, rref, *args, **kwargs):
args_tup = tuple([method, rref] + list(args))
return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
def _remote_method_async(method, rref, *args, **kwargs):
args_tup = tuple([method, rref] + list(args))
return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
class RemoteEM(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int):
gLogger.info(f"Initing RemoteEM with {num_embeddings} {embedding_dim}")
super().__init__()
init_em = [0.5] * embedding_dim
self.em = nn.EmbeddingBag(
num_embeddings,
embedding_dim,
_weight=torch.tensor([init_em] * num_embeddings),
)
def forward(self, input: torch.Tensor):
gLogger.debug(f"Running RemoteEM.forward() on: {input}")
return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
# Return a linear module with predefined parameters.
def getLinear(d_in, d_out):
l = nn.Linear(d_in, d_out, bias=False)
w = torch.ones((d_out, d_in))
w[0][0] = -1
w.requires_grad_()
l.weight.data = w
return l
class RemoteNet(nn.Module):
def __init__(self, d_in: int, d_out: int):
gLogger.info(f"Initing RemoteNet with {d_in} {d_out}")
super().__init__()
self.fc = getLinear(d_in, d_out)
self.relu = nn.ReLU()
def forward(self, input: torch.Tensor):
gLogger.debug(f"Running RemoteNet.forward() on: {input}")
return self.relu(self.fc(input))
class HybridModel(nn.Module):
def __init__(
self,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
process_group_for_ddp: dist.ProcessGroup = None,
):
super().__init__()
self.remote_em_rref = remote_em_rref
self.remote_net_rref = remote_net_rref
self.fc1 = getLinear(D_DENSE, D_DENSE)
self.fc2 = getLinear(D_HID, D_OUT)
self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
self.fc2.parameters()
)
self.ddp_params = ()
if process_group_for_ddp is not None:
self.non_ddp_params, self.ddp_params = (
tuple(self.fc1.parameters()),
tuple(self.fc2.parameters()),
)
gLogger.info("Use DDP for the second local net.")
self.fc2 = DistributedDataParallel(
self.fc2, check_reduction=True, process_group=process_group_for_ddp
)
gLogger.info(
f"HybridModel has {len(list(self.parameters()))} groups of parameters."
)
def forward(self, input: FeatureSet):
gLogger.debug(f"Running HybridModel.forward on {input}")
sparse = _remote_method(
RemoteEM.forward, self.remote_em_rref, input.sparse_features
)
# The same size of mini batch.
assert sparse.shape[0] == input.dense_features.shape[0]
dense = self.fc1(input.dense_features)
x = torch.cat((dense, sparse), 1)
gLogger.debug(f"Concatenated feature: {x}")
x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
return self.fc2(x)
class Trainer:
def __init__(
self,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
ddp_mode: DdpMode,
rank: int,
):
self.rank = rank
self.trainer_group = (
dist.new_group(TRAINER_RANKS)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
else None
)
self.remote_em_rref = remote_em_rref
self.remote_net_rref = remote_net_rref
self.hybrid_module = HybridModel(
self.remote_em_rref,
self.remote_net_rref,
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
)
self.ddp_params, self.non_ddp_params = (
self.hybrid_module.ddp_params,
self.hybrid_module.non_ddp_params,
)
if ddp_mode == DdpMode.OUTSIDE:
gLogger.info("Wrapping the whole hybrid module into DDP.")
self.ddp_params += self.non_ddp_params
self.non_ddp_params = ()
self.hybrid_module = DistributedDataParallel(
self.hybrid_module,
check_reduction=True,
process_group=self.trainer_group,
)
gLogger.info(
f"Succeeded in creating a HybridModel instance with "
f"{len(self.ddp_params)} ddp params and {len(self.non_ddp_params)} "
f"other local params."
)
def destroy_pg(self):
if self.trainer_group:
dist.destroy_process_group(self.trainer_group)
def train_batch(
self,
mini_batch: FeatureSet,
trainer_has_less_inputs: bool,
simulate_uneven_inputs: bool,
):
grads_dict = None
if not simulate_uneven_inputs:
input_batches = [mini_batch]
else:
# Split into microbatches, and trim to simulate uneven inputs.
dense_features = mini_batch.dense_features
sparse_features = mini_batch.sparse_features
values = mini_batch.values
dense_microbatch = torch.split(dense_features, 2)
sparse_microbatch = torch.split(sparse_features, 2)
values_microbatch = torch.split(values, 2)
batches = []
for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch):
feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
batches.append(feature_set)
if trainer_has_less_inputs:
input_batches = batches[: len(batches) // 2]
gLogger.info(
f"""Trainer reduced input patches from {len(batches)}
to {len(input_batches)} to simulate uneven inputs."""
)
else:
input_batches = batches
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.suppress():
for b in input_batches:
with dist_autograd.context() as context_id:
output = self.hybrid_module.forward(b)
loss = (output * mini_batch.values).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info(
f"Loss is {loss} for mini batch: {mini_batch}. "
f"Grads dict has {len(grads_dict)} entries: {grads_dict}"
)
return (
tuple(grads_dict[param] for param in self.ddp_params),
tuple(grads_dict[param] for param in self.non_ddp_params),
)
def get_training_examples():
n = 16
training_examples = FeatureSet(
dense_features=torch.zeros((n, D_DENSE)),
sparse_features=torch.zeros(n, dtype=torch.long),
values=torch.zeros(n),
)
idx = 0
# Every example has another one that has exactly the same features but an
# opposite value. Therefore, their grads cancel each other in all-reduce.
for value in (-1, 1):
for x in (-1.0 * value, 1.0 * value):
for y in (1.0 * value, -1.0 * value):
for z in (0, 1):
training_examples.dense_features[idx, :] = torch.tensor((x, y))
training_examples.sparse_features[idx] = z
training_examples.values[idx] = value
idx += 1
# Split the examples among NUM_TRAINERS trainers
assert 0 == (n % NUM_TRAINERS)
examples_per_trainer = int(n / NUM_TRAINERS)
return [
FeatureSet(
dense_features=training_examples.dense_features[
start : start + examples_per_trainer, :
],
sparse_features=training_examples.sparse_features[
start : start + examples_per_trainer
],
values=training_examples.values[start : start + examples_per_trainer],
)
for start in range(0, n, examples_per_trainer)
]
shutdown_signal = threading.Condition()
def set_shutdown_signal():
global shutdown_signal
with shutdown_signal:
shutdown_signal.notify()
class DdpUnderDistAutogradTest(RpcAgentTestFixture):
@property
def world_size(self) -> int:
return WORLD_SIZE
def remote_worker_name(self) -> str:
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{REMOTE_WORKER_RANK}"
def trainer_name(self, rank):
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{rank}"
def _remote_worker_process(self, ddp_mode):
gLogger.info("The remote worker is running.")
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
# new_group needs to be called on ranks.
dist.new_group(TRAINER_RANKS)
global shutdown_signal
with shutdown_signal:
shutdown_signal.wait()
gLogger.info("Exiting remote worker.")
dist.destroy_process_group()
def _trainer_process(self, rank: int):
gLogger.info(f"Running the trainer #{rank}...")
gLogger.info(
f"Initing trainer process group by trainer #{rank} with ranks {TRAINER_RANKS}"
)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
gLogger.info(f"Waiting for shutdown signal on trainer #{rank}...")
global shutdown_signal
with shutdown_signal:
shutdown_signal.wait()
gLogger.info(f"Exiting the trainer #{rank}...")
dist.destroy_process_group()
def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
gLogger.info("Running the master process...")
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
remote_em_rref = rpc.remote(
self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
)
remote_net_rref = rpc.remote(
self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID)
)
gLogger.info("Created remote rrefs on master")
self.do_test_on_master(
ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref
)
def do_test_on_master(
self,
ddp_mode: DdpMode,
simulate_uneven_inputs: bool,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
):
if simulate_uneven_inputs:
gLogger.info(
"Running DDP + RPC test with simulating uneven inputs across trainers."
)
trainer_rrefs = []
for rank in TRAINER_RANKS:
trainer = self.trainer_name(rank)
trainer_rrefs.append(
rpc.remote(
trainer,
Trainer,
args=(remote_em_rref, remote_net_rref, ddp_mode, rank),
)
)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
# new_group needs to be called on ranks.
dist.new_group(TRAINER_RANKS)
training_examples = get_training_examples()
for _ in range(3):
futures = []
num_trainers = len(trainer_rrefs)
for idx, trainer_rref in enumerate(trainer_rrefs):
# Half the trainers will deplete inputs earlier than the rest.
trainer_has_less_inputs = (
simulate_uneven_inputs and idx < num_trainers // 2
)
futures.append(
_remote_method_async(
Trainer.train_batch,
trainer_rref,
training_examples[idx],
trainer_has_less_inputs,
simulate_uneven_inputs,
)
)
for future in futures:
ddp_grads, non_ddp_grads = future.wait()
# When there are uneven inputs, it is not necessary that grads
# cancel each other out, since some trainers contribute 0 grad.
if not simulate_uneven_inputs:
for grad in ddp_grads:
self.assertEqual(
grad,
torch.zeros_like(grad),
msg=f"The grad for any ddp parameter should be zeros, because "
"the training examples' grads cancel each other. Received "
f"gradient {grad}",
)
for grad in non_ddp_grads:
self.assertNotEqual(
grad,
torch.zeros_like(grad),
msg="The grad for any non-ddp parameter shouldn't be zeros",
)
# Destroy process groups
for idx, trainer_rref in enumerate(trainer_rrefs):
_remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
# Send shutdown signals.
for rank in TRAINER_RANKS:
trainer = self.trainer_name(rank)
rpc.rpc_sync(trainer, set_shutdown_signal, args=())
rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=())
def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
if self.rank == MASTER_RANK:
self._master_process(ddp_mode, simulate_uneven_inputs)
elif self.rank == REMOTE_WORKER_RANK:
self._remote_worker_process(ddp_mode)
elif self.rank in TRAINER_RANKS:
self._trainer_process(self.rank)
else:
raise RuntimeError(f"Unknow process rank: {self.rank}")
@requires_gloo()
@dist_init
def test_backward_no_ddp(self):
self._do_test(DdpMode.NONE)
@requires_gloo()
@dist_init
def test_backward_ddp_outside(self):
self._do_test(DdpMode.OUTSIDE)
@requires_gloo()
@dist_init
def test_backward_ddp_outside_uneven_inputs(self):
self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True)
@requires_gloo()
@dist_init
def test_backward_ddp_inside(self):
self._do_test(DdpMode.INSIDE)
# Common utils for both CPU and CUDA test suites
class CommonDdpComparisonTest(RpcAgentTestFixture):
@property
def world_size(self) -> int:
return NUM_TRAINERS
def trainer_name(self, rank):
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{rank}"
@staticmethod
def get_remote_grads(rref, context_id):
return dist_autograd.get_gradients(context_id)[rref.local_value().weight]
class DdpComparisonTest(CommonDdpComparisonTest):
def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
gLogger.info(f"Running trainer rank: {self.rank}")
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
# Postfix file_name with "pg" since file_name is also used by RPC agent
init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"),
world_size=self.world_size,
rank=self.rank,
)
net = nn.Linear(2, 3)
ddp_net = DistributedDataParallel(net)
# Odd ranks join early if simulate_uneven_inputs.
num_inputs = 1
if simulate_uneven_inputs:
if self.rank % 2 == 0:
num_inputs += 2
inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
if simulate_uneven_inputs:
gLogger.info(f"Rank {self.rank} training with {len(inputs_list)} inputs.")
# Use distributed autograd. The gradients will be in RPC context map.
grads_dict = {}
with ddp_net.join(simulate_uneven_inputs):
for i, inputs in enumerate(inputs_list):
with dist_autograd.context() as context_id:
loss = ddp_net(inputs).norm()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info(f"Trainer #{self.rank} got grad dict: {grads_dict}")
# Use local autograd. The gradients will be in each variable's '.grad'.
ddp_net.zero_grad()
loss = ddp_net(inputs).norm()
loss.backward()
# The gradients should be the same
for param in net.parameters():
self.assertTrue(
param in grads_dict,
msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
)
self.assertEqual(
grads_dict[param],
param.grad,
msg=f"The grads for param {param} are different under local "
f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
)
dist.destroy_process_group()
@requires_gloo()
@dist_init
def test_ddp_comparison(self):
self._run_test_ddp_comparision()
@requires_gloo()
@dist_init
def test_ddp_comparison_uneven_inputs(self):
# test with simulating uneven inputs in DDP
self._run_test_ddp_comparision(simulate_uneven_inputs=True)
@requires_gloo()
@dist_init
def test_ddp_dist_autograd_sparse_grads(self):
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
model = nn.EmbeddingBag(10, 3, sparse=True)
ddp_model = DistributedDataParallel(model)
# Different inputs for each
input = torch.LongTensor(10).random_(0, 10)
offsets = torch.LongTensor([0, 4])
# Run local.
loss = ddp_model(input, offsets).sum()
loss.backward()
with dist_autograd.context() as context_id:
loss = ddp_model(input, offsets).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
self.assertEqual(1, len(grads_dict))
self.assertEqual(model.weight.grad, grads_dict[model.weight])
@requires_gloo()
@dist_init
def test_ddp_dist_autograd_local_vs_remote(self):
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
# Use two different remote device input string, w/ and w/o the default
# device string "cpu", respectively.
for remote_device in ["worker0/cpu", "worker0"]:
remote_layer1 = RemoteModule(
remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
)
layer1 = nn.Linear(10, 5, False)
# Start with the same parameters for remote and local
layer1.weight = remote_layer1.module_rref.to_here().weight
# Run local case.
layer2 = nn.Linear(5, 1)
inputs = torch.rand((10, 10))
ddp_model = DistributedDataParallel(layer2)
loss = ddp_model(layer1(inputs)).sum()
loss.backward()
# Run remote case.
with dist_autograd.context() as context_id:
loss = ddp_model(remote_layer1(inputs)).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
dist.barrier()
self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
self.assertEqual(
layer1.weight.grad,
rpc.rpc_sync(
"worker0",
CommonDdpComparisonTest.get_remote_grads,
args=(remote_layer1.module_rref, context_id),
),
)
class CudaDdpComparisonTest(CommonDdpComparisonTest):
@skip_if_lt_x_gpu(NUM_TRAINERS)
@requires_nccl()
@dist_init
@skip_if_rocm
def test_ddp_dist_autograd_local_vs_remote_gpu(self):
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
remote_layer1 = RemoteModule(
remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
)
layer1 = nn.Linear(10, 7, False)
# Start with the same parameters for remote and local
layer1.weight = remote_layer1.module_rref.to_here().weight
layer2 = nn.Linear(7, 5).cuda(self.rank)
ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])
remote_layer3 = RemoteModule(
remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
)
layer3 = nn.Linear(5, 3, False)
# Start with the same parameters for remote and local
layer3.weight = remote_layer3.module_rref.to_here().weight
layer4 = nn.Linear(3, 1).cuda(self.rank)
ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])
# Run local case.
inputs = torch.rand((10, 10))
loss = ddp_layer4(
layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank)
).sum()
loss.backward()
# Run remote case.
with dist_autograd.context() as context_id:
loss = ddp_layer4(
remote_layer3(
ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu()
).cuda(self.rank)
).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
dist.barrier()
self.assertEqual(
layer1.weight.grad,
rpc.rpc_sync(
"worker0",
CommonDdpComparisonTest.get_remote_grads,
args=(remote_layer1.module_rref, context_id),
),
)
self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
self.assertEqual(
layer3.weight.grad,
rpc.rpc_sync(
"worker0",
CommonDdpComparisonTest.get_remote_grads,
args=(remote_layer3.module_rref, context_id),
),
)
self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])