"""Tracing
This module contains functionality to support the JIT's tracing frontend, notably:
* torch.jit.trace
* torch.jit.trace_module
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import torch
import copy
import os
import contextlib
import functools
import warnings
import inspect
import re
from typing import Any, Dict, List, Optional, Set
from torch.jit._state import _python_cu, _enabled
from torch.jit._script import ScriptModule, _CachedForward, script
from torch._jit_internal import _qualified_name
from torch.autograd import function
from torch.nn import Module
_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten
def _create_interpreter_name_lookup_fn(frames_up=1):
def _get_interpreter_name_for_var(var):
frame = inspect.currentframe()
if not frame:
raise RuntimeError("failed to inspect frame")
i = 0
while i < frames_up + 1:
frame = frame.f_back
if not frame:
raise RuntimeError("failed to get frame")
i += 1
f_locals = frame.f_locals
f_globals = frame.f_globals
for k, v in f_locals.items():
if isinstance(v, torch.Tensor) and var is v:
return k if k != "self" else ""
return ""
return _get_interpreter_name_for_var
def _unique_state_dict(module, keep_vars=False):
# since Parameter.detach() always creates a new torch.Tensor instance,
# id(v) doesn't work with it. So we always get the Parameter or Buffer
# as values, and deduplicate the params using Parameters and Buffers
state_dict = module.state_dict(keep_vars=True)
filtered_dict = type(state_dict)()
seen_ids: Set[int] = set()
for k, v in state_dict.items():
if id(v) in seen_ids:
continue
seen_ids.add(id(v))
if keep_vars:
filtered_dict[k] = v
else:
filtered_dict[k] = v.detach()
return filtered_dict
class ONNXTracedModule(torch.nn.Module):
def __init__(
self,
inner,
strict=True,
force_outplace=False,
return_inputs=False,
return_inputs_states=False,
):
super(ONNXTracedModule, self).__init__()
# inner may be a Module, or it may be an arbitrary callable
# If it's a Module, we get its parameters automatically, which lets
# us avoid a special casing functions versus modules.
self.inner = inner
self.strict = strict
self._force_outplace = force_outplace
self._return_inputs = return_inputs
self._return_inputs_states = return_inputs_states
def forward(self, *args: torch.Tensor):
in_vars, in_desc = _flatten(args)
# NOTE: use full state, because we need it for BatchNorm export
# This differs from the compiler path, which doesn't support it at the moment.
module_state = list(_unique_state_dict(self, keep_vars=True).values())
ret_inputs = []
inputs_states = []
outs = []
def wrapper(*args):
in_args: List[torch.Tensor] = []
for i in range(len(in_vars)):
if not isinstance(args[i], torch.Tensor):
raise RuntimeError('Expected Tensor argument')
in_args.append(args[i])
trace_inputs = _unflatten(in_args, in_desc)
ret_inputs.append(
tuple(x.clone(memory_format=torch.preserve_format) for x in args)
)
if self._return_inputs_states:
inputs_states.append(_unflatten(in_args, in_desc))
outs.append(self.inner(*trace_inputs))
if self._return_inputs_states:
inputs_states[0] = (inputs_states[0], trace_inputs)
out_vars, _ = _flatten(outs)
if len(out_vars) == 1:
return out_vars[0]
else:
return tuple(out_vars)
graph, out = torch._C._create_graph_by_tracing(
wrapper,
in_vars + module_state,
_create_interpreter_name_lookup_fn(),
self.strict,
self._force_outplace,
)
if self._return_inputs:
return graph, outs[0], ret_inputs[0]
if self._return_inputs_states:
return graph, outs[0], inputs_states[0]
else:
return graph, outs[0]
def _clone_inputs(args):
def clone_input(a):
if a is None:
return None
elif isinstance(a, torch.Tensor):
# TODO: figure out one liner to .clone() and set requires_grad
v = (
a.detach()
.clone(memory_format=torch.preserve_format)
.requires_grad_(a.requires_grad)
)
if a.grad is not None:
v.grad = clone_input(v.grad)
return v
else:
return a.clone(memory_format=torch.preserve_format)
return function._nested_map(
lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
)(args)
# This is purely for developer debugging. We are not going to advertise it.
_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing
_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)
@contextlib.contextmanager
def _time(trace_name, name, time=True):
if (not _JIT_TIME and not time) or not torch.cuda.is_available():
yield
return
stream = torch.cuda.current_stream()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
stream.record_event(start)
try:
yield
finally:
stream.record_event(end)
end.synchronize()
print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
def verify(model, args, loss_fn=torch.sum, devices=None):
"""
Verify that a JIT compiled model has the same behavior as its uncompiled
version along with its backwards pass. If your model returns multiple
outputs, you must also specify a `loss_fn` to produce a loss for which
the backwards will be computed.
This function has side-effects (e.g., it executes your model / saves and loads
parameters), so don't expect the model to come out exactly the same as what
you passed in.
Args:
model (compiled torch.nn.Module or function): the module/function to be
verified. The module/function definition MUST have been decorated with
`@torch.jit.compile`.
args (tuple or Tensor): the positional arguments to pass to the
compiled function/module to be verified. A non-tuple is assumed to
be a single positional argument to be passed to the model.
loss_fn (function, optional): the loss function to be applied to
the output of the model, before backwards is invoked. By default,
we assume that a model returns a single result, and we :func:`torch.sum`
before calling backwards; if this is inappropriate, you can pass your
own loss function. Note that if a model returns a tuple of results,
these are passed as separate positional arguments to `loss_fn`.
devices (iterable of device IDs, optional): the GPU devices which the
compiled module will be run on. This determines the RNG state we
must save when running both compiled and uncompiled versions of the model.
"""
# TODO: In principle, we track device information in our trace, so it
# should be possible to check if our execution actually obeyed the 'devices'
# the user provided.
# TODO: Consider adding a utility function to torch.jit to test
# for this case
if not isinstance(model, torch._C.CompiledFunction): # type: ignore
raise TypeError(
"Cannot verify an uncompiled module. Add @torch.jit.compile to compile it"
)
is_module = isinstance(model, Module)
if not isinstance(args, tuple):
args = (args,)
saved_args = _clone_inputs(args)
if is_module:
saved_state = copy.deepcopy(model.state_dict())
def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
params = list(model.parameters()) if is_module else []
in_vars, _ = _flatten((args, params))
# We use a special API to reset the trace and compile it from scratch.
compiled_fn = model
if force_trace:
compiled_fn.clear_cache()
if assert_compiled:
hits = compiled_fn.hits
out = model(*args)
if assert_compiled and compiled_fn.hits == hits:
raise RuntimeError("failed to use the compiled function")
if not isinstance(out, tuple):
out = (out,)
if loss_fn == torch.sum and len(out) != 1:
raise ValueError(
(
"Model returns {} outputs, but default loss function "
"(torch.sum) can only handle a single output"
).format(len(out))
)
out_vars, _ = _flatten(out)
saved_outs = [
v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
]
loss = loss_fn(*out)
grads = torch.autograd.grad([loss], in_vars)
# TODO: I'm not sure if the clone here is necessary but it is safer
saved_grads = [
v.detach().clone(memory_format=torch.preserve_format) for v in grads
]
return (saved_outs, saved_grads)
with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
assert model.has_trace_for(*args)
if is_module:
model.load_state_dict(saved_state)
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
_verify_equal(uncompiled_outs, compiled_outs)
_verify_equal(uncompiled_grads, compiled_grads)
def _verify_equal(xs, ys):
for x, y in zip(xs, ys):
if x.sub(y).abs().max() > 1e-6:
raise RuntimeError("JIT and real computation mismatch")
def indent(s):
return "\n".join(["\t" + line for line in s.splitlines()])
class TracingCheckError(Exception):
def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
self.message = "Tracing failed sanity checks!\n"
if extra_msg is not None:
self.message += extra_msg + "\n"
if graph_diff_error is not None:
self.message += "ERROR: Graphs differed across invocations!\n"
self.message += indent(graph_diff_error) + "\n"
if tensor_compare_error is not None:
self.message += (
"ERROR: Tensor-valued Constant nodes differed in value "
"across invocations. This often indicates that the tracer has"
" encountered untraceable code.\n"
)
self.message += indent(tensor_compare_error) + "\n"
super(TracingCheckError, self).__init__(self.message)
# Check the traced module against a set of user-provided validation inputs
@torch.no_grad()
def _check_trace(
check_inputs,
func,
traced_func,
check_tolerance,
strict,
force_outplace,
is_trace_module,
_module_class,
):
# Note: tracing is independent of optimizations, which consume the trace
for inputs in check_inputs:
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
if is_trace_module:
copied_dict = {}
for name, data in inputs.items():
copied_dict[name] = _clone_inputs(data)
check_mod = torch.jit.trace_module(
func.__self__ if hasattr(func, "__self__") else func,
copied_dict,
check_trace=False,
strict=strict,
_force_outplace=force_outplace,
_module_class=_module_class,
_compilation_unit=torch._C.CompilationUnit(),
)
check_mod_func = check_mod._c._get_method(traced_func.name)
inputs = inputs[traced_func.name]
if isinstance(inputs, (torch.Tensor, dict)):
inputs = (inputs,)
else:
check_mod = torch.jit.trace(
func,
_clone_inputs(inputs),
check_trace=False,
Loading ...