import torch
from torch._C import ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple

import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils

from functools import partial
from functools import wraps

import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented

from typing import Optional

import numpy
import math
import warnings

# see Note [Edit Symbolic Files] in symbolic_helper.py

# This file exports ONNX ops for opset 9
# Opset 9 is supported by ONNX release 1.4.1
# release on 01/23/19

# Note [Pointwise by scalar]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# What happens if you add a tensor with a constant (e.g., x + 2)?  There are
# some moving parts to implementing the ONNX translation in this case:
#   - By the time we get the scalar in a symbolic function here, it is no longer
#     a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
#     want it to be a zero dim tensor but this change has not happened yet.)
#     However, the type of this scalar is *exactly* what the user wrote in
#     Python, which may not match the tensor it is being added to.  PyTorch
#     will do implicit conversions on scalars; however, ONNX will not, so
#     we must do the conversion ourselves.  This is what _if_scalar_type_as
#     does.
#   - Dispatch to these functions takes advantage an outrageous coincidence
#     between the tensor and scalar name.  When we add two tensors together,
#     you get the dispatch:
#       add(*[self, other], **{"alpha": alpha})
#     When you add a tensor and a scalar, you get the dispatch:
#       add(*[self], **{"other": other, "alpha": alpha})
#     By having the argument name line up with the name of the scalar attribute
#     if it exists, we can write a single function for both overloads.

# used to represent "missing" optional inputs
def unused(g):
    n = g.op("prim::Constant")
    return n

def _shape_as_tensor(g, input):
    return g.op('Shape', input)

def _reshape_from_tensor(g, input, shape):
    return g.op('Reshape', input, shape)

def reshape(g, self, shape):
    return view(g, self, shape)

def reshape_as(g, self, other):
    shape = g.op('Shape', other)
    return reshape(g, self, shape)

def add(g, self, other, alpha=None):
    if sym_help._is_value(self) and sym_help._is_tensor_list(self):
        return sym_help._onnx_opset_unsupported_detailed('Add', 9, 11, 'Add between list of tensors not supported')

    # default alpha arg is to allow no-alpha add (aten add st overload no alpha)
    if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
        return _unimplemented("add", "alpha != 1")
    return g.op("Add", self, other)

def sub(g, self, other, alpha=None):
    # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
    if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
        return _unimplemented("sub", "alpha != 1")
    return g.op("Sub", self, other)

def rsub(g, self, other, alpha=None):
    return sub(g, other, self, alpha=alpha)

def mul(g, self, other):
    return g.op("Mul", self, other)

def div(g, self, other):
    return true_divide(g, self, other)

def floor_divide(g, self, other):
    out = g.op('Div', self, other)
    # the correct operation is truncate, which is not supported in ONNX,
    # we cannot call floor since it will behave differently for negative numbers
    # (eg. -0.1 should become -0 )
    # - if scalar_type information are not available, assume that
    # we need to call floor (treat as float)
    out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Long'])

    # Matching PyTorch's behavior:
    # - if self is fp the output's type is self's type
    # - if self is not fp and other is fp, the output is of type 'Float'
    # - self is not fp and other is not fp, the output's type is self's output type
    # - the output type defaults to Float
    scalar_type = self.type().scalarType()

    if scalar_type is not None:
        if not sym_help._is_fp(self) and \
           other.type().scalarType() is not None and \
            out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Float'])
            out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx[scalar_type])
        out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Float'])
    return out

def floordiv(g, self, other):
    return floor_divide(g, self, other)

# Division where both inputs are cast to floating types
# If both inputs are floating, performs div as usual
# If only one input is a floating type, the other input is cast to its type
# If neither input is a floating type, both inputs are cast to the default scalar type
def true_divide(g, self, other):
    # Case 1: both values are floating
    # Performs div as usual
    if sym_help._is_fp(self) and sym_help._is_fp(other):
        return g.op("Div", self, other)

    # Case 2: self is floating, other is not
    # Casts other to self's dtype
    if sym_help._is_fp(self):
        other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
        return g.op("Div", self, other)

    # Case 3: other is floating, self is not
    # Casts self to other's dtype
    if sym_help._is_fp(other):
        self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()])
        return g.op("Div", self, other)

    # Case 4: neither is floating
    # Casts both inputs to the default scalar type
    scalar_type = torch.get_default_dtype()
    onnx_scalar_type = sym_help.cast_pytorch_to_onnx['Float']
    assert scalar_type is torch.float or scalar_type is torch.double
    if torch.get_default_dtype() is torch.double:
        onnx_scalar_type = sym_help.cast_pytorch_to_onnx['Double']

    self = g.op("Cast", self, to_i=onnx_scalar_type)
    other = g.op("Cast", other, to_i=onnx_scalar_type)
    return g.op("Div", self, other)

