Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ jit / _trace.py

"""Tracing

This module contains functionality to support the JIT's tracing frontend, notably:
    * torch.jit.trace
    * torch.jit.trace_module

This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import torch

import copy
import os
import contextlib
import functools
import warnings
import inspect
import re
from typing import Any, Dict, List, Optional, Set

from torch.jit._state import _python_cu, _enabled
from torch.jit._script import ScriptModule, _CachedForward, script
from torch._jit_internal import _qualified_name
from torch.autograd import function
from torch.nn import Module

_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten


def _create_interpreter_name_lookup_fn(frames_up=1):
    def _get_interpreter_name_for_var(var):
        frame = inspect.currentframe()
        if not frame:
            raise RuntimeError("failed to inspect frame")

        i = 0
        while i < frames_up + 1:
            frame = frame.f_back
            if not frame:
                raise RuntimeError("failed to get frame")
            i += 1

        f_locals = frame.f_locals
        f_globals = frame.f_globals

        for k, v in f_locals.items():
            if isinstance(v, torch.Tensor) and var is v:
                return k if k != "self" else ""
        return ""

    return _get_interpreter_name_for_var


def _unique_state_dict(module, keep_vars=False):
    # since Parameter.detach() always creates a new torch.Tensor instance,
    # id(v) doesn't work with it. So we always get the Parameter or Buffer
    # as values, and deduplicate the params using Parameters and Buffers
    state_dict = module.state_dict(keep_vars=True)
    filtered_dict = type(state_dict)()
    seen_ids: Set[int] = set()
    for k, v in state_dict.items():
        if id(v) in seen_ids:
            continue
        seen_ids.add(id(v))
        if keep_vars:
            filtered_dict[k] = v
        else:
            filtered_dict[k] = v.detach()
    return filtered_dict


class ONNXTracedModule(torch.nn.Module):
    def __init__(
        self,
        inner,
        strict=True,
        force_outplace=False,
        return_inputs=False,
        return_inputs_states=False,
    ):
        super(ONNXTracedModule, self).__init__()
        # inner may be a Module, or it may be an arbitrary callable
        # If it's a Module, we get its parameters automatically, which lets
        # us avoid a special casing functions versus modules.
        self.inner = inner
        self.strict = strict
        self._force_outplace = force_outplace
        self._return_inputs = return_inputs
        self._return_inputs_states = return_inputs_states

    def forward(self, *args: torch.Tensor):
        in_vars, in_desc = _flatten(args)
        # NOTE: use full state, because we need it for BatchNorm export
        # This differs from the compiler path, which doesn't support it at the moment.
        module_state = list(_unique_state_dict(self, keep_vars=True).values())

        ret_inputs = []
        inputs_states = []
        outs = []

        def wrapper(*args):
            in_args: List[torch.Tensor] = []
            for i in range(len(in_vars)):
                if not isinstance(args[i], torch.Tensor):
                    raise RuntimeError('Expected Tensor argument')
                in_args.append(args[i])

            trace_inputs = _unflatten(in_args, in_desc)

            ret_inputs.append(
                tuple(x.clone(memory_format=torch.preserve_format) for x in args)
            )
            if self._return_inputs_states:
                inputs_states.append(_unflatten(in_args, in_desc))
            outs.append(self.inner(*trace_inputs))
            if self._return_inputs_states:
                inputs_states[0] = (inputs_states[0], trace_inputs)
            out_vars, _ = _flatten(outs)
            if len(out_vars) == 1:
                return out_vars[0]
            else:
                return tuple(out_vars)

        graph, out = torch._C._create_graph_by_tracing(
            wrapper,
            in_vars + module_state,
            _create_interpreter_name_lookup_fn(),
            self.strict,
            self._force_outplace,
        )

        if self._return_inputs:
            return graph, outs[0], ret_inputs[0]
        if self._return_inputs_states:
            return graph, outs[0], inputs_states[0]
        else:
            return graph, outs[0]


def _clone_inputs(args):
    def clone_input(a):
        if a is None:
            return None
        elif isinstance(a, torch.Tensor):
            # TODO: figure out one liner to .clone() and set requires_grad
            v = (
                a.detach()
                .clone(memory_format=torch.preserve_format)
                .requires_grad_(a.requires_grad)
            )
            if a.grad is not None:
                v.grad = clone_input(v.grad)
            return v
        else:
            return a.clone(memory_format=torch.preserve_format)

    return function._nested_map(
        lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
    )(args)


# This is purely for developer debugging.  We are not going to advertise it.
_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False)  # CUDA-only timing
_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)


@contextlib.contextmanager
def _time(trace_name, name, time=True):
    if (not _JIT_TIME and not time) or not torch.cuda.is_available():
        yield
        return
    stream = torch.cuda.current_stream()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    stream.record_event(start)
    try:
        yield
    finally:
        stream.record_event(end)
        end.synchronize()
        print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))


def verify(model, args, loss_fn=torch.sum, devices=None):
    """
    Verify that a JIT compiled model has the same behavior as its uncompiled
    version along with its backwards pass.  If your model returns multiple
    outputs, you must also specify a `loss_fn` to produce a loss for which
    the backwards will be computed.

    This function has side-effects (e.g., it executes your model / saves and loads
    parameters), so don't expect the model to come out exactly the same as what
    you passed in.

    Args:
        model (compiled torch.nn.Module or function): the module/function to be
            verified.  The module/function definition MUST have been decorated with
            `@torch.jit.compile`.
        args (tuple or Tensor): the positional arguments to pass to the
            compiled function/module to be verified.  A non-tuple is assumed to
            be a single positional argument to be passed to the model.
        loss_fn (function, optional): the loss function to be applied to
            the output of the model, before backwards is invoked.  By default,
            we assume that a model returns a single result, and we :func:`torch.sum`
            before calling backwards; if this is inappropriate, you can pass your
            own loss function.  Note that if a model returns a tuple of results,
            these are passed as separate positional arguments to `loss_fn`.
        devices (iterable of device IDs, optional): the GPU devices which the
            compiled module will be run on.  This determines the RNG state we
            must save when running both compiled and uncompiled versions of the model.
    """
    # TODO: In principle, we track device information in our trace, so it
    # should be possible to check if our execution actually obeyed the 'devices'
    # the user provided.

    # TODO: Consider adding a utility function to torch.jit to test
    # for this case
    if not isinstance(model, torch._C.CompiledFunction):  # type: ignore
        raise TypeError(
            "Cannot verify an uncompiled module.  Add @torch.jit.compile to compile it"
        )
    is_module = isinstance(model, Module)

    if not isinstance(args, tuple):
        args = (args,)

    saved_args = _clone_inputs(args)
    if is_module:
        saved_state = copy.deepcopy(model.state_dict())

    def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
        params = list(model.parameters()) if is_module else []
        in_vars, _ = _flatten((args, params))
        # We use a special API to reset the trace and compile it from scratch.
        compiled_fn = model
        if force_trace:
            compiled_fn.clear_cache()
        if assert_compiled:
            hits = compiled_fn.hits
        out = model(*args)
        if assert_compiled and compiled_fn.hits == hits:
            raise RuntimeError("failed to use the compiled function")
        if not isinstance(out, tuple):
            out = (out,)
        if loss_fn == torch.sum and len(out) != 1:
            raise ValueError(
                (
                    "Model returns {} outputs, but default loss function "
                    "(torch.sum) can only handle a single output"
                ).format(len(out))
            )
        out_vars, _ = _flatten(out)
        saved_outs = [
            v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
        ]
        loss = loss_fn(*out)
        grads = torch.autograd.grad([loss], in_vars)
        # TODO: I'm not sure if the clone here is necessary but it is safer
        saved_grads = [
            v.detach().clone(memory_format=torch.preserve_format) for v in grads
        ]
        return (saved_outs, saved_grads)

    with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
        uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
        assert model.has_trace_for(*args)

    if is_module:
        model.load_state_dict(saved_state)
    compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)

    _verify_equal(uncompiled_outs, compiled_outs)
    _verify_equal(uncompiled_grads, compiled_grads)


def _verify_equal(xs, ys):
    for x, y in zip(xs, ys):
        if x.sub(y).abs().max() > 1e-6:
            raise RuntimeError("JIT and real computation mismatch")


def indent(s):
    return "\n".join(["\t" + line for line in s.splitlines()])


class TracingCheckError(Exception):
    def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
        self.message = "Tracing failed sanity checks!\n"
        if extra_msg is not None:
            self.message += extra_msg + "\n"
        if graph_diff_error is not None:
            self.message += "ERROR: Graphs differed across invocations!\n"
            self.message += indent(graph_diff_error) + "\n"
        if tensor_compare_error is not None:
            self.message += (
                "ERROR: Tensor-valued Constant nodes differed in value "
                "across invocations. This often indicates that the tracer has"
                " encountered untraceable code.\n"
            )
            self.message += indent(tensor_compare_error) + "\n"
        super(TracingCheckError, self).__init__(self.message)


# Check the traced module against a set of user-provided validation inputs
@torch.no_grad()
def _check_trace(
    check_inputs,
    func,
    traced_func,
    check_tolerance,
    strict,
    force_outplace,
    is_trace_module,
    _module_class,
):
    # Note: tracing is independent of optimizations, which consume the trace
    for inputs in check_inputs:

        if isinstance(inputs, torch.Tensor):
            inputs = (inputs,)

        if is_trace_module:
            copied_dict = {}
            for name, data in inputs.items():
                copied_dict[name] = _clone_inputs(data)
            check_mod = torch.jit.trace_module(
                func.__self__ if hasattr(func, "__self__") else func,
                copied_dict,
                check_trace=False,
                strict=strict,
                _force_outplace=force_outplace,
                _module_class=_module_class,
                _compilation_unit=torch._C.CompilationUnit(),
            )
            check_mod_func = check_mod._c._get_method(traced_func.name)
            inputs = inputs[traced_func.name]
            if isinstance(inputs, (torch.Tensor, dict)):
                inputs = (inputs,)
        else:
            check_mod = torch.jit.trace(
                func,
                _clone_inputs(inputs),
                check_trace=False,
Loading ...