Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ _dynamo / replay_record.py

import dataclasses
from dataclasses import field
from types import CodeType, ModuleType
from typing import Any, Dict

try:
    import dill
except ImportError:
    dill = None


@dataclasses.dataclass
class ModuleRecord:
    module: ModuleType
    accessed_attrs: Dict[str, Any] = field(default_factory=dict)


@dataclasses.dataclass
class DummyModule:
    name: str


@dataclasses.dataclass
class ExecutionRecord:
    code: CodeType
    globals: Dict[str, Any] = field(default_factory=dict)
    locals: Dict[str, Any] = field(default_factory=dict)
    builtins: Dict[str, Any] = field(default_factory=dict)
    code_options: Dict[str, Any] = field(default_factory=dict)

    def dump(self, f):
        assert dill is not None, "replay_record requires `pip install dill`"
        dill.dump(self, f)

    @classmethod
    def load(cls, f):
        assert dill is not None, "replay_record requires `pip install dill`"
        return dill.load(f)


@dataclasses.dataclass
class ExecutionRecorder:
    MOD_EXCLUDES = ["torch"]
    LOCAL_MOD_PREFIX = "___local_mod_"

    code: CodeType
    globals: Dict[str, Any] = field(default_factory=dict)
    locals: Dict[str, Any] = field(default_factory=dict)
    builtins: Dict[str, Any] = field(default_factory=dict)
    code_options: Dict[str, Any] = field(default_factory=dict)
    name_to_modrec: Dict[str, Any] = field(default_factory=dict)

    def add_local_var(self, name, var):
        if isinstance(var, ModuleType):
            if self._is_excl(var):
                return
            self.locals[name] = self._add_mod(var)
        else:
            self.locals[name] = var

    def add_global_var(self, name, var):
        if isinstance(var, ModuleType):
            if self._is_excl(var):
                return
            self.globals[name] = self._add_mod(var)
        else:
            self.globals[name] = var

    def add_local_mod(self, name, mod):
        assert isinstance(mod, ModuleType)
        if self._is_excl(mod):
            return

        self.add_global_var(name, mod)

    def record_module_access(self, mod, name, val):
        if self._is_excl(mod):
            return
        if isinstance(val, ModuleType):
            self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
            return

        self.name_to_modrec[mod.__name__].accessed_attrs[name] = val

    def get_record(self):
        return ExecutionRecord(
            self.code,
            ExecutionRecorder._resolve_modules(self.globals),
            ExecutionRecorder._resolve_modules(self.locals),
            self.builtins.copy(),
            self.code_options.copy(),
        )

    def _add_mod(self, mod):
        if mod.__name__ not in self.name_to_modrec:
            self.name_to_modrec[mod.__name__] = ModuleRecord(mod)

        return self.name_to_modrec[mod.__name__]

    @classmethod
    def _is_excl(cls, mod):
        return any([mod.__name__ == excl for excl in cls.MOD_EXCLUDES])

    # Convert ModuleRecords -> DummyModule tree
    @classmethod
    def _resolve_modules(cls, vars):
        def resolve_module(var):
            if not isinstance(var, ModuleRecord):
                return var

            dummy_mod = DummyModule(var.module.__name__)
            for attr_name, attr_value in var.accessed_attrs.items():
                attr_value = resolve_module(attr_value)
                dummy_mod.__setattr__(attr_name, attr_value)

            return dummy_mod

        return {k: resolve_module(v) for k, v in vars.items()}