Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / _composable / replicate.py

from typing import List, Tuple

import torch
import torch.nn as nn

from . import _ddp
from .contract import _get_registry, contract


@contract()
def replicate(
    module: nn.Module,  # NOTE: contract now supports single module only
    **kwargs,
) -> nn.Module:
    r"""Replicates a module

    Args:
        module (torch.nn.Module): module to replicate

    Example::
        >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
        >>> module = nn.Linear(3, 3)
        >>> replicate(module)
    """
    _ReplicateState().mark_modules(module, **kwargs)
    return module


def _can_compose(module: nn.Module) -> bool:
    r"""Check if module is composable for `replicate` API."""
    return "fully_shard" not in _get_registry(module)


class _ReplicateState:
    def __init__(self) -> None:
        self.modules: List[nn.Module] = []
        self.has_initialized: bool = False
        self._param_list: nn.ParameterList = nn.ParameterList()
        self.kwargs: dict = {}

    def mark_modules(self, *modules: nn.Module, **kwargs) -> None:
        for module in modules:
            if not _can_compose(module):
                raise AssertionError(
                    "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
                )
            self.modules.append(module)
            replicate.state(module)._distributed_state = self
            replicate.state(module)._params_collected = False
            module.register_forward_pre_hook(self.forward_pre_hook)
            # TODO(@yhcharles): fix type error
            module.register_forward_hook(self.forward_post_hook)  # type: ignore[arg-type]
        self.kwargs = kwargs

    def _recursive_collect_params(self, module: nn.Module) -> None:
        # skip if managed by other APIs
        if not _can_compose(module):
            return

        # skip if module parameters already collected
        if hasattr(replicate.state(module), "_params_collected"):
            if replicate.state(module)._params_collected:
                return
            replicate.state(module)._params_collected = True

        self._param_list.extend(
            param for param in module.parameters(recurse=False) if param.requires_grad
        )
        for child in module.children():
            self._recursive_collect_params(child)

    def init_helper(self) -> None:
        if self.has_initialized:
            return

        self.has_initialized = True
        for module in self.modules:
            self._recursive_collect_params(module)

        self._ddp = _ddp.DistributedDataParallel(self._param_list, **self.kwargs)

    def forward_pre_hook(
        self, module: nn.Module, input: Tuple[torch.Tensor, ...]
    ) -> None:
        self.init_helper()
        self._ddp.pre_forward()

    def forward_post_hook(
        self,
        module: nn.Module,
        input: Tuple[torch.Tensor],
        output: torch.Tensor,
    ) -> torch.Tensor:
        return self._ddp.post_forward(output)