# Torch
from torch.autograd import Variable
from torch.autograd.function import _nested_map
from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401

from torch.onnx import OperatorExportTypes
import torch
import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
import zipfile
import functools

# Testing utils
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_WINDOWS, \
    freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401

# Standard library
from contextlib import contextmanager
from functools import reduce
from torch._six import StringIO
from collections import defaultdict

import inspect
import io
import math
import os
import pickle
import sys
import tempfile
import textwrap
from typing import Any, Dict, List

RUN_CUDA = torch.cuda.is_available()
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
# HIP supports half, no version check necessary
if torch.cuda.is_available() and not torch.version.hip:
    CUDA_VERSION = torch._C._cuda_getCompiledVersion()
    for d in range(torch.cuda.device_count()):
        major = torch.cuda.get_device_capability(d)[0]
        if (major < 6):
            RUN_CUDA_HALF = False

def execWrapper(code, glob, loc):
    exec(code, glob, loc)

def do_input_map(fn, input):
    return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)

def clear_class_registry():
    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()

def get_execution_plan(graph_executor_state):
    execution_plans = list(graph_executor_state.execution_plans.values())
    num_plans = len(execution_plans)
    if num_plans != 1:
        raise RuntimeError('This test assumes this GraphExecutor should '
                           'only have one execution plan, got: {}'.format(num_plans))
    return execution_plans[0]

class _AssertRaisesRegexWithHighlightContext(object):
    A context manager that is useful for checking that error messages highlight
    the correct part of the source code.

    def __init__(self, test_case, exception, regex, highlight):
        self.test_case = test_case
        self.exception_type = exception
        self.regex = regex
        self.highlight = highlight

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
            if type:
                raise value

        if self.highlight:

        return True

FUSION_GROUP = "prim::TensorExprGroup"

