#!/usr/bin/env python3
import contextlib
import enum
import logging
import os
import threading
from typing import NamedTuple
import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.nn as nn
from torch.distributed import rpc
from torch.distributed.nn import RemoteModule
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
requires_gloo,
requires_nccl,
skip_if_lt_x_gpu,
skip_if_rocm,
)
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
NUM_EM_ROW = 2
D_SPARSE = 3
D_DENSE = 2
D_HID = 3
D_OUT = 1
NUM_TRAINERS = 4
# Trainers + the master + the remote worker
WORLD_SIZE = NUM_TRAINERS + 2
TRAINER_RANKS = list(range(NUM_TRAINERS))
REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
MASTER_RANK = REMOTE_WORKER_RANK + 1
class DdpMode(enum.Enum):
# Don't apply DDP
NONE = enum.auto()
# Apply DDP to the top level nn.Module
OUTSIDE = enum.auto()
# Embed DDP inside the top level nn.Module
INSIDE = enum.auto()
def init_logger():
logger = logging.getLogger(__name__)
level = logging.DEBUG if "debug" in os.environ else logging.INFO
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger
gLogger = init_logger()
class FeatureSet(NamedTuple):
""" A feature set has 2 types of features"""
dense_features: torch.Tensor
sparse_features: torch.LongTensor
values: torch.Tensor
def _call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)
def _remote_method(method, rref, *args, **kwargs):
args_tup = tuple([method, rref] + list(args))
return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
def _remote_method_async(method, rref, *args, **kwargs):
args_tup = tuple([method, rref] + list(args))
return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
class RemoteEM(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int):
gLogger.info(f"Initing RemoteEM with {num_embeddings} {embedding_dim}")
super(RemoteEM, self).__init__()
init_em = [0.5] * embedding_dim
self.em = nn.EmbeddingBag(
num_embeddings,
embedding_dim,
_weight=torch.Tensor([init_em] * num_embeddings),
)
def forward(self, input: torch.Tensor):
gLogger.debug(f"Running RemoteEM.forward() on: {input}")
return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
# Return a linear module with predefined parameters.
def getLinear(d_in, d_out):
l = nn.Linear(d_in, d_out, bias=False)
w = torch.ones((d_out, d_in))
w[0][0] = -1
w.requires_grad_()
l.weight.data = w
return l
class RemoteNet(nn.Module):
def __init__(self, d_in: int, d_out: int):
gLogger.info(f"Initing RemoteNet with {d_in} {d_out}")
super(RemoteNet, self).__init__()
self.fc = getLinear(d_in, d_out)
self.relu = nn.ReLU()
def forward(self, input: torch.Tensor):
gLogger.debug(f"Running RemoteNet.forward() on: {input}")
return self.relu(self.fc(input))
class HybridModel(nn.Module):
def __init__(
self,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
process_group_for_ddp: dist.ProcessGroup = None,
):
super(HybridModel, self).__init__()
self.remote_em_rref = remote_em_rref
self.remote_net_rref = remote_net_rref
self.fc1 = getLinear(D_DENSE, D_DENSE)
self.fc2 = getLinear(D_HID, D_OUT)
self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
self.fc2.parameters()
)
self.ddp_params = ()
if process_group_for_ddp is not None:
self.non_ddp_params, self.ddp_params = (
tuple(self.fc1.parameters()),
tuple(self.fc2.parameters()),
)
gLogger.info("Use DDP for the second local net.")
self.fc2 = DistributedDataParallel(
self.fc2, check_reduction=True, process_group=process_group_for_ddp
)
gLogger.info(
f"HybridModel has {len(list(self.parameters()))} groups of parameters."
)
def forward(self, input: FeatureSet):
gLogger.debug(f"Running HybridModel.forward on {input}")
sparse = _remote_method(
RemoteEM.forward, self.remote_em_rref, input.sparse_features
)
# The same size of mini batch.
assert sparse.shape[0] == input.dense_features.shape[0]
dense = self.fc1(input.dense_features)
x = torch.cat((dense, sparse), 1)
gLogger.debug(f"Concatenated feature: {x}")
x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
return self.fc2(x)
class Trainer:
def __init__(
self,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
ddp_mode: DdpMode,
rank: int,
):
self.rank = rank
self.trainer_group = (
dist.new_group(TRAINER_RANKS)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
else None
)
self.remote_em_rref = remote_em_rref
self.remote_net_rref = remote_net_rref
self.hybrid_module = HybridModel(
self.remote_em_rref,
self.remote_net_rref,
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
)
self.ddp_params, self.non_ddp_params = (
self.hybrid_module.ddp_params,
self.hybrid_module.non_ddp_params,
)
if ddp_mode == DdpMode.OUTSIDE:
gLogger.info("Wrapping the whole hybrid module into DDP.")
self.ddp_params += self.non_ddp_params
self.non_ddp_params = ()
self.hybrid_module = DistributedDataParallel(
self.hybrid_module,
check_reduction=True,
process_group=self.trainer_group,
)
gLogger.info(
f"Succeeded in creating a HybridModel instance with "
f"{len(self.ddp_params)} ddp params and {len(self.non_ddp_params)} "
f"other local params."
)
def destroy_pg(self):
if self.trainer_group:
dist.destroy_process_group(self.trainer_group)
def train_batch(
self,
mini_batch: FeatureSet,
trainer_has_less_inputs: bool,
simulate_uneven_inputs: bool,
):
grads_dict = None
if not simulate_uneven_inputs:
input_batches = [mini_batch]
else:
# Split into microbatches, and trim to simulate uneven inputs.
dense_features = mini_batch.dense_features
sparse_features = mini_batch.sparse_features
values = mini_batch.values
dense_microbatch = torch.split(dense_features, 2)
sparse_microbatch = torch.split(sparse_features, 2)
values_microbatch = torch.split(values, 2)
batches = []
for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch):
feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
batches.append(feature_set)
if trainer_has_less_inputs:
input_batches = batches[: len(batches) // 2]
gLogger.info(
f"""Trainer reduced input patches from {len(batches)}
to {len(input_batches)} to simulate uneven inputs."""
)
else:
input_batches = batches
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.suppress():
for b in input_batches:
with dist_autograd.context() as context_id:
output = self.hybrid_module.forward(b)
loss = (output * mini_batch.values).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info(
f"Loss is {loss} for mini batch: {mini_batch}. "
f"Grads dict has {len(grads_dict)} entries: {grads_dict}"
)
return (
tuple(grads_dict[param] for param in self.ddp_params),
tuple(grads_dict[param] for param in self.non_ddp_params),
)
def get_training_examples():
n = 16
training_examples = FeatureSet(
dense_features=torch.zeros((n, D_DENSE)),
sparse_features=torch.zeros(n, dtype=torch.long),
values=torch.zeros(n),
)
idx = 0
# Every example has another one that has exactly the same features but an
# opposite value. Therefore, their grads cancel each other in all-reduce.
for value in (-1, 1):
for x in (-1 * value, 1 * value):
for y in (1 * value, -1 * value):
for z in (0, 1):
training_examples.dense_features[idx, :] = torch.Tensor((x, y))
training_examples.sparse_features[idx] = z
training_examples.values[idx] = value
idx += 1
# Split the examples among NUM_TRAINERS trainers
assert 0 == (n % NUM_TRAINERS)
examples_per_trainer = int(n / NUM_TRAINERS)
return [
FeatureSet(
dense_features=training_examples.dense_features[
start : start + examples_per_trainer, :
],
sparse_features=training_examples.sparse_features[
start : start + examples_per_trainer
],
values=training_examples.values[start : start + examples_per_trainer],
)
for start in range(0, n, examples_per_trainer)
]
shutdown_signal = threading.Condition()
def set_shutdown_signal():
global shutdown_signal
with shutdown_signal:
shutdown_signal.notify()
class DdpUnderDistAutogradTest(RpcAgentTestFixture):
@property
def world_size(self) -> int:
return WORLD_SIZE
def remote_worker_name(self) -> str:
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{REMOTE_WORKER_RANK}"
def trainer_name(self, rank):
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{rank}"
def _remote_worker_process(self, ddp_mode):
gLogger.info("The remote worker is running.")
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
# new_group needs to be called on ranks.
dist.new_group(TRAINER_RANKS)
global shutdown_signal
with shutdown_signal:
shutdown_signal.wait()
gLogger.info("Exiting remote worker.")
dist.destroy_process_group()
Loading ...