"""Locally Optimal Block Preconditioned Conjugate Gradient methods.
# Author: Pearu Peterson
# Created: February 2020

from typing import Dict, Optional, Tuple

import torch
from torch import Tensor
from . import _linalg_utils as _utils
from .overrides import handle_torch_function, has_torch_function

__all__ = ["lobpcg"]

def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
    # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
    F = D.unsqueeze(-2) - D.unsqueeze(-1)
    F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))

    # A.grad = U (D.grad + (U^T U.grad * F)) U^T
    Ut = U.mT.contiguous()
    res = torch.matmul(
        U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)

    return res

def _polynomial_coefficients_given_roots(roots):
    Given the `roots` of a polynomial, find the polynomial's coefficients.

    If roots = (r_1, ..., r_n), then the method returns
    coefficients (a_0, a_1, ..., a_n (== 1)) so that
    p(x) = (x - r_1) * ... * (x - r_n)
         = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0

    Note: for better performance requires writing a low-level kernel
    poly_order = roots.shape[-1]
    poly_coeffs_shape = list(roots.shape)
    # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
    # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
    # but we insert one extra coefficient to enable better vectorization below
    poly_coeffs_shape[-1] += 2
    poly_coeffs = roots.new_zeros(poly_coeffs_shape)
    poly_coeffs[..., 0] = 1
    poly_coeffs[..., -1] = 1

    # perform the Horner's rule
    for i in range(1, poly_order + 1):
        # note that it is computationally hard to compute backward for this method,
        # because then given the coefficients it would require finding the roots and/or
        # calculating the sensitivity based on the Vieta's theorem.
        # So the code below tries to circumvent the explicit root finding by series
        # of operations on memory copies imitating the Horner's method.
        # The memory copies are required to construct nodes in the computational graph
        # by exploting the explicit (not in-place, separate node for each step)
        # recursion of the Horner's method.
        # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
        poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
        out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
        out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
            -1, poly_order - i + 1, i + 1
        poly_coeffs = poly_coeffs_new

    return poly_coeffs.narrow(-1, 1, poly_order + 1)

def _polynomial_value(poly, x, zero_power, transition):
    A generic method for computing poly(x) using the Horner's rule.

      poly (Tensor): the (possibly batched) 1D Tensor representing
                     polynomial coefficients such that
                     poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
                     poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n

      x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.

      zero_power (Tensor): the representation of `x^0`. It is application-specific.

      transition (Callable): the function that accepts some intermediate result `int_val`,
                             the `x` and a specific polynomial coefficient
                             `poly[..., k]` for some iteration `k`.
                             It basically performs one iteration of the Horner's rule
                             defined as `x * int_val + poly[..., k] * zero_power`.
                             Note that `zero_power` is not a parameter,
                             because the step `+ poly[..., k] * zero_power` depends on `x`,
                             whether it is a vector, a matrix, or something else, so this
                             functionality is delegated to the user.

    res = zero_power.clone()
    for k in range(poly.size(-1) - 2, -1, -1):
        res = transition(res, x, poly[..., k])
    return res

def _matrix_polynomial_value(poly, x, zero_power=None):
    Evaluates `poly(x)` for the (batched) matrix input `x`.
    Check out `_polynomial_value` function for more details.

    # matrix-aware Horner's rule iteration
    def transition(curr_poly_val, x, poly_coeff):
        res = x.matmul(curr_poly_val)
        res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
        return res

    if zero_power is None:
        zero_power = torch.eye(
            x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
        ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))

    return _polynomial_value(poly, x, zero_power, transition)

def _vector_polynomial_value(poly, x, zero_power=None):
    Evaluates `poly(x)` for the (batched) vector input `x`.
    Check out `_polynomial_value` function for more details.

    # vector-aware Horner's rule iteration
    def transition(curr_poly_val, x, poly_coeff):
        res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
        return res

    if zero_power is None:
        zero_power = x.new_ones(1).expand(x.shape)

    return _polynomial_value(poly, x, zero_power, transition)

def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
    # compute a projection operator onto an orthogonal subspace spanned by the
    # columns of U defined as (I - UU^T)
    Ut = U.mT.contiguous()
    proj_U_ortho = -U.matmul(Ut)
    proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)

    # compute U_ortho, a basis for the orthogonal complement to the span(U),
    # by projecting a random [..., m, m - k] matrix onto the subspace spanned
    # by the columns of U.
    # fix generator for determinism
    gen = torch.Generator(A.device)

    # orthogonal complement to the span(U)
    U_ortho = proj_U_ortho.matmul(
            (*A.shape[:-1], A.size(-1) - D.size(-1)),
    U_ortho_t = U_ortho.mT.contiguous()

    # compute the coefficients of the characteristic polynomial of the tensor D.
    # Note that D is diagonal, so the diagonal elements are exactly the roots
    # of the characteristic polynomial.
    chr_poly_D = _polynomial_coefficients_given_roots(D)

    # the code belows finds the explicit solution to the Sylvester equation
    # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
    # and incorporates it into the whole gradient stored in the `res` variable.
    # Equivalent to the following naive implementation:
    # res = A.new_zeros(A.shape)
    # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
    # for k in range(1, chr_poly_D.size(-1)):
    #     p_res.zero_()
    #     for i in range(0, k):
    #         p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
    #     res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @  p_res @ U.t())
    # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
    # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
    # and we need to compute g(U_grad, A, U, D)
    # The naive implementation is based on the paper
    # Hu, Qingxi, and Daizhan Cheng.
    # "The polynomial solution to the Sylvester matrix equation."
    # Applied mathematics letters 19.9 (2006): 859-864.
    # We can modify the computation of `p_res` from above in a more efficient way
    # p_res =   U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
    #       + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
    #       + ...
    #       + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
    # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
    U_grad_projected = U_grad
    series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
    for k in range(1, chr_poly_D.size(-1)):
        poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
        series_acc += U_grad_projected * poly_D.unsqueeze(-2)
        U_grad_projected = A.matmul(U_grad_projected)

    # compute chr_poly_D(A) which essentially is:
    # chr_poly_D_at_A = A.new_zeros(A.shape)
    # for k in range(chr_poly_D.size(-1)):
    #     chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
    # Note, however, for better performance we use the Horner's rule
    chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)

    # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
    chr_poly_D_at_A_to_U_ortho = torch.matmul(
        U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
    # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
    # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
    # Cholesky decomposition requires the input to be positive-definite.
    # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
    # 1. `largest` == False, or
    # 2. `largest` == True and `k` is even
    # under the assumption that `A` has distinct eigenvalues.
    # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
    chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
    chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
        chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho

    # compute the gradient part in span(U)
    res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)

    # incorporate the Sylvester equation solution into the full gradient
    # it resides in span(U_ortho)
    res -= U_ortho.matmul(
        * torch.cholesky_solve(
            U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L

    return res

def _symeig_backward(D_grad, U_grad, A, D, U, largest):
    # if `U` is square, then the columns of `U` is a complete eigenspace
    if U.size(-1) == U.size(-2):
        return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
        return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)

class LOBPCGAutogradFunction(torch.autograd.Function):
    def forward(  # type: ignore[override]
        A: Tensor,
        k: Optional[int] = None,
        B: Optional[Tensor] = None,
        X: Optional[Tensor] = None,
        n: Optional[int] = None,
        iK: Optional[Tensor] = None,
        niter: Optional[int] = None,
        tol: Optional[float] = None,
        largest: Optional[bool] = None,
        method: Optional[str] = None,
        tracker: None = None,
        ortho_iparams: Optional[Dict[str, int]] = None,
        ortho_fparams: Optional[Dict[str, float]] = None,
        ortho_bparams: Optional[Dict[str, bool]] = None,
    ) -> Tuple[Tensor, Tensor]:

        # makes sure that input is contiguous for efficiency.
        # Note: autograd does not support dense gradients for sparse input yet.
        A = A.contiguous() if (not A.is_sparse) else A
        if B is not None:
            B = B.contiguous() if (not B.is_sparse) else B

        D, U = _lobpcg(

        ctx.save_for_backward(A, B, D, U)
        ctx.largest = largest

        return D, U

    def backward(ctx, D_grad, U_grad):
        A_grad = B_grad = None
        grads = [None] * 14

        A, B, D, U = ctx.saved_tensors
        largest = ctx.largest

        # lobpcg.backward has some limitations. Checks for unsupported input
        if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
            raise ValueError(
                "lobpcg.backward does not support sparse input yet."
                "Note that lobpcg.forward does though."
        if (
            A.dtype in (torch.complex64, torch.complex128)
            or B is not None
            and B.dtype in (torch.complex64, torch.complex128)
            raise ValueError(
                "lobpcg.backward does not support complex input yet."
                "Note that lobpcg.forward does though."
        if B is not None:
            raise ValueError(
                "lobpcg.backward does not support backward with B != I yet."

        if largest is None:
            largest = True

        # symeig backward
        if B is None:
            A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)

        # A has index 0
        grads[0] = A_grad
        # B has index 2
        grads[2] = B_grad
        return tuple(grads)
