"""Locally Optimal Block Preconditioned Conjugate Gradient methods.
"""
# Author: Pearu Peterson
# Created: February 2020
from typing import Dict, Tuple, Optional
import torch
from torch import Tensor
from . import _linalg_utils as _utils
from .overrides import has_torch_function, handle_torch_function
__all__ = ['lobpcg']
def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
# compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
F = D.unsqueeze(-2) - D.unsqueeze(-1)
F.diagonal(dim1=-2, dim2=-1).fill_(float('inf'))
F.pow_(-1)
# A.grad = U (D.grad + (U^T U.grad * F)) U^T
Ut = U.transpose(-1, -2).contiguous()
res = torch.matmul(
U,
torch.matmul(
torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F,
Ut
)
)
return res
def _polynomial_coefficients_given_roots(roots):
"""
Given the `roots` of a polynomial, find the polynomial's coefficients.
If roots = (r_1, ..., r_n), then the method returns
coefficients (a_0, a_1, ..., a_n (== 1)) so that
p(x) = (x - r_1) * ... * (x - r_n)
= x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
Note: for better performance requires writing a low-level kernel
"""
poly_order = roots.shape[-1]
poly_coeffs_shape = list(roots.shape)
# we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
# so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
# but we insert one extra coefficient to enable better vectorization below
poly_coeffs_shape[-1] += 2
poly_coeffs = roots.new_zeros(poly_coeffs_shape)
poly_coeffs[..., 0] = 1
poly_coeffs[..., -1] = 1
# perform the Horner's rule
for i in range(1, poly_order + 1):
# note that it is computationally hard to compute backward for this method,
# because then given the coefficients it would require finding the roots and/or
# calculating the sensitivity based on the Vieta's theorem.
# So the code below tries to circumvent the explicit root finding by series
# of operations on memory copies imitating the Horner's method.
# The memory copies are required to construct nodes in the computational graph
# by exploting the explicit (not in-place, separate node for each step)
# recursion of the Horner's method.
# Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(-1, poly_order - i + 1, i + 1)
poly_coeffs = poly_coeffs_new
return poly_coeffs.narrow(-1, 1, poly_order + 1)
def _polynomial_value(poly, x, zero_power, transition):
"""
A generic method for computing poly(x) using the Horner's rule.
Args:
poly (Tensor): the (possibly batched) 1D Tensor representing
polynomial coefficients such that
poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
zero_power (Tensor): the represenation of `x^0`. It is application-specific.
transition (Callable): the function that accepts some intermediate result `int_val`,
the `x` and a specific polynomial coefficient
`poly[..., k]` for some iteration `k`.
It basically performs one iteration of the Horner's rule
defined as `x * int_val + poly[..., k] * zero_power`.
Note that `zero_power` is not a parameter,
because the step `+ poly[..., k] * zero_power` depends on `x`,
whether it is a vector, a matrix, or something else, so this
functionality is delegated to the user.
"""
res = zero_power.clone()
for k in range(poly.size(-1) - 2, -1, -1):
res = transition(res, x, poly[..., k])
return res
def _matrix_polynomial_value(poly, x, zero_power=None):
"""
Evaluates `poly(x)` for the (batched) matrix input `x`.
Check out `_polynomial_value` function for more details.
"""
# matrix-aware Horner's rule iteration
def transition(curr_poly_val, x, poly_coeff):
res = x.matmul(curr_poly_val)
res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
return res
if zero_power is None:
zero_power = torch.eye(x.size(-1), x.size(-1), dtype=x.dtype, device=x.device) \
.view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
return _polynomial_value(poly, x, zero_power, transition)
def _vector_polynomial_value(poly, x, zero_power=None):
"""
Evaluates `poly(x)` for the (batched) vector input `x`.
Check out `_polynomial_value` function for more details.
"""
# vector-aware Horner's rule iteration
def transition(curr_poly_val, x, poly_coeff):
res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
return res
if zero_power is None:
zero_power = x.new_ones(1).expand(x.shape)
return _polynomial_value(poly, x, zero_power, transition)
def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
# compute a projection operator onto an orthogonal subspace spanned by the
# columns of U defined as (I - UU^T)
Ut = U.transpose(-2, -1).contiguous()
proj_U_ortho = -U.matmul(Ut)
proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
# compute U_ortho, a basis for the orthogonal complement to the span(U),
# by projecting a random [..., m, m - k] matrix onto the subspace spanned
# by the columns of U.
#
# fix generator for determinism
gen = torch.Generator(A.device)
# orthogonal complement to the span(U)
U_ortho = proj_U_ortho.matmul(
torch.randn(
(*A.shape[:-1], A.size(-1) - D.size(-1)),
dtype=A.dtype,
device=A.device,
generator=gen
)
)
U_ortho_t = U_ortho.transpose(-2, -1).contiguous()
# compute the coefficients of the characteristic polynomial of the tensor D.
# Note that D is diagonal, so the diagonal elements are exactly the roots
# of the characteristic polynomial.
chr_poly_D = _polynomial_coefficients_given_roots(D)
# the code belows finds the explicit solution to the Sylvester equation
# U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
# and incorporates it into the whole gradient stored in the `res` variable.
#
# Equivalent to the following naive implementation:
# res = A.new_zeros(A.shape)
# p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
# for k in range(1, chr_poly_D.size(-1)):
# p_res.zero_()
# for i in range(0, k):
# p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
# res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
#
# Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
# Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
# and we need to compute g(U_grad, A, U, D)
#
# The naive implementation is based on the paper
# Hu, Qingxi, and Daizhan Cheng.
# "The polynomial solution to the Sylvester matrix equation."
# Applied mathematics letters 19.9 (2006): 859-864.
#
# We can modify the computation of `p_res` from above in a more efficient way
# p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
# + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
# + ...
# + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
# Note that this saves us from redundant matrix products with A (elimination of matrix_power)
U_grad_projected = U_grad
series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
for k in range(1, chr_poly_D.size(-1)):
poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
series_acc += U_grad_projected * poly_D.unsqueeze(-2)
U_grad_projected = A.matmul(U_grad_projected)
# compute chr_poly_D(A) which essentially is:
#
# chr_poly_D_at_A = A.new_zeros(A.shape)
# for k in range(chr_poly_D.size(-1)):
# chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
#
# Note, however, for better performance we use the Horner's rule
chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
# compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
chr_poly_D_at_A_to_U_ortho = torch.matmul(
U_ortho_t,
torch.matmul(
chr_poly_D_at_A,
U_ortho
)
)
# we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
# Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
# Cholesky decomposition requires the input to be positive-definite.
# Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
# 1. `largest` == False, or
# 2. `largest` == True and `k` is even
# under the assumption that `A` has distinct eigenvalues.
#
# check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
chr_poly_D_at_A_to_U_ortho_L = torch.cholesky(
chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
)
# compute the gradient part in span(U)
res = _symeig_backward_complete_eigenspace(
D_grad, U_grad, A, D, U
)
# incorporate the Sylvester equation solution into the full gradient
# it resides in span(U_ortho)
res -= U_ortho.matmul(
chr_poly_D_at_A_to_U_ortho_sign * torch.cholesky_solve(
U_ortho_t.matmul(series_acc),
chr_poly_D_at_A_to_U_ortho_L
)
).matmul(Ut)
return res
def _symeig_backward(D_grad, U_grad, A, D, U, largest):
# if `U` is square, then the columns of `U` is a complete eigenspace
if U.size(-1) == U.size(-2):
return _symeig_backward_complete_eigenspace(
D_grad, U_grad, A, D, U
)
else:
return _symeig_backward_partial_eigenspace(
D_grad, U_grad, A, D, U, largest
)
class LOBPCGAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, # type: ignore[override]
A: Tensor,
k: Optional[int] = None,
B: Optional[Tensor] = None,
X: Optional[Tensor] = None,
n: Optional[int] = None,
iK: Optional[Tensor] = None,
niter: Optional[int] = None,
tol: Optional[float] = None,
largest: Optional[bool] = None,
method: Optional[str] = None,
tracker: Optional[None] = None,
ortho_iparams: Optional[Dict[str, int]] = None,
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None
) -> Tuple[Tensor, Tensor]:
# makes sure that input is contiguous for efficiency.
# Note: autograd does not support dense gradients for sparse input yet.
A = A.contiguous() if (not A.is_sparse) else A
if B is not None:
B = B.contiguous() if (not B.is_sparse) else B
D, U = _lobpcg(
A, k, B, X,
n, iK, niter, tol, largest, method, tracker,
ortho_iparams, ortho_fparams, ortho_bparams
)
ctx.save_for_backward(A, B, D, U, largest)
return D, U
@staticmethod
def backward(ctx, D_grad, U_grad):
A_grad = B_grad = None
grads = [None] * 14
A, B, D, U, largest = ctx.saved_tensors
# lobpcg.backward has some limitations. Checks for unsupported input
if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
raise ValueError(
'lobpcg.backward does not support sparse input yet.'
'Note that lobpcg.forward does though.'
)
if A.dtype in (torch.complex64, torch.complex128) or \
B is not None and B.dtype in (torch.complex64, torch.complex128):
raise ValueError(
'lobpcg.backward does not support complex input yet.'
'Note that lobpcg.forward does though.'
)
if B is not None:
raise ValueError(
'lobpcg.backward does not support backward with B != I yet.'
)
if largest is None:
largest = True
# symeig backward
if B is None:
A_grad = _symeig_backward(
D_grad, U_grad, A, D, U, largest
)
# A has index 0
grads[0] = A_grad
# B has index 2
grads[2] = B_grad
return tuple(grads)
def lobpcg(A: Tensor,
k: Optional[int] = None,
B: Optional[Tensor] = None,
X: Optional[Tensor] = None,
n: Optional[int] = None,
iK: Optional[Tensor] = None,
niter: Optional[int] = None,
tol: Optional[float] = None,
Loading ...