Repository URL to install this package:
Version:
2.4.1 ▾
|
# mypy: allow-untyped-defs
import logging
from collections import abc, defaultdict
from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
import torch
import torch.distributed as dist
from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
from torch.distributed.distributed_c10d import ProcessGroup
logger = logging.getLogger(__name__)
def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
def _is_supported_device(tensor: torch.Tensor) -> bool:
return tensor.is_cuda or tensor.device.type in (
"xla",
"cpu",
"hpu",
torch._C._get_privateuse1_backend_name(),
)
class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
"""
Lazily serves tensor to request device. This class extends
_MultiDeviceReplicator to allow support for "cpu" as a device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert _is_supported_device(master_tensor)
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
class ShardedGradScaler(GradScaler):
"""
ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
functionality from GradScaler:
* Supports Pytorch DDP and FSDP implementations
* Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
* Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
* Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
nodes
Example::
# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
process group for sharding
"""
def __init__(
self,
device: str = "cuda",
init_scale: float = 2.0**16,
backoff_factor: float = 0.5,
growth_factor: float = 2.0,
growth_interval: int = 2000,
enabled: bool = True,
process_group: Optional[ProcessGroup] = dist.group.WORLD,
) -> None:
super().__init__(
device,
init_scale=init_scale,
backoff_factor=backoff_factor,
growth_factor=growth_factor,
growth_interval=growth_interval,
enabled=enabled,
)
if self._enabled:
self.process_group = process_group
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
if not self._enabled:
return outputs
if isinstance(outputs, torch.Tensor):
assert _is_supported_device(outputs)
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
scaled_output = outputs * self._scale.to(
device=outputs.device, non_blocking=True
)
# Here we ensure the return dtype is the same as the outputs dtype.
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_output.type(outputs.dtype)
stash: List[_GeneralMultiDeviceReplicator] = []
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
assert _is_supported_device(val)
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_GeneralMultiDeviceReplicator(self._scale))
scaled_val = val * stash[0].get(val.device)
# Here we ensure the return dtype is the same as the outputs dtype.
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_val.type(val.dtype)
if isinstance(val, abc.Iterable):
iterator = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterator)
return iterator
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _foreach_non_finite_check_and_unscale_cpu_(
self,
grads: Sequence[torch.Tensor],
found_inf: torch.Tensor,
inv_scale: torch.Tensor,
) -> None:
if len(grads) == 0:
return
assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
for grad in grads:
if grad.device.type != "cpu":
logger.error(
"tensor device is %s but was expected to be ``cpu``",
grad.device,
)
raise ValueError(
"Gradients were found on a non-CPU device when"
" expected to be on CPU."
)
if (
torch.isinf(grad).any().item() is True
or torch.isnan(grad).any().item() is True
):
found_inf.data = torch.tensor([1.0])
break
else:
grad.data *= inv_scale.item()
def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool = True,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be thousands of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
# coalesce is not supported in torch.float16
param_grad_fp32 = param.grad.type(torch.float32).coalesce()
param.grad = param_grad_fp32.type(torch.float16)
to_unscale = param.grad._values()
else:
to_unscale = param.grad
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
if grads[0].device.type == "cpu":
self._foreach_non_finite_check_and_unscale_cpu_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
else:
torch._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
# There exist contexts (e.g. w/ `use_orig_params=True`) wherein some
# ranks may have no (non-zero sized) parameter shards, necessitating the
# initialization of `per_device_found_inf._per_device_tensors` here
if not per_device_found_inf._per_device_tensors:
assert self._scale is not None
per_device_found_inf.get(self._scale.device)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, True
)
optimizer_state["stage"] = OptState.UNSCALED
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
works = []
found_inf_on_cpus = []
found_inf_on_devices = []
for found_inf in optimizer_state["found_inf_per_device"].values():
if self._device != "cpu" and found_inf.device.type == "cpu":
found_inf_on_cpus.append(found_inf)
found_inf_on_device = found_inf.to(self._device)
found_inf_on_devices.append(found_inf_on_device)
works.append(
dist.all_reduce(
found_inf_on_device, async_op=True, group=self.process_group
)
)
else:
works.append(
dist.all_reduce(found_inf, async_op=True, group=self.process_group)
)
for work in works:
work.wait()
if found_inf_on_cpus:
torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices)
def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
"""
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
"""
assert self._scale is not None and self._growth_tracker is not None
if found_inf.item() >= 1.0:
self._scale *= self._backoff_factor
self._growth_tracker.fill_(0)
else:
successful = self._growth_tracker + 1
if successful == self._growth_interval:
self._scale *= self._growth_factor
self._growth_tracker.fill_(0)
else:
self._growth_tracker = successful
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated]
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
torch.FloatTensor with requires_grad=False."
assert new_scale.device.type == self._device, reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
if _scale.device.type == "cpu":
self._amp_update_scale_cpu_(found_inf_combined)
else:
torch._amp_update_scale_(
self._scale, # type: ignore[arg-type]
self._growth_tracker, # type: ignore[arg-type]
found_inf_combined,
self._growth_factor, # type: ignore[arg-type]
self._backoff_factor, # type: ignore[arg-type]
self._growth_interval, # type: ignore[arg-type]
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)