Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ nn / parallel / _replicated_tensor_ddp_interop.py

import torch
from torch.distributed._shard.replicated_tensor import ReplicatedTensor

class ReplicatedTensorFunction(torch.autograd.Function):
    """
    Autograd function to ensure gradients are replicated between the
    replicated tensor and the original one.
    """
    @staticmethod
    def forward(ctx, inp, process_group=None):
        # set_materialize_grads(False) will ensure that None gradients stay as
        # None and are not filled with zeros.
        ctx.set_materialize_grads(False)
        return ReplicatedTensor(inp, process_group)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

def _make_replicated_tensor(tensor, process_group):
    replicated_tensor = ReplicatedTensorFunction.apply(tensor, process_group)
    replicated_tensor.grad = tensor.grad
    return replicated_tensor

def _replicate_module_recurse(module, process_group):
    replica = module._replicate_for_data_parallel()
    for param_name, param in module._parameters.items():
        if param is not None:
            setattr(replica, param_name, _make_replicated_tensor(param, process_group))
        else:
            setattr(replica, param_name, param)

    for buffer_name, buffer in module._buffers.items():
        setattr(replica, buffer_name, buffer)

    for module_name, child in module._modules.items():
        setattr(replica, module_name, _replicate_module_recurse(child, process_group))
    return replica

def _replicate_module(network, process_group):
    from torch.nn.parallel.replicate import _replicatable_module  # type: ignore[attr-defined]
    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    return _replicate_module_recurse(network, process_group)