from contextlib import contextmanager
_DDP_WITH_REPLICATED_TENSOR = False
@contextmanager
def _ddp_replicated_tensor(val):
"""
A context manager to tag tensors in the forward pass of DDP to be
``ReplicatedTensor``. This can be used by ReplicatedTensor inter-op
during the forward pass to perform appropriate optimizations.
This context manager needs to wrap DDP creation and modifying the underlying
module passed into DDP after leaving this context manager would cause
inconsitencies and the changes will not be picked up during the forward
pass.
"""
global _DDP_WITH_REPLICATED_TENSOR
old_val = _DDP_WITH_REPLICATED_TENSOR
_DDP_WITH_REPLICATED_TENSOR = val
try:
yield
finally:
_DDP_WITH_REPLICATED_TENSOR = old_val
def _ddp_with_replicated_tensor_enabled():
global _DDP_WITH_REPLICATED_TENSOR
return _DDP_WITH_REPLICATED_TENSOR
def _set_ddp_with_replicated_tensor(value):
global _DDP_WITH_REPLICATED_TENSOR
_DDP_WITH_REPLICATED_TENSOR = value