# Torch
from torch.autograd import Variable
from torch.autograd.function import _nested_map
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
from torch.onnx import OperatorExportTypes
import torch
import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
import zipfile
import functools
# Testing utils
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_WINDOWS, \
freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
# Standard library
from contextlib import contextmanager
from functools import reduce
from torch._six import StringIO
from collections import defaultdict
import inspect
import io
import math
import os
import pickle
import sys
import tempfile
import textwrap
from typing import Any, Dict, List
RUN_CUDA = torch.cuda.is_available()
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
RUN_CUDA_HALF = RUN_CUDA
# HIP supports half, no version check necessary
if torch.cuda.is_available() and not torch.version.hip:
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
for d in range(torch.cuda.device_count()):
major = torch.cuda.get_device_capability(d)[0]
if (major < 6):
RUN_CUDA_HALF = False
def execWrapper(code, glob, loc):
exec(code, glob, loc)
def do_input_map(fn, input):
return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
def clear_class_registry():
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._script_classes.clear()
def get_execution_plan(graph_executor_state):
execution_plans = list(graph_executor_state.execution_plans.values())
num_plans = len(execution_plans)
if num_plans != 1:
raise RuntimeError('This test assumes this GraphExecutor should '
'only have one execution plan, got: {}'.format(num_plans))
return execution_plans[0]
class _AssertRaisesRegexWithHighlightContext(object):
"""
A context manager that is useful for checking that error messages highlight
the correct part of the source code.
"""
def __init__(self, test_case, exception, regex, highlight):
self.test_case = test_case
self.exception_type = exception
self.regex = regex
self.highlight = highlight
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
if type:
raise value
if self.highlight:
FileCheck().check_source_highlighted(self.highlight).run(str(value))
return True
FUSION_GROUP = "prim::TensorExprGroup"
class JitTestCase(JitCommonTestCase):
_do_cuda_memory_leak_check = True
_restored_warnings = False
class capture_stdout(list):
"""
Replace sys.stdout with a temporary StringIO
"""
def __enter__(self):
self.sys_stdout = sys.stdout
self.stringio = StringIO()
sys.stdout = self.stringio
return self
def __exit__(self, *args):
self.append(str(self.stringio.getvalue()))
del self.stringio
sys.stdout = self.sys_stdout
def setHooks(self):
torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
def clearHooks(self):
torch._C._jit_set_emit_hooks(None, None)
def setUp(self):
super().setUp()
# unittest overrides all warning filters and forces all of them to show up
# after we install our own to silence those coming from inside PyTorch.
# This will ensure that our filter still takes precedence.
if not JitTestCase._restored_warnings:
torch.jit.TracerWarning.ignore_lib_warnings()
JitTestCase._restored_warnings = True
self.setHooks()
def tearDown(self):
super().tearDown()
# needs to be cleared because python might be unloaded before
# the callback gets destucted
self.clearHooks()
clear_class_registry()
def assertAllFused(self, graph, except_for=()):
# note this helper collects nodes on 'fast path' only
# i.e. the true blocks of specialized checks
def get_nodes_and_parents_recursively(block, kind, acc):
for node in block.nodes():
if node.kind() == kind:
acc[block].append(node)
elif node.kind() == 'prim::DifferentiableGraph':
get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
node.inputs().__next__().node().kind() == 'prim::TypeCheck' or
node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'):
get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
else:
for inner_block in node.blocks():
get_nodes_and_parents_recursively(inner_block, kind, acc)
allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list)
get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
self.assertTrue(len(fusion_groups) == 1, 'got {}'.format(graph))
(graph, fusion_nodes) = list(fusion_groups.items())[0]
# the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes`
self.assertTrue(len(fusion_nodes) == 1, 'got {}'.format(graph))
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
'got {}'.format(graph))
def _isHookExceptionOk(self, e):
se = str(e)
allowed = ("Could not export Python function",
"closures are not exportable")
for a in allowed:
if a in se:
return True
return False
def _compared_saved_loaded(self, m):
def extract_files(buffer):
# crack open the zip format to get at the main module code
archive = zipfile.ZipFile(buffer)
# check that we have no duplicate names
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
# unwrap all the code files into strings
code_files_str = filter(lambda x: x.endswith('.py'), files)
code_files_stream = (archive.open(f) for f in code_files_str)
code_files = ("".join([line.decode() for line in file]) for file in code_files_stream)
# unpickled all the debug files
debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
debug_files_stream = (archive.open(f) for f in debug_files_str)
debug_files = (pickle.load(f) for f in debug_files_stream)
return code_files, debug_files
# disable the hook while we parse code, otherwise we will re-enter the hook
with torch._jit_internal._disable_emit_hooks():
try:
# short-circuit if this is an empty function or module
if len(m.code) == 0:
return
if isinstance(m, torch._C.ScriptModule):
if len(m._method_names()) == 0:
return
# save the module to a buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)
# copy the data in the buffer so we can restore it later. This
# is because py2 and py3 have different semantics with zipfile
# and it's easier to just work with a fresh copy each time.
buffer_copy = buffer.getvalue()
code_files, debug_files = extract_files(buffer)
except RuntimeError as e:
if not self._isHookExceptionOk(e):
raise
else:
return
# import the model again (from a the copy we made of the original)
buffer2 = io.BytesIO(buffer_copy)
imported = torch.jit.load(buffer2)
# save it again
saved_module_buffer_2 = io.BytesIO()
torch.jit.save(imported, saved_module_buffer_2)
saved_module_buffer_2.seek(0)
code_files_2, debug_files_2 = extract_files(saved_module_buffer_2)
for a, b in zip(code_files, code_files_2):
self.assertMultiLineEqual(a, b)
if isinstance(m, torch._C.ScriptModule):
self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))
def emitFunctionHook(self, func):
# func has invalid names for export, skip the jitter check
if func.name == "<lambda>" or "aten::" in func.name:
return
self._compared_saved_loaded(func)
def emitModuleHook(self, module):
self._compared_saved_loaded(module)
def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
torch.jit.save(m, buffer)
m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
buffer.seek(0)
imported = torch.jit.load(buffer, map_location=map_location)
imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
if not also_test_file:
return imported
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
imported.save(f.name)
result = torch.jit.load(f.name, map_location=map_location)
finally:
os.unlink(f.name)
result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
return result
def assertGraphContains(self, graph, kind):
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
def perform_assert(graph, kind, actual, expected, consider_subgraphs):
if actual == expected:
return
subgraph = 'including' if consider_subgraphs else 'excluding'
raise AssertionError(
'{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
graph, actual, kind, subgraph, expected))
if consider_subgraphs:
strgraph = str(graph)
count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
perform_assert(graph, kind, count, num_kind_nodes,
consider_subgraphs)
return
def nodes(block):
out = []
for node in block.nodes():
if node.kind() == kind:
out.append(node)
for block in node.blocks():
out += nodes(block)
return out
out_nodes = nodes(graph)
perform_assert(graph, kind, len(out_nodes), num_kind_nodes,
consider_subgraphs)
def assertExpectedONNXGraph(self, g, *args, **kwargs):
g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
self.assertExpectedGraph(g, *args, **kwargs)
def assertExpectedGraph(self, trace, *args, **kwargs):
if isinstance(trace, torch._C.Graph):
graph = trace
else:
graph = trace.graph()
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
self.assertExpected(str(graph), *args, **kwargs)
def run_pass(self, name, trace):
if isinstance(trace, torch._C.Graph):
graph = trace
set_graph = False
else:
set_graph = True
graph = trace.graph()
torch._C._jit_pass_lint(graph)
result = getattr(torch._C, '_jit_pass_' + name)(graph)
if result is not None:
graph = result
torch._C._jit_pass_lint(graph)
if set_graph:
trace.set_graph(graph)
return graph
def get_frame_vars(self, frames_up):
frame = inspect.currentframe()
if not frame:
raise RuntimeError("failed to inspect frame")
Loading ...