class JitTestCase(JitCommonTestCase):
    _do_cuda_memory_leak_check = True
    _restored_warnings = False

    class capture_stdout(list):
        Replace sys.stdout with a temporary StringIO
        def __enter__(self):
            self.sys_stdout = sys.stdout
            self.stringio = StringIO()
            sys.stdout = self.stringio
            return self

        def __exit__(self, *args):
            del self.stringio
            sys.stdout = self.sys_stdout

    def setHooks(self):
        torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)

    def clearHooks(self):
        torch._C._jit_set_emit_hooks(None, None)

    def setUp(self):
        # unittest overrides all warning filters and forces all of them to show up
        # after we install our own to silence those coming from inside PyTorch.
        # This will ensure that our filter still takes precedence.
        if not JitTestCase._restored_warnings:
            JitTestCase._restored_warnings = True

    def tearDown(self):
        # needs to be cleared because python might be unloaded before
        # the callback gets destucted

    def assertAllFused(self, graph, except_for=()):

        # note this helper collects nodes on 'fast path' only
        # i.e. the true blocks of specialized checks
        def get_nodes_and_parents_recursively(block, kind, acc):
            for node in block.nodes():
                if node.kind() == kind:
                elif node.kind() == 'prim::DifferentiableGraph':
                    get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
                elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
                                                    node.inputs().__next__().node().kind() == 'prim::TypeCheck' or 
                                                    node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'):
                    get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
                    for inner_block in node.blocks():
                        get_nodes_and_parents_recursively(inner_block, kind, acc)

        allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
                         'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)

        fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list)
        get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
        self.assertTrue(len(fusion_groups) == 1, 'got {}'.format(graph))
        (graph, fusion_nodes) = list(fusion_groups.items())[0]
        # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes`
        self.assertTrue(len(fusion_nodes) == 1, 'got {}'.format(graph))
        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
                        'got {}'.format(graph))

    def _isHookExceptionOk(self, e):
        se = str(e)
        allowed = ("Could not export Python function",
                   "closures are not exportable")
        for a in allowed:
            if a in se:
                return True
        return False

    def _compared_saved_loaded(self, m):
        def extract_files(buffer):
            # crack open the zip format to get at the main module code
            archive = zipfile.ZipFile(buffer)
            # check that we have no duplicate names
            self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
            files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
            # unwrap all the code files into strings
            code_files_str = filter(lambda x: x.endswith('.py'), files)
            code_files_stream = (archive.open(f) for f in code_files_str)
            code_files = ("".join([line.decode() for line in file]) for file in code_files_stream)

            # unpickled all the debug files
            debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
            debug_files_stream = (archive.open(f) for f in debug_files_str)
            debug_files = (pickle.load(f) for f in debug_files_stream)
            return code_files, debug_files

        # disable the hook while we parse code, otherwise we will re-enter the hook
        with torch._jit_internal._disable_emit_hooks():
                # short-circuit if this is an empty function or module
                if len(m.code) == 0:
                if isinstance(m, torch._C.ScriptModule):
                    if len(m._method_names()) == 0:

                # save the module to a buffer
                buffer = io.BytesIO()
                torch.jit.save(m, buffer)
                # copy the data in the buffer so we can restore it later. This
                # is because py2 and py3 have different semantics with zipfile
                # and it's easier to just work with a fresh copy each time.
                buffer_copy = buffer.getvalue()

                code_files, debug_files = extract_files(buffer)

            except RuntimeError as e:
                if not self._isHookExceptionOk(e):

            # import the model again (from a the copy we made of the original)
            buffer2 = io.BytesIO(buffer_copy)
            imported = torch.jit.load(buffer2)

            # save it again
            saved_module_buffer_2 = io.BytesIO()
            torch.jit.save(imported, saved_module_buffer_2)

            code_files_2, debug_files_2 = extract_files(saved_module_buffer_2)

            for a, b in zip(code_files, code_files_2):
                self.assertMultiLineEqual(a, b)

            if isinstance(m, torch._C.ScriptModule):
                self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))

    def emitFunctionHook(self, func):
        # func has invalid names for export, skip the jitter check
        if func.name == "<lambda>" or "aten::" in func.name:

    def emitModuleHook(self, module):

    def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
        buffer = io.BytesIO()
        m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
        torch.jit.save(m, buffer)
        m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
        imported = torch.jit.load(buffer, map_location=map_location)
        imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)

        if not also_test_file:
            return imported

        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
        # close the file after creation and try to remove it manually
        f = tempfile.NamedTemporaryFile(delete=False)
            result = torch.jit.load(f.name, map_location=map_location)

        result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
        return result

    def assertGraphContains(self, graph, kind):
        self.assertTrue(any(n.kind() == kind for n in graph.nodes()))

    def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
        def perform_assert(graph, kind, actual, expected, consider_subgraphs):
            if actual == expected:
            subgraph = 'including' if consider_subgraphs else 'excluding'
            raise AssertionError(
                '{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
                    graph, actual, kind, subgraph, expected))

        if consider_subgraphs:
            strgraph = str(graph)
            count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
            perform_assert(graph, kind, count, num_kind_nodes,

        def nodes(block):
            out = []
            for node in block.nodes():
                if node.kind() == kind:
                for block in node.blocks():
                    out += nodes(block)
            return out

        out_nodes = nodes(graph)
        perform_assert(graph, kind, len(out_nodes), num_kind_nodes,

    def assertExpectedONNXGraph(self, g, *args, **kwargs):
        g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
        self.assertExpectedGraph(g, *args, **kwargs)

    def assertExpectedGraph(self, trace, *args, **kwargs):
        if isinstance(trace, torch._C.Graph):
            graph = trace
            graph = trace.graph()

        graph = torch._C._jit_pass_canonicalize(graph)
        self.assertExpected(str(graph), *args, **kwargs)

    def run_pass(self, name, trace):
        if isinstance(trace, torch._C.Graph):
            graph = trace
            set_graph = False
            set_graph = True
            graph = trace.graph()

        result = getattr(torch._C, '_jit_pass_' + name)(graph)
        if result is not None:
            graph = result

        if set_graph:
        return graph

    def get_frame_vars(self, frames_up):
        frame = inspect.currentframe()
        if not frame:
            raise RuntimeError("failed to inspect frame")
