import inspect
import torch
import collections
import textwrap
import functools
import warnings
from typing import Dict, List, Set, Type
import torch._jit_internal as _jit_internal
from torch.jit.frontend import get_default_args, get_jit_def, get_class_properties
from torch.jit._builtins import _find_builtin
from torch.nn import Module
from torch._six import get_function_from_type, bind_method
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
PropertyStub = collections.namedtuple('Property', ('resolution_callback', 'def_'))
# TODO: there should be a more principled way of doing this.
ignored_attributes = [
"_version",
"_parameters",
"_buffers",
"_modules",
"_initializing",
"_backward_hooks",
"_forward_hooks",
"_forward_pre_hooks",
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"dump_patches",
]
def make_stub(func, name):
rcb = _jit_internal.createResolutionCallbackFromClosure(func)
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
return ScriptMethodStub(rcb, ast, func)
def make_stub_from_method(nn_module, method_name):
func = getattr(nn_module, method_name)
if isinstance(func, ScriptMethodStub):
return func
# Make sure the name present in the resulting AST will match the name
# requested here. The only time they don't match is if you do something
# like:
# def _forward(self):
# pass
# forward = _forward
# In this case, the actual function object will have the name `_forward`,
# even though we requested a stub for `forward`.
return make_stub(func, method_name)
def make_stubs_from_exported_methods(mod):
stubs = []
for name in dir(mod):
item = getattr(mod, name, None)
if (
_jit_internal.get_torchscript_modifier(item)
is _jit_internal.FunctionModifiers.EXPORT
):
stubs.append(make_stub_from_method(mod, name))
return stubs
# base types that can be constants
# in addition, tuples and lists of these base types are also considered constants
# If you edit this list, then you also need to edit the handlers in
# ConstantValue in jit/script/init.cpp
_constant_types = (bool, float, int, str, type(None), torch.device, torch.layout, torch.dtype)
def _get_valid_constant(attr, v, owner_type):
if isinstance(v, _constant_types):
return v
elif isinstance(v, tuple) or isinstance(v, list):
return tuple(_get_valid_constant(attr, x, owner_type) for x in v)
constants = ", ".join(torch.typename(typ) for typ in _constant_types)
raise TypeError(textwrap.dedent("""
'{}' object in attribute '{}.{}' is not a valid constant.
Valid constants are:
1. a nn.ModuleList
2. a value of type {{{}}}
3. a list or tuple of (2)
""".format(torch.typename(type(v)), owner_type, attr, constants)))
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len):
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
def infer_concrete_type_builder(nn_module, share_types=True):
"""
Build a ConcreteModuleTypeBuilder from an nn.Module. This
ConcreteModuleType doesn't have a JIT type associated with it yet, it
must be filled in by the caller.
"""
concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
if isinstance(nn_module, (torch.nn.ModuleDict)):
concrete_type_builder.set_module_dict()
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
concrete_type_builder.set_module_list()
class_annotations = getattr(nn_module, '__annotations__', {})
if isinstance(nn_module, (torch.quantization.QuantWrapper)):
class_annotations = {}
# Get user-annotated ignored attributes.
user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list())
concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
# try to infer the type from type annotation or from the object itself
def infer_type(name, item):
# The forward function from Module is special; never use this annotations; we
# need to infer type directly using JIT. I originally wanted to write
# this test as isinstance(class_annotations[name], Callable) but
# isinstance on typing things doesn't seem to work: isinstance(list, Callable)
# is also true!
inferred = False
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
attr_type = torch._C.InferredType(ann_to_type)
elif isinstance(item, torch.jit.Attribute):
ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
attr_type = torch._C.InferredType(ann_to_type)
else:
attr_type = torch._C._jit_try_infer_type(item)
inferred = True
return attr_type, inferred
added_names = set()
for name, item in nn_module._parameters.items():
if name in user_annotated_ignored_attributes:
continue
assert item is None or isinstance(item, torch.Tensor)
attr_type, _ = infer_type(name, item)
# We currently have the invariant in various places in our code
# that parameters must be Tensors. However, the nn.Module API also
# allows NoneType parameters. These parameters are not returned as
# part of `parameters()` and its variants, but are available
# through direct attribute access.
concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
added_names.add(name)
for name, item in nn_module._buffers.items():
if name in user_annotated_ignored_attributes:
continue
assert item is None or isinstance(item, torch.Tensor)
attr_type, _ = infer_type(name, item)
concrete_type_builder.add_attribute(name, attr_type.type(), False, True)
added_names.add(name)
for name, item in nn_module._modules.items():
if name in user_annotated_ignored_attributes:
continue
attr_type, _ = infer_type(name, item)
if item is None:
# Modules can be None. We don't have direct support for optional
# Modules, so the register it as an NoneType attribute instead.
concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
continue
if attr_type.success():
assert attr_type.type().is_interface_type()
# if the type can be inferred, it should be a module interface type
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type.type())
else:
# otherwise we get the concrete module type for item and add it to concrete_type
sub_concrete_type = get_module_concrete_type(item, share_types)
concrete_type_builder.add_module(name, sub_concrete_type)
added_names.add(name)
# populate constants_set
constants_set = getattr(nn_module, "__constants__", set())
# Constants annotated via `Final[T]` rather than being added to `__constants__`
for name, ann in class_annotations.items():
if torch._jit_internal.is_final(ann):
constants_set.add(name)
for name in constants_set:
if name in added_names:
# TODO: We should really error in this case, but its bc-breaking so
# we need to warn for at least one release
if name in nn_module._modules:
hint = "submodule"
elif name in nn_module._buffers:
hint = "buffer"
elif name in nn_module._parameters:
hint = "parameter"
else:
raise AssertionError("added_names must be submodule, parameter, or buffer")
warnings.warn("'{}' was found in ScriptModule constants, "
" but it is a non-constant {}. Consider removing it.".format(name, hint))
continue
if not hasattr(nn_module, name):
# TODO: We should really error in this case, but its bc-breaking so
# we need to warn for at least one release
warnings.warn("'{}' was found in ScriptModule constants, "
"but was not actually set in __init__. "
"Consider removing it.".format(name))
continue
value = getattr(nn_module, name)
concrete_type_builder.add_constant(name, _get_valid_constant(name, value, type(nn_module).__name__))
added_names.add(name)
# populate overloads
overloads = getattr(nn_module, "__overloads__", {})
# update with any annotated overloads
overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
for name, overloaded_names in overloads.items():
concrete_type_builder.add_overload(name, overloaded_names)
for name, value in nn_module.__dict__.items():
if name in ignored_attributes or name.startswith("__"):
# Python objects have lots of random attributes attached to them;
# PyTorch adds a few more. Prevent these from getting compiled.
continue
if name in user_annotated_ignored_attributes:
continue
if name in added_names:
# Don't re-add anything we already added
continue
# Handle Python function attributes
if inspect.isfunction(value):
try:
scripted_fn = torch.jit.script(value)
concrete_type_builder.add_function_attribute(
name,
torch._C._jit_try_infer_type(scripted_fn).type(),
value)
except Exception as e:
# If we fail to script the function, it isn't a hard error.
# Instead, we will add it to the list of attributes we failed
# to convert, with the compilation error.
hint = ("(This function exists as an attribute on the Python module, "
"but we failed to compile it to a TorchScript function. "
"\nThe error stack is reproduced here:\n{}").format(e)
concrete_type_builder.add_failed_attribute(name, hint)
pass
continue
# Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
# a call to an aten function like torch.add)
builtin_symbol_name = _find_builtin(value)
if builtin_symbol_name:
concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
continue
# Handle Script function attributes
if isinstance(value, torch.jit.ScriptFunction):
concrete_type_builder.add_function_attribute(
name,
torch._C._jit_try_infer_type(value).type(),
value)
continue
# If we got here, this is a regular "data" attribute, Add it to the concrete type
attr_type, inferred = infer_type(name, value)
if attr_type.success():
concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
else:
# TODO: could add more detail here. For example, what the user should do
# when the pytype is `list` or `NoneType`
inferred_msg = "Its type was inferred; try adding a type annotation for the attribute." if inferred else ""
additional_info = f"{attr_type.reason()}. {inferred_msg}"
hint = "(This attribute exists on the Python module, " \
f"but we failed to convert Python type: '{torch.typename(type(value))}' " \
f"to a TorchScript type. {additional_info})"
concrete_type_builder.add_failed_attribute(name, hint)
# add hooks to concrete type
for hook in nn_module._forward_hooks.values():
concrete_type_builder.add_forward_hook(hook)
for pre_hook in nn_module._forward_pre_hooks.values():
concrete_type_builder.add_forward_pre_hook(pre_hook)
return concrete_type_builder
class ConcreteTypeStore(object):
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
methods_compiled: Set[torch._C.ConcreteModuleType]
def __init__(self):
# Python module type => List[ConcreteModuleType)]
self.type_store = {}
# ConcreteTypes that have had their methods already compiled
self.methods_compiled = set()
def get_or_create_concrete_type(self, nn_module):
"""
Infer a ConcreteType from this `nn.Module` instance. Underlying JIT
types are re-used if possible.
"""
concrete_type_builder = infer_concrete_type_builder(nn_module)
nn_module_type = type(nn_module)
if nn_module_type not in self.type_store:
self.type_store[nn_module_type] = []
# Search the type store for an already-available JIT type
known_types = self.type_store[nn_module_type]
for known_type in known_types:
if known_type.equals(concrete_type_builder):
return known_type
# We didn't find anything; generate a new JIT type from this concrete type
concrete_type = concrete_type_builder.build()
self.type_store[nn_module_type].append(concrete_type)
return concrete_type
concrete_type_store = ConcreteTypeStore()
def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs):
method_defs = [m.def_ for m in method_stubs]
method_rcbs = [m.resolution_callback for m in method_stubs]
method_defaults = [get_default_args(m.original_method) for m in method_stubs]
property_defs = [p.def_ for p in property_stubs]
property_rcbs = [p.resolution_callback for p in property_stubs]
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
hook_defs = [h.def_ for h in hook_stubs]
hook_rcbs = [h.resolution_callback for h in hook_stubs]
pre_hook_defs = [h.def_ for h in pre_hook_stubs]
pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs]
concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
Loading ...