import logging
import torch
import torch.distributed as dist
from . import default_hooks as default
logger = logging.getLogger(__name__)
class PostLocalSGDState:
r"""
Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
and all-reducing gradients locally using ``subgroup`` afterwards.
If ``process_group`` is ``None``, the global process group will be used.
If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
Additionally, ``post_local_gradient_allreduce`` may be worth tuning,
because both true and false may give a faster convergence.
"""
__slots__ = [
"process_group",
"subgroup",
"start_localSGD_iter",
"post_local_gradient_allreduce",
"iter",
]
def __init__(
self,
process_group,
subgroup,
start_localSGD_iter,
post_local_gradient_allreduce=True,
):
logger.info(
"Local SGD will be started after {} iterations".format(start_localSGD_iter)
)
# The group used for all-reducing gradients globally.
self.process_group = process_group
# The group used for all-reducing gradients locally.
self.subgroup = subgroup
self.start_localSGD_iter = start_localSGD_iter
# Allreduce gradients locally since iteration `start_localSGD_iter`.
# This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication.
self.post_local_gradient_allreduce = post_local_gradient_allreduce
# Iteration/step in the training loop.
self.iter = 0
def maybe_increase_iter(self, bucket):
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.
if bucket.is_last():
self.iter += 1
if self.iter == self.start_localSGD_iter:
logger.info(
"Start to apply local SGD after {} iterations.".format(self.iter)
)
def post_localSGD_hook(
state: PostLocalSGDState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
This DDP communication hook is used for running post-localSGD algorithm,
by combining with a model averaging component (e.g.,
:class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`)
that runs after the optimizer step.
Args:
state (PostLocalSGDState): State information to run post-localSGD.
Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD.
bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode,
only exactly one tensor is stored in this bucket.
Returns:
Future handler of the communication, which updates the gradients in place.
Example::
>>> # xdoctest: +SKIP
>>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup,
start_localSGD_iter=10)
>>> ddp_model.register_comm_hook(state, post_localSGD_hook)
>>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``.
>>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module.
"""
global_group_to_use = (
state.process_group if state.process_group is not None else dist.group.WORLD
)
# The input tensor is a flattened 1D tensor.
input_tensor = bucket.buffer()
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
if state.iter < state.start_localSGD_iter:
state.maybe_increase_iter(bucket)
return default._allreduce_fut(global_group_to_use, input_tensor)
# If `post_local_gradient_allreduce` is not set,
# then no gradient synchronization after the first `start_localSGD_iter` iterations.
if not state.post_local_gradient_allreduce:
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(input_tensor)
return fut
# Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
# Note that by default, a separate subgroup for each node is created which
# causes an intra-node allreduce to be done at each training step.
# From this moment, model averaging should run after the optimizer step,
# to globally allreduce all the parameters.
if state.subgroup is None:
state.subgroup, _ = dist.new_subgroups()
return default._allreduce_fut(state.subgroup, input_tensor)