Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ distributed / _composable_state.py

from typing import cast, Dict, Optional

import torch.nn as nn


class _State:
    pass


_module_state_mapping: Dict[nn.Module, _State] = {}


def _insert_module_state(module: nn.Module, state: _State) -> None:
    global _module_state_mapping
    assert module not in _module_state_mapping, f"Inserting {module} more than once."
    _module_state_mapping[module] = state


def _get_module_state(module: nn.Module) -> Optional[_State]:
    """
    Given a ``module``, this API finds out if the module is also a ``_State``
    instance or if the module is managed by a composable API. If the module
    is also a ``_State``, ``module`` will be casted to ``_State` and returned.
    If it is managed by a composable API, the corresponding ``_State`` will
    be returned.
    """

    global _module_state_mapping
    if isinstance(module, _State):
        return cast(_State, module)
    else:
        return _module_state_mapping.get(module, None)