def reciprocal(g, self):
    return g.op("Div", torch.ones(1), self)

@parse_args('v', 'i')
def cat(g, tensor_list, dim):
    tensors = sym_help._unpack_list(tensor_list)
    return g.op("Concat", *tensors, axis_i=dim)

@parse_args('v', 'i')
def stack(g, tensor_list, dim):
    unsqueezed = [sym_help._unsqueeze_helper(g, t, [dim]) for t in sym_help._unpack_list(tensor_list)]
    return g.op("Concat", *unsqueezed, axis_i=dim)

def _list(g, self):
    return self

def mm(g, self, other):
    # Create a dummy C tensor. Only needed for API purposes, the value is
    # since beta = 0
    C = g.op("Constant", value_t=torch.tensor([1]))
    return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)

def bmm(g, self, other):
    return g.op("MatMul", self, other)

def matmul(g, self, other):
    return g.op("MatMul", self, other)

@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
    dtype = None
    self_dtype = sym_help._try_get_scalar_type(self)
    mat1_dtype = sym_help._try_get_scalar_type(mat1)
    mat2_dtype = sym_help._try_get_scalar_type(mat2)
    if self_dtype is not None:
        dtype = self_dtype
    elif mat1_dtype is not None:
        dtype = mat1_dtype
    elif mat2_dtype is not None:
        dtype = mat2_dtype

    mat1_rank = sym_help._get_tensor_rank(mat1)
    mat2_rank = sym_help._get_tensor_rank(mat2)

    def isNotNoneAnd(v, u):
        return v is not None and v != u

    if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)):
        dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]

        res1 = g.op("MatMul", mat1, mat2)
        res2 = self

        alpha = sym_help._scalar(alpha)
        beta = sym_help._scalar(beta)

        if alpha != 1:
            alpha = g.op("Constant",
                         value_t=torch.tensor(alpha, dtype=dtype))
            res1 = g.op("Mul", res1, alpha)
        if beta != 1:
            beta = g.op("Constant",
                        value_t=torch.tensor(sym_help._scalar(beta), dtype=dtype))
            res2 = g.op("Mul", res2, beta)

        return g.op("Add", res1, res2)

    return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))

def neg(g, self):
    return g.op("Neg", self)

def sqrt(g, self):
    return g.op("Sqrt", self)

def rsqrt(g, self):
    return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self))

def tanh(g, self):
    return g.op("Tanh", self)

def sin(g, self):
    return g.op("Sin", self)

def cos(g, self):
    return g.op("Cos", self)

def tan(g, self):
    return g.op("Tan", self)

def asin(g, self):
    return g.op("Asin", self)

def acos(g, self):
    return g.op("Acos", self)

def atan(g, self):
    return g.op("Atan", self)

def sigmoid(g, self):
    return g.op("Sigmoid", self)

def sign(g, self):
    return g.op("Sign", self)

def _slice(g, input, axes, starts, ends):
    assert len(starts) == len(ends)
    if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807:
        return input
    return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)

def _maybe_cast_reduce_op_input(g, self):
    dtype = self.type().scalarType()
    # This check only covers traced modules where dtype is present
    if dtype is not None:
        # pytorch reduce-ops cast all other integral types to int64
        if not sym_help._is_fp(self) and not (dtype == 'Long'):
            self = _cast_Long(g, self, False)  # type: ignore
    return self

def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
    def symbolic(g, self, dim=None, keepdim=None):
        self = _maybe_cast_reduce_op_input(g, self)
        if dim is None:
            # all-reduce path
            return g.op(onnx_op_name, self, keepdims_i=0)
            # dim-reduce path
            desc = 'is' if allow_multi_dim_support else 'i'
            dim, keepdim = sym_help._get_const(dim, desc, 'dim'), sym_help._get_const(keepdim, 'i', 'keepdim')
            dim_list = dim if allow_multi_dim_support else [dim]
            return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
    return symbolic

def overload_by_arg_count(fn):
    def wrapper(g, *args):
        overloads = fn(g, *args)
        last_exception = None
        for overload in overloads:
            arg_descriptors = overload._arg_descriptors
            if len(arg_descriptors) == len(args):
                return overload(g, *args)
        raise NotImplementedError("Unknown aten::{} signature".format(fn.__name__))
