# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import collections
import copy
import enum
import inspect
import io
import logging
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
import torch
import torch.distributed as dist
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.distributed.optim.utils import functional_optim_map
from torch.optim import Optimizer
logger = logging.getLogger(__name__)
__all__ = ["ZeroRedundancyOptimizer"]
# Credits: classy_vision/generic/distributed_util.py
def _recursive_copy_to_device(
value: Any,
non_blocking: bool,
device: torch.device,
) -> Any:
r"""
Recursively searches lists, tuples, dicts and copies tensors to device if
possible. Non-tensor values are passed as-is in the result.
.. note: These are all copies, so if there are two objects that reference
the same object, then after this call, there will be two different objects
referenced on the device.
"""
if isinstance(value, torch.Tensor):
return value.to(device, non_blocking=non_blocking)
if isinstance(value, (list, tuple)):
values = [
_recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
for val in value
]
return values if isinstance(value, list) else tuple(values)
if isinstance(value, collections.abc.Mapping):
return {
key: _recursive_copy_to_device(
val, non_blocking=non_blocking, device=device
)
for key, val in value.items()
}
return value
def _is_trainable(param: torch.Tensor) -> bool:
r"""
Returns if a parameter is trainable, where trainability is equivalent to
requiring a gradient.
"""
return param.requires_grad
def _broadcast_object(
obj: Any,
src_rank: int,
group: object = dist.group.WORLD,
device: torch.device = torch.device("cpu"),
) -> Any:
r"""
Broadcasts an object to the given group, sending the object if called from
the source rank and receiving the object otherwise.
Arguments:
obj: object to broadcast; only used if called on the source rank.
src_rank (int): source rank.
group (``ProcessGroup``, optional): group used for the broadcast
(default: ``dist.group.WORLD``).
device (``torch.device``, optional): device to send from or receive
to (default: ``torch.device("cpu")``).
Returns:
The broadcasted object.
"""
if dist.get_rank() == src_rank:
# Send the object
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(device)
data_send_tensor = torch.ByteTensor(data).to(device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Receive the object
length_tensor = torch.LongTensor([0]).to(device)
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty(
[int(length_tensor.item())], dtype=torch.uint8, device=device
)
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=device)
return obj
class _ZeROJoinHook(JoinHook):
def __init__(self, zero):
assert isinstance(zero, ZeroRedundancyOptimizer), (
"ZeRO join hook requires passing in a ZeroRedundancyOptimizer "
"instance as the state"
)
self.zero = zero
super().__init__()
def main_hook(self):
"""
Performs an optimizer step, which updates the joined process's shard of
the parameters and broadcasts those parameters.
"""
self.zero.step()
class _DDPBucketAssignment:
r"""
This represents a :class:`DistributedDataParallel` bucket assignment,
meaning a (possibly non-strict) subset of the parameters corresponding to
a DDP bucket assigned to a rank to update.
Attributes:
bucket_index (int): index of the bucket determined by the DDP gradient
bucket all-reduce order.
parameters (List[torch.Tensor]): model parameters in the bucket
assigned to this rank.
offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
giving the index of the first element in the passed-in
``parameters``; this equivalently indexes into the
:class:`GradBucket` 's :meth:`gradients`.
device (torch.device): device on which the parameters are stored.
tensor (torch.Tensor): flattened tensor giving the data of the
parameter subset assigned to the rank.
"""
def __init__(
self,
bucket_index: int,
parameters: List[torch.Tensor],
offset: int,
):
self.bucket_index = bucket_index
self.parameters = parameters
self.offset = offset
if len(self.parameters) == 0:
raise ValueError("Empty bucket assignment")
# DDP guarantees all parameters in the bucket have the same device
self.device: torch.device = self.parameters[0].device
self.tensor: Optional[torch.Tensor] = None
class _OverlapStatus(enum.IntEnum):
r"""
This defines the three possible statuses that
:class:`ZeroRedundancyOptimizer` can be in when overlapping with
:class:`DistributedDataParallel`.
``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and
is waiting for DDP to finalize its bucketing.
``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that
its bucketing is finalized. The ZeRO instance can now collect the
necessary information about the DDP bucketing.
``INITIALIZED``: The ZeRO instance is fully initialized and can now
optimize parameters.
"""
UNINITIALIZED = 0
DDP_HAS_REBUILT_BUCKETS = 1
INITIALIZED = 2
class _OverlapInfo:
r"""
This contains the information needed by :class:`ZeroRedundancyOptimizer`
to overlap with :class:`DistributedDataParallel`.
Arguments:
world_size (int): world size of the process group being used.
Attributes:
shard_buckets (bool): if ``True``, then the assignment of each
:class:`DistributedDataParallel` bucket is partitioned across
possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
across possibly multiple ranks) to approximate uniformity following
a threshold given by the total parameter size divided by the world
size; if ``False``, then each bucket is wholly assigned to a single
:class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
this should be set to the value passed into the hook constructor.
status (_OverlapStatus): current status; see :class:`_OverlapStatus`
for more information.
params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
gives the model parameters in the ``i``th bucket.
params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]``
gives the model parameters assigned to the ``i``th rank, where the
parameters are grouped by increasing bucket indices.
offsets (Dict[int, int]): maps from bucket index to the offset in
``self.params_per_rank[rank]`` giving the index of the first
parameter in that bucket, where ``rank`` is this process's own
rank; the keys of this :class:`dict` are the bucket indices
assigned to this rank.
num_bucket_assignments (int): total number of bucket assignments across
all ranks; this is equal to the number of
:class:`DistributedDataParallel` gradient buckets if
``shard_buckets=False`` and possibly greater otherwise.
total_size (int, optional): total size of all buckets (i.e. sum of
``param.numel()`` for all ``param`` across all buckets) if
``shard_buckets=True``; otherwise, ``None``.
broadcast_handles (List[Work]): :class:`list` of async work handles for
the parameter broadcasts.
bucket_index_to_future (Dict[int, torch.futures.Future]):
:class:`dict` mapping bucket index to the corresponding all-reduce
future.
bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict`
mapping bucket index to the corresponding bucket.
bucket_indices_seen (List[int]): :class:`list` of the bucket indices
seen on this iteration.
"""
def __init__(self, world_size) -> None:
self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
self.shard_buckets: bool = False
# Modified per bucket reconstruction
self.params_per_bucket: List[List[torch.Tensor]] = []
self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
self.offsets: Dict[int, int] = {}
# Group Ranks
self.assigned_ranks_per_bucket: List[Set[int]] = []
self.num_bucket_assignments: int = 0
self.total_size: Optional[int] = None
# Modified per iteration
self.broadcast_handles: List[Any] = []
self.bucket_indices_seen: List[int] = []
# Used by `hook_with_zero_step()`
self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
def wait_for_broadcasts(self) -> None:
r"""
Waits for all parameter broadcasts. This should be called once all
broadcasts have been scheduled, meaning ``self.broadcast_handles`` is
filled. This clears ``self.broadcast_handles`` in preparation for the
next iteration.
"""
assert (
len(self.broadcast_handles) == self.num_bucket_assignments
), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
_ = list(map(lambda x: x.wait(), self.broadcast_handles))
self.broadcast_handles.clear()
def clear_per_iter_info(self) -> None:
r"""
Clears the data structures that are modified per-iteration. This should
be called at the end of an iteration.
"""
self.bucket_indices_seen.clear()
self.bucket_index_to_future.clear()
self.bucket_index_to_bucket.clear()
class ZeroRedundancyOptimizer(Optimizer, Joinable):
r"""
This class wraps an arbitrary :class:`optim.Optimizer
<torch.optim.Optimizer>` and shards its states across ranks in the group as
described by ZeRO_. The local optimizer instance in each rank is only
responsible for updating approximately ``1 / world_size`` parameters and
hence only needs to keep ``1 / world_size`` optimizer states. After
parameters are updated locally, each rank will broadcast its parameters to
all other peers to keep all model replicas in the same state.
``ZeroRedundancyOptimizer`` can be used in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
memory consumption.
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
of parameters at each rank. Each parameter belongs to a single rank and is
not divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.
Arguments:
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
or :class:`dict` s giving all parameters, which will be sharded
across ranks.
Keyword Args:
optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
optimizer.
process_group (``ProcessGroup``, optional): ``torch.distributed``
``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by
:meth:`torch.distributed.init_process_group`).
parameters_as_bucket_view (bool, optional): if ``True``, parameters are
packed into buckets to speed up communication, and ``param.data``
fields point to bucket views at different offsets; if ``False``,
each individual parameter is communicated separately, and each
``params.data`` stays intact (default: ``False``).
overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is
overlapped with :class:`DistributedDataParallel` 's gradient
synchronization; this requires (1) either a functional optimizer
for the ``optimizer_class`` argument or one with a functional
equivalent and (2) registering a DDP communication hook
constructed from one of the functions in ``ddp_zero_hook.py``;
parameters are packed into buckets matching those in
:class:`DistributedDataParallel`, meaning that the
``parameters_as_bucket_view`` argument is ignored.
If ``False``, :meth:`step` runs disjointly after the backward pass
(per normal).
(default: ``False``)
**defaults: any trailing arguments, which are forwarded to the local
optimizer.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>> ddp.parameters(),
>>> optimizer_class=torch.optim.Adam,
>>> lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()
.. warning::
Currently, ``ZeroRedundancyOptimizer`` requires that all of the
passed-in parameters are the same dense type.
.. warning::
If you pass ``overlap_with_ddp=True``, be wary of the following: Given
Loading ...