Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / algorithms / model_averaging / hierarchical_model_averager.py

# Copyright 2022 Cruise LLC
import logging
import warnings
from collections import OrderedDict
from typing import Union, Iterable, Dict

import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.utils as utils

logger = logging.getLogger(__name__)


class HierarchicalModelAverager(averagers.ModelAverager):
    r"""
    Runs hierarchical model averaging (`hierarchical SGD <https://arxiv.org/pdf/2010.12998.pdf>`_).
    Process groups of different sizes are organized in a hierarhicy, and they average parameters
    by using different periods concurrently after the warm-up stage.
    This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
    that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
    a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
    level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
    Similarly, the process groups within this class do not have such an intra-machine process
    subgroup, which should be embedded by the post-local SGD communication hook instead.

    Args:
        period_group_size_dict: An ordered dict mapping keys of model averaging period to
                                process group size, used for initializing process groups of
                                different sizes in a hierarchy to average parameters concurrently.
                                Particularly, at each iteration, there will be at most a single
                                process group that runs averaging -- the period of such group should
                                have the largest period which the current step can be divided by.
                                For example, if the dict has three keys: 2, 4, and 8,
                                then this means totally three process groups will be created to
                                average parameters every 2, 4, and 8 iterations, respectively.
                                At the 4th iteration, only the second process group will run
                                averaging, because the first process group should be a
                                subset of the second process group, and no need to execute the first
                                process group redundantly.
                                On the other hand, the third process group can only be triggered
                                every 8 iterations, so it will not be triggered at the 4th iteration.
        warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
        process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
                                                If ``None``, the default process group, which is created
                                                by :func:`torch.distributed.init_process_group`, will be used.
                                                (default: ``None``)

    Example::
        >>> # xdoctest: +SKIP('undefined rank')
        >>> from collections import OrderedDict
        >>> import torch
        >>> import torch.distributed as dist
        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
        >>>     PostLocalSGDState,
        >>>     post_localSGD_hook,
        >>> )
        >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
        >>> import torch.nn as nn
        >>>
        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
        >>> torch.cuda.set_device(rank)
        >>> module = nn.Linear(1, 1, bias=False).to(rank)
        >>> model = nn.parallel.DistributedDataParallel(
        >>>    module, device_ids=[rank], output_device=rank
        >>> )
        >>> # Register a post-localSGD communication hook.
        >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
        >>> subgroup, _ = dist.new_subgroups()
        >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100)
        >>> model.register_comm_hook(state, post_localSGD_hook)
        >>>
        >>> # Average parameters among each group of 8 processes every 4 iterations, and among all
        >>> # the 16 processes every 16 iterations.
        >>> averager = hierarchicalSGD.HierarchicalModelAverager(
        >>>     period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
        >>> # After 100 steps, run model averaging at two levels.
        >>> for step in range(0, 200):
        >>>    optimizer.zero_grad()
        >>>    loss = loss_fn(output, labels)
        >>>    loss.backward()
        >>>    optimizer.step()
        >>>    # Average parameters after ``optimizer.step()``.
        >>>    # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
        >>>    averager.average_parameters(model.parameters())

    .. warning ::
        The last group size in the dict must be the size of the provided ``process_group``,
        which indicates model averaging at the highest level of the hierarchy.
        If ``process_group`` is not provided, then the last group size should be equal to the world size.

    .. warning ::
        `HierarchicalModelAverager` is experimental and subject to change.
    """

    def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
        super().__init__(process_group)
        if not period_group_size_dict:
            raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
        self._periods = list(period_group_size_dict.keys())
        if self._periods[0] <= 0:
            raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.")
        elif self._periods[-1] == 1:
            warnings.warn(
                "When the maximum period in arg ``period_group_size_dict`` is 1, "
                "no need to use model averaging because the communication cost "
                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
                "by DistributedDataParallel in the backward pass. Therefore, only "
                "DistributedDataParallel should be used for this case."
            )
        overall_group_size = dist.get_world_size(group=self.process_group)
        if list(period_group_size_dict.values())[-1] != overall_group_size:
            raise ValueError(
                f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} "
                f"must be equal to the size of arg ``process_group`` {overall_group_size}."
            )

        self.period_process_group_dict = OrderedDict()
        logger.info("Model averaging hierarchy:")
        for period, group_size in period_group_size_dict.items():
            logger.info(
                f"\tEach group that has {group_size} processes average parameters every {period} iterations, "
                "if no higher-level averaging.")
            if group_size != overall_group_size:
                self.period_process_group_dict[period], _ = dist.new_subgroups(
                    group_size=group_size, group=self.process_group)
            else:
                self.period_process_group_dict[period] = self.process_group

        if warmup_steps < 0:
            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
        self.warmup_steps = warmup_steps

    def _find_process_group(self):
        """
        Returns a process group as the value of an ``period_process_group_dict`` entry,
        if ``step`` can be divided by a period in the keys of ``period_process_group_dict``.
        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
        then the returned process group is the one corresponding to the largest period,
        since this process group will be used for averaging parameters at this ``step``.
        Returns ``None`` if not found.
        """
        for period in reversed(self._periods):
            if self.step % period == 0:
                return self.period_process_group_dict[period]
        return None

    def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
        """
        Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``
        and it can be divided by a period in the keys of ``period_process_group_dict``,
        where ``step`` is increased by 1 at each iteration in the training loop.
        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
        only the largest period is used, and the corresponding process group is used for averaging parameters.
        Args:
            params: The parameters of a model or parameter groups of an optimizer.
        """
        if self.step >= self.warmup_steps:
            group = self._find_process_group()
            if group is not None:
                utils.average_parameters_or_parameter_groups(params, group)
        self.step += 1