Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ nn / parallel / _replicated_tensor_ddp_utils.py

from contextlib import contextmanager

_DDP_WITH_REPLICATED_TENSOR = False

@contextmanager
def _ddp_replicated_tensor(val):
    """
    A context manager to tag tensors in the forward pass of DDP to be
    ``ReplicatedTensor``. This can be used by ReplicatedTensor inter-op
    during the forward pass to perform appropriate optimizations.

    This context manager needs to wrap DDP creation and modifying the underlying
    module passed into DDP after leaving this context manager would cause
    inconsitencies and the changes will not be picked up during the forward
    pass.
    """
    global _DDP_WITH_REPLICATED_TENSOR
    old_val = _DDP_WITH_REPLICATED_TENSOR
    _DDP_WITH_REPLICATED_TENSOR = val
    try:
        yield
    finally:
        _DDP_WITH_REPLICATED_TENSOR = old_val

def _ddp_with_replicated_tensor_enabled():
    global _DDP_WITH_REPLICATED_TENSOR
    return _DDP_WITH_REPLICATED_TENSOR

def _set_ddp_with_replicated_tensor(value):
    global _DDP_WITH_REPLICATED_TENSOR
    _DDP_WITH_REPLICATED_TENSOR = value