Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _dynamo / side_effects.py

import collections
import dataclasses
import inspect
from typing import Any, Dict, List, Optional

import torch.nn

from . import utils, variables
from .bytecode_transformation import create_instruction
from .codegen import PyCodegen
from .source import LocalSource, Source
from .utils import object_new
from .variables.base import VariableTracker


@dataclasses.dataclass
class MutableSideEffects:
    """
    VariableTracker.mutable_local marker to indicate a list passed as
    an input that if we mutate we need to re-apply those mutations after
    the graph runs.
    """

    source: Source
    is_modified: bool = False

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other


@dataclasses.dataclass
class AttributeMutation:
    """
    VariableTracker.mutable_local marker to track changes to attributes
    """

    source: Source


class AttributeMutationExisting(AttributeMutation):
    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other


@dataclasses.dataclass
class AttributeMutationNew(AttributeMutation):
    cls_source: Source

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return self is other


class SideEffects:
    """
    Track side effects (list mutation, setattr, etc) that need to be
    applied after an FX graph is run.
    """

    id_to_variable: Dict[int, VariableTracker]
    store_attr_mutations: Dict[AttributeMutation, Dict[str, VariableTracker]]
    keepalive: List[Any]

    def __init__(self, id_to_variable=None, store_attr_mutations=None, keepalive=None):
        super().__init__()
        self.id_to_variable = id_to_variable or collections.OrderedDict()
        self.store_attr_mutations = store_attr_mutations or collections.OrderedDict()
        self.keepalive = keepalive or []

    def __eq__(self, other: object) -> bool:
        assert isinstance(other, SideEffects)
        # NB: do NOT test keepalive
        return (
            self.id_to_variable == other.id_to_variable
            and self.store_attr_mutations == other.store_attr_mutations
        )

    def diff(self, other: "SideEffects") -> Optional[str]:
        if self.id_to_variable != other.id_to_variable:
            sk_itv = self.id_to_variable.keys()
            ok_itv = other.id_to_variable.keys()
            if sk_itv != ok_itv:
                return f"id_to_variable keys: {sk_itv} != {ok_itv}"
            # Feel free to augment this with more fancy diffing logic
            # if needed for debugging
            return "id_to_variable: unknown diff"
        elif self.store_attr_mutations != other.store_attr_mutations:
            sk_sam = self.store_attr_mutations.keys()
            ok_sam = other.store_attr_mutations.keys()
            if sk_sam != ok_sam:
                return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
            return "store_attr_mutations: unknown diff"
        else:
            return None

    def clone(self):
        """Create a shallow copy"""
        return self.__class__(
            id_to_variable=collections.OrderedDict(self.id_to_variable),
            store_attr_mutations=collections.OrderedDict(
                (k, collections.OrderedDict(v))
                for k, v in self.store_attr_mutations.items()
            ),
            keepalive=list(self.keepalive),
        )

    def apply(self, fn, cache=None, skip_fn=lambda _: False):
        if cache is None:
            cache = dict()

        self.id_to_variable = collections.OrderedDict(
            (k, VariableTracker.apply(fn, v, cache, skip_fn))
            for k, v in self.id_to_variable.items()
        )
        self.store_attr_mutations = collections.OrderedDict(
            (k, VariableTracker.apply(fn, v, cache, skip_fn))
            for k, v in self.store_attr_mutations.items()
        )

    def __contains__(self, item):
        return id(item) in self.id_to_variable

    def __getitem__(self, item):
        return self.id_to_variable[id(item)]

    def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
        assert self.is_attribute_mutation(item)
        if item.mutable_local not in self.store_attr_mutations:
            self.store_attr_mutations[item.mutable_local] = collections.OrderedDict()
        self.store_attr_mutations[item.mutable_local][name] = value

    def load_attr(self, item, name):
        assert self.is_attribute_mutation(item)
        return self.store_attr_mutations[item.mutable_local][name]

    def store_cell(self, cellvar, value):
        assert isinstance(cellvar, variables.NewCellVariable)
        assert isinstance(value, variables.VariableTracker)
        self.store_attr(cellvar, "cell_contents", value)

    def load_cell(self, cellvar):
        assert isinstance(cellvar, variables.NewCellVariable)
        return self.load_attr(cellvar, "cell_contents")

    def load_global(self, gvar: VariableTracker, name: str):
        assert isinstance(gvar, variables.VariableTracker)
        return self.load_attr(gvar, name)

    def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker):
        assert isinstance(gvar, variables.VariableTracker)
        assert isinstance(value, variables.VariableTracker)
        self.store_attr(gvar, name, value)

    @staticmethod
    def cls_supports_mutation_side_effects(cls):
        return inspect.getattr_static(cls, "__setattr__", None) in (
            object.__setattr__,
            torch.nn.Module.__setattr__,
        )

    def is_attribute_mutation(self, item):
        return isinstance(item.mutable_local, AttributeMutation)

    def is_modified(self, item):
        if isinstance(item.mutable_local, AttributeMutationNew):
            return True
        if self.is_attribute_mutation(item):
            return item.mutable_local in self.store_attr_mutations
        return item.mutable_local.is_modified

    def _track_obj(
        self,
        source: Source,
        item: Any,
        variable: VariableTracker,
        mutable_cls=MutableSideEffects,
    ):
        """Start tracking a new variable for mutation"""
        variable = variable.clone(mutable_local=mutable_cls(source), source=source)
        self.id_to_variable[id(item)] = variable
        self.keepalive.append(item)
        return variable

    track_list = _track_obj
    track_dict = _track_obj

    def track_object_existing(
        self,
        source: Source,
        item: Any,
        variable: VariableTracker,
    ):
        return self._track_obj(
            source, item, variable, mutable_cls=AttributeMutationExisting
        )

    def track_object_new(
        self,
        cls_source: Source,
        user_cls: Any,
        variable_cls: Any,
        options,
    ):
        obj = object_new(user_cls)
        variable = variable_cls(
            obj,
            mutable_local=AttributeMutationNew(None, cls_source),
            **options,
        )
        self.id_to_variable[id(obj)] = variable
        self.keepalive.append(obj)
        return variable

    def track_cell_new(
        self,
    ):
        obj = object()
        variable = variables.NewCellVariable(
            mutable_local=AttributeMutationNew(None, None),
        )
        self.id_to_variable[id(obj)] = variable
        self.keepalive.append(obj)
        return variable

    def track_cell_existing(self, source: Source, item: Any):
        variable = variables.NewCellVariable(
            mutable_local=AttributeMutationExisting(source),
        )
        self.id_to_variable[id(item)] = variable
        self.keepalive.append(item)
        return variable

    def track_global_existing(self, source: Source, item: Any):
        variable = variables.NewGlobalVariable(
            mutable_local=AttributeMutationExisting(source),
        )
        self.id_to_variable[id(item)] = variable
        self.keepalive.append(item)
        return variable

    def prune_dead_object_new(self, tx):
        live_new_objects = set()
        skip_obj = None

        def visit(var: VariableTracker):
            if (
                isinstance(var.mutable_local, AttributeMutationNew)
                and var.mutable_local is not skip_obj
            ):
                live_new_objects.add(var.mutable_local)
            return var

        def is_live(var: VariableTracker):
            if isinstance(var, AttributeMutationNew):
                return var in live_new_objects
            if isinstance(var, VariableTracker):
                return is_live(var.mutable_local)
            return True

        VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals))
        for var in self.id_to_variable.values():
            if not isinstance(var.mutable_local, AttributeMutationNew):
                VariableTracker.apply(visit, var)

        for skip_obj, setattrs in self.store_attr_mutations.items():
            VariableTracker.apply(visit, setattrs)

        self.id_to_variable = collections.OrderedDict(
            (k, v) for k, v in self.id_to_variable.items() if is_live(v)
        )
        self.store_attr_mutations = collections.OrderedDict(
            (k, v) for k, v in self.store_attr_mutations.items() if is_live(k)
        )

    def mutation(self, oldvar, newvar):
        return newvar.clone(
            mutable_local=MutableSideEffects(oldvar.mutable_local.source, True)
        )

    def _get_modified_vars(self):
        return [var for var in self.id_to_variable.values() if self.is_modified(var)]

    def codegen_save_tempvars(self, cg: PyCodegen):
        for var in self._get_modified_vars():
            if isinstance(
                var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
            ) and isinstance(var, variables.NewCellVariable):
                cg.load_import_from(utils.__name__, "make_cell")
                cg.extend_output([create_instruction("CALL_FUNCTION", 0)])
                cg.add_cache(var)
                if isinstance(var.mutable_local, AttributeMutationNew):
                    var.mutable_local.source = LocalSource(cg.tempvars[var])
            elif isinstance(var.mutable_local, AttributeMutationNew):
                cg.load_import_from(utils.__name__, "object_new")
                cg(var.mutable_local.cls_source)
                cg.extend_output([create_instruction("CALL_FUNCTION", 1)])
                cg.add_cache(var)
                var.mutable_local.source = LocalSource(cg.tempvars[var])
            elif var in cg.tempvars:
                assert cg.tempvars.get(var) is None
                # subsequent usage should point to the original variable
                cg(var.mutable_local.source)
                cg.add_cache(var)

    def codegen_update_mutated(self, cg: PyCodegen):
        suffixes = []
        for var in self._get_modified_vars():
            if isinstance(var, variables.ListVariable):
                # old[:] = new
                cg(var, allow_cache=False)
                cg(var.mutable_local.source)
                cg.extend_output(
                    [
                        cg.create_load_const(None),
                        cg.create_load_const(None),
                        create_instruction("BUILD_SLICE", 2),
                    ]
                )
                suffixes.append([create_instruction("STORE_SUBSCR")])
            elif isinstance(var, variables.ConstDictVariable):
                cg.tx.output.update_co_names("clear")
                cg.tx.output.update_co_names("update")

                cg(var.mutable_local.source)
                cg.extend_output([create_instruction("LOAD_METHOD", "update")])
                cg(var, allow_cache=False)

                cg(var.mutable_local.source)
                cg.extend_output([create_instruction("LOAD_METHOD", "clear")])

                suffixes.append(
                    [
                        create_instruction("CALL_METHOD", 0),  # clear
                        create_instruction("POP_TOP"),
                        create_instruction("CALL_METHOD", 1),  # update
                        create_instruction("POP_TOP"),
                    ]
Loading ...