from enum import Enum, auto
import torch
from torch import Tensor
from ..utils import parametrize
from ..modules import Module
from .. import functional as F
from typing import Optional
__all__ = ['orthogonal', 'spectral_norm']
def _is_orthogonal(Q, eps=None):
n, k = Q.size(-2), Q.size(-1)
Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
# A reasonable eps, but not too large
eps = 10. * n * torch.finfo(Q.dtype).eps
return torch.allclose(Q.mH @ Q, Id, atol=eps)
def _make_orthogonal(A):
""" Assume that A is a tall matrix.
Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative
"""
X, tau = torch.geqrf(A)
Q = torch.linalg.householder_product(X, tau)
# The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
return Q
class _OrthMaps(Enum):
matrix_exp = auto()
cayley = auto()
householder = auto()
class _Orthogonal(Module):
base: Tensor
def __init__(self,
weight,
orthogonal_map: _OrthMaps,
*,
use_trivialization=True) -> None:
super().__init__()
# Note [Householder complex]
# For complex tensors, it is not possible to compute the tensor `tau` necessary for
# linalg.householder_product from the reflectors.
# To see this, note that the reflectors have a shape like:
# 0 0 0
# * 0 0
# * * 0
# which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
# to parametrize the unitary matrices. Saving tau on its own does not work either, because
# not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
# them as independent tensors we would not maintain the constraint
# An equivalent reasoning holds for rectangular matrices
if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
raise ValueError("The householder parametrization does not support complex tensors.")
self.shape = weight.shape
self.orthogonal_map = orthogonal_map
if use_trivialization:
self.register_buffer("base", None)
def forward(self, X: torch.Tensor) -> torch.Tensor:
n, k = X.size(-2), X.size(-1)
transposed = n < k
if transposed:
X = X.mT
n, k = k, n
# Here n > k and X is a tall matrix
if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
# We just need n x k - k(k-1)/2 parameters
X = X.tril()
if n != k:
# Embed into a square matrix
X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
A = X - X.mH
# A is skew-symmetric (or skew-hermitian)
if self.orthogonal_map == _OrthMaps.matrix_exp:
Q = torch.matrix_exp(A)
elif self.orthogonal_map == _OrthMaps.cayley:
# Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
Id = torch.eye(n, dtype=A.dtype, device=A.device)
Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
# Q is now orthogonal (or unitary) of size (..., n, n)
if n != k:
Q = Q[..., :k]
# Q is now the size of the X (albeit perhaps transposed)
else:
# X is real here, as we do not support householder with complex numbers
A = X.tril(diagonal=-1)
tau = 2. / (1. + (A * A).sum(dim=-2))
Q = torch.linalg.householder_product(A, tau)
# The diagonal of X is 1's and -1's
# We do not want to differentiate through this or update the diagonal of X hence the casting
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"):
Q = self.base @ Q
if transposed:
Q = Q.mT
return Q
@torch.autograd.no_grad()
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
if Q.shape != self.shape:
raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
f"Got a tensor of shape {Q.shape}.")
Q_init = Q
n, k = Q.size(-2), Q.size(-1)
transpose = n < k
if transpose:
Q = Q.mT
n, k = k, n
# We always make sure to always copy Q in every path
if not hasattr(self, "base"):
# Note [right_inverse expm cayley]
# If we do not have use_trivialization=True, we just implement the inverse of the forward
# map for the Householder. To see why, think that for the Cayley map,
# we would need to find the matrix X \in R^{n x k} such that:
# Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
# A = Y - Y.mH
# cayley(A)[:, :k]
# gives the original tensor. It is not clear how to do this.
# Perhaps via some algebraic manipulation involving the QR like that of
# Corollary 2.2 in Edelman, Arias and Smith?
if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
raise NotImplementedError("It is not possible to assign to the matrix exponential "
"or the Cayley parametrizations when use_trivialization=False.")
# If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
# Here Q is always real because we do not support householder and complex matrices.
# See note [Householder complex]
A, tau = torch.geqrf(Q)
# We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
# decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
# The diagonal of Q is the diagonal of R from the qr decomposition
A.diagonal(dim1=-2, dim2=-1).sign_()
# Equality with zero is ok because LAPACK returns exactly zero when it does not want
# to use a particular reflection
A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
return A.mT if transpose else A
else:
if n == k:
# We check whether Q is orthogonal
if not _is_orthogonal(Q):
Q = _make_orthogonal(Q)
else: # Is orthogonal
Q = Q.clone()
else:
# Complete Q into a full n x n orthogonal matrix
N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
Q = torch.cat([Q, N], dim=-1)
Q = _make_orthogonal(Q)
self.base = Q
# It is necessary to return the -Id, as we use the diagonal for the
# Householder parametrization. Using -Id makes:
# householder(torch.zeros(m,n)) == torch.eye(m,n)
# Poor man's version of eye_like
neg_Id = torch.zeros_like(Q_init)
neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
return neg_Id
def orthogonal(module: Module,
name: str = 'weight',
orthogonal_map: Optional[str] = None,
*,
use_trivialization: bool = True) -> Module:
r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
.. math::
\begin{align*}
Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
\end{align*}
where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
and the transpose when :math:`Q` is real-valued, and
:math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
and orthonormal rows otherwise.
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
- ``"matrix_exp"``/``"cayley"``:
the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
:math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
:math:`A` to give an orthogonal matrix.
- ``"householder"``: computes a product of Householder reflectors
(:func:`~torch.linalg.householder_product`).
``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
``"householder"``, but they are slower to compute for very thin or very wide matrices.
If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
``module.parametrizations.weight[0].base``. This helps the
convergence of the parametrized layer at the expense of some extra memory use.
See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
Initial value of :math:`Q`:
If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
Otherwise, the initial value is the result of the composition of all the registered
parametrizations applied to the original tensor.
.. note::
This function is implemented using the parametrization functionality
in :func:`~torch.nn.utils.parametrize.register_parametrization`.
.. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
.. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
Args:
module (nn.Module): module on which to register the parametrization.
name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
Default: ``True``.
Returns:
The original module with an orthogonal parametrization registered to the specified
weight
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _Orthogonal()
)
)
)
>>> # xdoctest: +IGNORE_WANT
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)
"""
weight = getattr(module, name, None)
if not isinstance(weight, Tensor):
raise ValueError(
"Module '{}' has no parameter or buffer with name '{}'".format(module, name)
)
# We could implement this for 1-dim tensors as the maps on the sphere
# but I believe it'd bite more people than it'd help
if weight.ndim < 2:
raise ValueError("Expected a matrix or batch of matrices. "
f"Got a tensor of {weight.ndim} dimensions.")
if orthogonal_map is None:
orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
orth_enum = getattr(_OrthMaps, orthogonal_map, None)
if orth_enum is None:
raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
f'Got: {orthogonal_map}')
orth = _Orthogonal(weight,
orth_enum,
use_trivialization=use_trivialization)
parametrize.register_parametrization(module, name, orth, unsafe=True)
return module
class _SpectralNorm(Module):
def __init__(
self,
weight: torch.Tensor,
n_power_iterations: int = 1,
dim: int = 0,
eps: float = 1e-12
) -> None:
super().__init__()
ndim = weight.ndim
if dim >= ndim or dim < -ndim:
raise IndexError("Dimension out of range (expected to be in range of "
f"[-{ndim}, {ndim - 1}] but got {dim})")
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.dim = dim if dim >= 0 else dim + ndim
self.eps = eps
if ndim > 1:
# For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
self.n_power_iterations = n_power_iterations
weight_mat = self._reshape_weight_to_matrix(weight)
h, w = weight_mat.size()
u = weight_mat.new_empty(h).normal_(0, 1)
v = weight_mat.new_empty(w).normal_(0, 1)
self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
# Start with u, v initialized to some reasonable values by performing a number
# of iterations of the power method
self._power_method(weight_mat, 15)
def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
# Precondition
assert weight.ndim > 1
if self.dim != 0:
# permute dim to front
weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
return weight.flatten(1)
@torch.autograd.no_grad()
def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
# See original note at torch/nn/utils/spectral_norm.py
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
# updated in power iteration **in-place**. This is very important
# because in `DataParallel` forward, the vectors (being buffers) are
# broadcast from the parallelized module to each module replica,
# which is a new module object created on the fly. And each replica
# runs its own spectral norm power iteration. So simply assigning
# the updated vectors to the module this function runs on will cause
# the update to be lost forever. And the next time the parallelized
# module is replicated, the same randomly initialized vectors are
# broadcast and used!
#
Loading ...