import functools
import inspect
import itertools
import types
from contextlib import contextmanager
from typing import Dict, List
import torch.nn
from .. import skipfiles, variables
from ..allowed_functions import is_allowed
from ..exc import RestartAnalysis, unimplemented
from ..guards import GuardBuilder
from ..mutation_guard import GenerationTracker
from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
from ..utils import (
is_lazy_module,
is_safe_constant,
istensor,
istype,
proxy_args_kwargs,
)
from .base import MutableLocal, typestr, VariableTracker
from .functions import invoke_and_store_as_constant
from .lists import SliceVariable
from .user_defined import UserDefinedObjectVariable
class NNModuleVariable(VariableTracker):
_nonvar_fields = ["module_type", "module_key"]
def __init__(self, module_type: type, module_key: str, **kwargs):
super().__init__(**kwargs)
self.module_type = module_type
self.module_key = module_key
assert self.source
def python_type(self):
return self.module_type
def _wrap_submodule(self, tx, source, submod, *key_extra, **options):
return
def unpack_var_sequence(self, tx):
# implement list/iter/tuple/etc calls
base = tx.output.get_submodule(self.module_key)
options = VariableTracker.propagate([self])
assert isinstance(
base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
), typestr(base)
assert self.source
result = []
for idx, submod in enumerate(base):
result.append(
tx.output.register_attr_or_module(
submod,
self.module_key,
idx,
source=NNModuleSource(GetItemSource(self.source, idx)),
**options,
)
)
return result
def call_hasattr(self, tx, name: str) -> "VariableTracker":
options = VariableTracker.propagate(self)
mod = tx.output.get_submodule(self.module_key)
result = hasattr(mod, name)
return variables.ConstantVariable(result, **options).add_guard(
NNModuleSource(AttrSource(self.source, name)).make_guard(
GuardBuilder.HASATTR
)
)
def is_training(self, tx):
mod = tx.output.get_submodule(self.module_key)
return getattr(mod, "training", False)
def convert_to_unspecialized(self, tx):
"""Restart analysis treating this module as an UnspecializedNNModuleVariable"""
mod = tx.output.get_submodule(self.module_key)
GenerationTracker.tag(mod)
# Mark the class dynamic unless its module initialization
if tx.f_code.co_name != "__init__":
GenerationTracker.mark_class_dynamic(type(mod))
raise RestartAnalysis()
def var_getattr(self, tx, name):
from .builder import VariableBuilder
options = VariableTracker.propagate(self)
guards = options.get("guards", set())
if self.source:
source = AttrSource(self.source, name)
options["source"] = source
else:
source = None
base = tx.output.get_submodule(self.module_key)
base_dict = object.__getattribute__(base, "__dict__")
object_member = True
all_class_attribute_names = set()
for x in inspect.getmro(base.__class__):
all_class_attribute_names.update(x.__dict__.keys())
if not self.source:
unimplemented("GETATTR with no source")
if name in base_dict:
subobj = base_dict[name]
elif (
"_modules" in base_dict
and name in base_dict["_modules"]
and name not in all_class_attribute_names
):
subobj = base_dict["_modules"][name]
elif "_parameters" in base_dict and name in base_dict["_parameters"]:
subobj = base_dict["_parameters"][name]
elif "_buffers" in base_dict and name in base_dict["_buffers"]:
subobj = base_dict["_buffers"][name]
else:
subobj = inspect.getattr_static(base, name)
object_member = False
if name == "__class__" and not object_member:
return variables.UserDefinedClassVariable(base.__class__, **options)
if object_member:
return VariableBuilder(tx, NNModuleSource(source))(subobj)
else:
if istype(subobj, property):
return variables.UserFunctionVariable(
subobj.fget,
guards=guards,
source=source,
).call_function(tx, [(self)], {})
elif istype(subobj, classmethod):
return variables.UserMethodVariable(
subobj.__func__,
variables.UserDefinedObjectVariable(type(base), guards=guards),
**options,
)
elif istype(subobj, staticmethod):
return variables.UserFunctionVariable(subobj.__get__(base), **options)
elif istype(subobj, types.FunctionType):
return variables.UserMethodVariable(subobj, self, **options)
elif is_safe_constant(subobj) or istensor(subobj):
# Support possibly common cases of class members
return VariableBuilder(tx, NNModuleSource(source))(subobj)
else:
unimplemented(f"class property {typestr(base)} {typestr(subobj)}")
return variables.GetAttrVariable(self, name, **options)
def call_function(
self,
tx,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate(self, args, kwargs.values())
mod = tx.output.get_submodule(self.module_key)
@contextmanager
def record_nn_module_stack():
try:
tx.nn_module_stack[self.module_key] = type(mod)
yield
finally:
del tx.nn_module_stack[self.module_key]
with record_nn_module_stack():
is_lazy = is_lazy_module(mod)
if (
isinstance(mod, torch.nn.Sequential)
and mod.__class__.forward is torch.nn.Sequential.forward
):
# unroll Sequential()
assert not kwargs
(arg,) = args
for idx, submod in enumerate(mod):
tx.call_function(
tx.output.register_attr_or_module(
submod,
self.module_key,
idx,
source=NNModuleSource(GetItemSource(self.source, idx)),
**options,
),
[arg],
{},
)
arg = tx.pop()
return arg
elif is_allowed(mod.__class__):
# The module type will change after it is called
if is_lazy:
self.module_type = mod.cls_to_become
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_module",
self.module_key,
*proxy_args_kwargs(args, kwargs),
),
**options,
)
else:
# for lazy modules, run the pre-hooks which will update the type
# TODO mlazos: we don't fully support all of the hooks that exist,
# so restrict using __call__ only to lazy modules for now
assert self.source, (
"Must provide a valid source in order to inline, "
"since inlined function may have default args which must be guarded."
)
if is_lazy:
if istype(mod.__call__, types.FunctionType):
fn = mod.__call__
fn_source = AttrSource(self.source, "__call__")
else:
assert istype(mod.__call__, types.MethodType)
fn = mod.__call__.__func__
fn_source = AttrSource(
AttrSource(self.source, "__call__"), "__func__"
)
args = [self] + args
else:
if istype(mod.forward, types.FunctionType):
fn = mod.forward
fn_source = AttrSource(self.source, "forward")
else:
assert istype(mod.forward, types.MethodType)
fn = mod.forward.__func__
fn_source = AttrSource(
AttrSource(self.source, "forward"), "__func__"
)
args = [self] + args
options["source"] = fn_source
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, **options),
args,
kwargs,
)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
constant=False,
) -> "VariableTracker":
from . import ConstantVariable, ListIteratorVariable, TupleVariable
options = VariableTracker.propagate(self, args, kwargs.values())
key = self.module_key
module = tx.output.get_submodule(key)
if name == "forward":
return self.call_function(tx, args, kwargs)
if name == "_check_input_dim" and skipfiles.is_torch_inline_allowed(
inspect.getfile(module.__class__._check_input_dim)
):
return ConstantVariable(True, **options)
if name == "_get_item_by_idx":
assert args[1].is_python_constant()
assert isinstance(args[0], TupleVariable)
mod_var = args[0].items[args[1].value]
key = mod_var.module_key
submod = tx.output.get_submodule(key)
return tx.output.register_attr_or_module(
submod,
key,
key,
source=NNModuleSource(GetItemSource(self.source, key)),
**options,
)
if constant:
fn = getattr(module, name)
name = f"{module.__class__.__name__}_{name}_result"
return invoke_and_store_as_constant(tx, fn, name, options, args, kwargs)
def assert_all_args_kwargs_const():
if not all(
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
):
raise unimplemented(f"non-const NNModule method {name}")
def get_kwargs(*names):
assert_all_args_kwargs_const()
fn = getattr(module, name)
bound_args = inspect.signature(fn).bind(
*([x.as_python_constant() for x in args]),
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
bound_args.apply_defaults()
bound_args = bound_args.arguments
return {k: bound_args[k] for k in names}
def wrap_values(items):
result = []
for name, submod in items:
result.append(
tx.output.register_attr_or_module(
submod,
key,
name,
source=NNModuleSource(gen_source(self.source, name)),
**options,
)
)
return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
def named_embed(name, obj):
return TupleVariable(
[
ConstantVariable(name, **options),
tx.output.register_attr_or_module(
obj,
key,
name,
source=NNModuleSource(gen_source(self.source, name)),
**options,
),
]
)
def gen_source(source, name):
name_split = name.split(".")
if name_split[0] == "":
return source
while len(name_split) > 0:
x = name_split.pop(0)
source = AttrSource(source, x)
return source
if name == "children":
Loading ...