import abc
import enum
import functools
import inspect
import itertools
import types
from typing import Dict, List

import torch

from .. import variables
from ..bytecode_transformation import create_instruction
from ..exc import unimplemented
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import istensor, istype, make_cell
from .base import typestr, VariableTracker

def wrap_bound_arg(tx, val, options, source=None):
    # Source propagation is best effort since not every object we encounter has a source to begin with.
    assert (
        "source" not in options
    ), "Source needs to be separate from options due to recursive calls for lists/dicts"

    if isinstance(val, dict):
        return variables.ConstDictVariable(
                k: wrap_bound_arg(tx, v, options, source=getattr(v, "source", None))
                for k, v in val.items()
    elif isinstance(val, (tuple, list)):
        cls = variables.BaseListVariable.cls_for(type(val))
        return cls(
                wrap_bound_arg(tx, x, options, source=getattr(x, "source", None))
                for x in val

    if variables.ConstantVariable.is_literal(val) or istype(
        val, (torch.Size, torch.device, torch.dtype)
        return variables.ConstantVariable(val, **options)
    elif isinstance(val, types.FunctionType):
        return variables.UserFunctionVariable(val, source=source, **options)
    elif isinstance(val, enum.Enum):
        return variables.EnumVariable(val, source=source, **options)
    elif isinstance(val, (type, abc.ABCMeta)):
        return variables.UserDefinedClassVariable(val, source=source, **options)
    elif istensor(val):
        from torch._dynamo.variables.builder import VariableBuilder

        return VariableBuilder(tx, source=source, **options)(val)
        assert isinstance(val, VariableTracker), typestr(val)
        return val

def wrap_args_kwargs(tx, result, options):
    for k, v in list(result.items()):
        if isinstance(v, (tuple, dict)):
            # args/kwargs
            result[k] = wrap_bound_arg(tx, v, options)

def init_cellvars(parent, result, code):
    closure_cells = dict()
    side_effects = parent.output.side_effects

    for name in code.co_cellvars:
        closure_cells[name] = side_effects.track_cell_new()
        if name in result:
            side_effects.store_cell(closure_cells[name], result.pop(name))

    return closure_cells

class BaseUserFunctionVariable(VariableTracker):
    def get_filename(self):
        return self.get_code().co_filename

    def get_name(self):
        return self.get_code().co_name

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        return tx.inline_user_function_return(
            self, list(self.self_args()) + list(args), kwargs

    def num_parameters(self):
        return len(inspect.signature(self.get_function()).parameters)

    def closure_vars(self, tx):
        return {}

class UserFunctionVariable(BaseUserFunctionVariable):
    """Some unsupported user-defined global function"""

    def __init__(self, fn, is_constant=False, **kwargs):
        if getattr(fn, "_dynamo_marked_constant", False):
            # This method should be treated as a constant for the purposes of compilation
            self.is_constant = True
            self.is_constant = False

        assert isinstance(
            fn, (types.FunctionType, torch.jit.ScriptFunction)
        ), f"expected FunctionType found {typestr(fn)} {fn}"
        # unpack @torch._dynamo.optimize()(fn) wrapped function
        fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
        # unpack torch.jit.script_if_tracing
        if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
            fn = inspect.getattr_static(fn, "__original_fn", fn)
        self.fn: types.FunctionType = fn

    def self_args(self):
        return []

    def get_function(self):
        return self.fn

    def get_code(self):
        return self.fn.__code__

    def python_type(self):
        return types.FunctionType

    def has_self(self):
        return getattr(self.fn, "__self__", None) is not None

    def get_globals(self):
        return self.fn.__globals__

    def bind_args(self, parent, args, kwargs):
        assert not self.is_constant
        options = VariableTracker.propagate([self])
        tx = parent.output.root_tx
        wrap = functools.partial(wrap_bound_arg, tx=tx, options=options)

        fn: types.FunctionType = self.fn
        defaults = fn.__defaults__ or []
        defaults_sources = [
            None if self.source is None else DefaultsSource(self.source, idx)
            for idx, _ in enumerate(defaults)
        fake_func = types.FunctionType(
                    wrap(val=arg, source=source)
                    for arg, source in zip(defaults, defaults_sources)
        if fn.__kwdefaults__:
            kwdefaults_sources = {
                k: None
                if self.source is None
                else DefaultsSource(self.source, k, is_kw=True)
                for k in fn.__kwdefaults__
            fake_func.__kwdefaults__ = {
                k: wrap(val=v, source=kwdefaults_sources[k])
                for k, v in fn.__kwdefaults__.items()

        bound = inspect.signature(fake_func).bind(*args, **kwargs)
        result = dict(bound.arguments.items())

        wrap_args_kwargs(tx, result, options)
        closure_cells = init_cellvars(parent, result, fn.__code__)
        closure = self.fn.__closure__ or ()
        assert len(closure) == len(self.fn.__code__.co_freevars)
        for idx, name, cell in zip(
            itertools.count(), self.fn.__code__.co_freevars, closure
            if name == "__class__":
                source = AttrSource(self.source, "__class__") if self.source else None
                result[name] = variables.UserDefinedClassVariable(
                var = tx.match_nested_cell(name, cell)
                if var is not None:
                    # optimization for cleaner codegen
                    result[name] = var
                elif self.source:
                    from .builder import VariableBuilder

                    side_effects = parent.output.side_effects
                    if cell in side_effects:
                        out = side_effects[cell]
                        closure_cell = GetItemSource(
                            AttrSource(self.source, "__closure__"), idx
                        closure_cell_contents = AttrSource(
                            closure_cell, "cell_contents"
                        contents_var = VariableBuilder(parent, closure_cell_contents)(

                        if (
                            not in tx.mutated_closure_cell_contents
                            # Optimistically don't allocate the cell, to
                            # reduce the number of side effects.  This is
                            # important for cond, as without it, any accesses
                            # to closures create side effects and cond doesn't
                            # support side effects.  If we're wrong and this
                            # closure cell gets written to, we will restart
                            # the analysis with this cell's name in the
                            # mutated list here
                            result[name] = contents_var

                        # cells are written to with "cell_contents",
                        # so the source should just be the closure_cell, not its contents
                        out = side_effects.track_cell_existing(closure_cell, cell)

                    result[name] = out

                    unimplemented("inline with __closure__")

        return result, closure_cells

    def export_freevars(self, parent, child):

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        if self.is_constant:
            options = VariableTracker.propagate(self, args, kwargs.values())
            return invoke_and_store_as_constant(
                tx, self.fn, self.get_name(), options, args, kwargs

        return super().call_function(tx, args, kwargs)

class UserMethodVariable(UserFunctionVariable):
    """Some unsupported user-defined method"""

    def __init__(self, fn, obj, **kwargs):
        super().__init__(fn=fn, **kwargs)
        self.obj = obj

    def __str__(self):
        return f"{self.__class__.__name__}({self.fn}, {self.obj})"

    def self_args(self):
        return [self.obj]

    def python_type(self):
        return types.MethodType

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        if isinstance(self.obj, variables.NNModuleVariable):
            module_attr = getattr(self.fn, "__module__", "")
            if (
                module_attr is not None
                and module_attr.startswith("torch.nn.")
                or self.is_constant
                return self.obj.call_method(
                    tx, self.fn.__name__, args, kwargs, constant=self.is_constant
        return super().call_function(tx, args, kwargs)

    def num_parameters(self):
        return super().num_parameters() - 1

class WrappedUserMethodVariable(UserMethodVariable):
    def __init__(self, wrapped, context, **kwargs):
        kwargs.pop("fn", None)
        kwargs.pop("obj", None)
        super().__init__(wrapped.fn, wrapped.obj, **kwargs)
        self.wrapped = wrapped
        self.context = context

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        result = super().call_function(tx, args, kwargs)
        return result

class WrappedUserFunctionVariable(UserFunctionVariable):
    def __init__(self, wrapped, context, **kwargs):
        kwargs.pop("fn", None)
        kwargs.pop("obj", None)
        super().__init__(wrapped.fn, **kwargs)
        self.wrapped = wrapped
        self.context = context

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        result = super().call_function(tx, args, kwargs)
        return result

def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
    def convert(x):
        if isinstance(x, variables.TensorVariable):
            return x.get_real_value()
        return x.as_python_constant()

    args = [convert(x) for x in args]
    kwargs = {k: convert(v) for k, v in kwargs.items()}
    res = fn(*args, **kwargs)
    return tx.output.register_attr_or_module(
