import functools
import math
import numbers
import operator
import weakref
from typing import List
import torch
import torch.nn.functional as F
from torch.distributions import constraints
from torch.distributions.utils import (_sum_rightmost, broadcast_all,
lazy_property, tril_matrix_to_vec,
vec_to_tril_matrix)
from torch.nn.functional import pad
from torch.nn.functional import softplus
__all__ = [
'AbsTransform',
'AffineTransform',
'CatTransform',
'ComposeTransform',
'CorrCholeskyTransform',
'ExpTransform',
'IndependentTransform',
'LowerCholeskyTransform',
'PowerTransform',
'ReshapeTransform',
'SigmoidTransform',
'TanhTransform',
'SoftmaxTransform',
'StackTransform',
'StickBreakingTransform',
'Transform',
'identity_transform',
]
class Transform(object):
"""
Abstract class for invertable transformations with computable log
det jacobians. They are primarily used in
:class:`torch.distributions.TransformedDistribution`.
Caching is useful for transforms whose inverses are either expensive or
numerically unstable. Note that care must be taken with memoized values
since the autograd graph may be reversed. For example while the following
works with or without caching::
y = t(x)
t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
However the following will error when caching due to dependency reversal::
y = t(x)
z = t.inv(y)
grad(z.sum(), [y]) # error because z is x
Derived classes should implement one or both of :meth:`_call` or
:meth:`_inverse`. Derived classes that set `bijective=True` should also
implement :meth:`log_abs_det_jacobian`.
Args:
cache_size (int): Size of cache. If zero, no caching is done. If one,
the latest single value is cached. Only 0 and 1 are supported.
Attributes:
domain (:class:`~torch.distributions.constraints.Constraint`):
The constraint representing valid inputs to this transform.
codomain (:class:`~torch.distributions.constraints.Constraint`):
The constraint representing valid outputs to this transform
which are inputs to the inverse transform.
bijective (bool): Whether this transform is bijective. A transform
``t`` is bijective iff ``t.inv(t(x)) == x`` and
``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
the codomain. Transforms that are not bijective should at least
maintain the weaker pseudoinverse properties
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
sign (int or Tensor): For bijective univariate transforms, this
should be +1 or -1 depending on whether transform is monotone
increasing or decreasing.
"""
bijective = False
domain: constraints.Constraint
codomain: constraints.Constraint
def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError('cache_size must be 0 or 1')
super(Transform, self).__init__()
@property
def event_dim(self):
if self.domain.event_dim == self.codomain.event_dim:
return self.domain.event_dim
raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
@property
def inv(self):
"""
Returns the inverse :class:`Transform` of this transform.
This should satisfy ``t.inv.inv is t``.
"""
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
return inv
@property
def sign(self):
"""
Returns the sign of the determinant of the Jacobian, if applicable.
In general this only makes sense for bijective transforms.
"""
raise NotImplementedError
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
if type(self).__init__ is Transform.__init__:
return type(self)(cache_size=cache_size)
raise NotImplementedError("{}.with_cache is not implemented".format(type(self)))
def __eq__(self, other):
return self is other
def __ne__(self, other):
# Necessary for Python2
return not self.__eq__(other)
def __call__(self, x):
"""
Computes the transform `x => y`.
"""
if self._cache_size == 0:
return self._call(x)
x_old, y_old = self._cached_x_y
if x is x_old:
return y_old
y = self._call(x)
self._cached_x_y = x, y
return y
def _inv_call(self, y):
"""
Inverts the transform `y => x`.
"""
if self._cache_size == 0:
return self._inverse(y)
x_old, y_old = self._cached_x_y
if y is y_old:
return x_old
x = self._inverse(y)
self._cached_x_y = x, y
return x
def _call(self, x):
"""
Abstract method to compute forward transformation.
"""
raise NotImplementedError
def _inverse(self, y):
"""
Abstract method to compute inverse transformation.
"""
raise NotImplementedError
def log_abs_det_jacobian(self, x, y):
"""
Computes the log det jacobian `log |dy/dx|` given input and output.
"""
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__ + '()'
def forward_shape(self, shape):
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return shape
def inverse_shape(self, shape):
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return shape
class _InverseTransform(Transform):
"""
Inverts a single :class:`Transform`.
This class is private; please instead use the ``Transform.inv`` property.
"""
def __init__(self, transform: Transform):
super(_InverseTransform, self).__init__(cache_size=transform._cache_size)
self._inv: Transform = transform
@constraints.dependent_property(is_discrete=False)
def domain(self):
assert self._inv is not None
return self._inv.codomain
@constraints.dependent_property(is_discrete=False)
def codomain(self):
assert self._inv is not None
return self._inv.domain
@property
def bijective(self):
assert self._inv is not None
return self._inv.bijective
@property
def sign(self):
assert self._inv is not None
return self._inv.sign
@property
def inv(self):
return self._inv
def with_cache(self, cache_size=1):
assert self._inv is not None
return self.inv.with_cache(cache_size).inv
def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
assert self._inv is not None
return self._inv == other._inv
def __repr__(self):
return f"{self.__class__.__name__}({repr(self._inv)})"
def __call__(self, x):
assert self._inv is not None
return self._inv._inv_call(x)
def log_abs_det_jacobian(self, x, y):
assert self._inv is not None
return -self._inv.log_abs_det_jacobian(y, x)
def forward_shape(self, shape):
return self._inv.inverse_shape(shape)
def inverse_shape(self, shape):
return self._inv.forward_shape(shape)
class ComposeTransform(Transform):
"""
Composes multiple transforms in a chain.
The transforms being composed are responsible for caching.
Args:
parts (list of :class:`Transform`): A list of transforms to compose.
cache_size (int): Size of cache. If zero, no caching is done. If one,
the latest single value is cached. Only 0 and 1 are supported.
"""
def __init__(self, parts: List[Transform], cache_size=0):
if cache_size:
parts = [part.with_cache(cache_size) for part in parts]
super(ComposeTransform, self).__init__(cache_size=cache_size)
self.parts = parts
def __eq__(self, other):
if not isinstance(other, ComposeTransform):
return False
return self.parts == other.parts
@constraints.dependent_property(is_discrete=False)
def domain(self):
if not self.parts:
return constraints.real
domain = self.parts[0].domain
# Adjust event_dim to be maximum among all parts.
event_dim = self.parts[-1].codomain.event_dim
for part in reversed(self.parts):
event_dim += part.domain.event_dim - part.codomain.event_dim
event_dim = max(event_dim, part.domain.event_dim)
assert event_dim >= domain.event_dim
if event_dim > domain.event_dim:
domain = constraints.independent(domain, event_dim - domain.event_dim)
return domain
@constraints.dependent_property(is_discrete=False)
def codomain(self):
if not self.parts:
return constraints.real
codomain = self.parts[-1].codomain
# Adjust event_dim to be maximum among all parts.
event_dim = self.parts[0].domain.event_dim
for part in self.parts:
event_dim += part.codomain.event_dim - part.domain.event_dim
event_dim = max(event_dim, part.codomain.event_dim)
assert event_dim >= codomain.event_dim
if event_dim > codomain.event_dim:
codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
return codomain
@lazy_property
def bijective(self):
return all(p.bijective for p in self.parts)
@lazy_property
def sign(self):
sign = 1
for p in self.parts:
sign = sign * p.sign
return sign
@property
def inv(self):
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return ComposeTransform(self.parts, cache_size=cache_size)
def __call__(self, x):
for part in self.parts:
x = part(x)
return x
def log_abs_det_jacobian(self, x, y):
Loading ...