r"""
The following constraints are implemented:
- ``constraints.boolean``
- ``constraints.cat``
- ``constraints.corr_cholesky``
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.greater_than_eq(lower_bound)``
- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.less_than(upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.multinomial``
- ``constraints.nonnegative_integer``
- ``constraints.one_hot``
- ``constraints.positive_integer``
- ``constraints.positive``
- ``constraints.positive_semidefinite``
- ``constraints.positive_definite``
- ``constraints.real_vector``
- ``constraints.real``
- ``constraints.simplex``
- ``constraints.symmetric``
- ``constraints.stack``
- ``constraints.square``
- ``constraints.symmetric``
- ``constraints.unit_interval``
"""
import torch
__all__ = [
'Constraint',
'boolean',
'cat',
'corr_cholesky',
'dependent',
'dependent_property',
'greater_than',
'greater_than_eq',
'independent',
'integer_interval',
'interval',
'half_open_interval',
'is_dependent',
'less_than',
'lower_cholesky',
'lower_triangular',
'multinomial',
'nonnegative_integer',
'positive',
'positive_semidefinite',
'positive_definite',
'positive_integer',
'real',
'real_vector',
'simplex',
'square',
'stack',
'symmetric',
'unit_interval',
]
class Constraint:
"""
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
Attributes:
is_discrete (bool): Whether constrained space is discrete.
Defaults to False.
event_dim (int): Number of rightmost dimensions that together define
an event. The :meth:`check` method will remove this many dimensions
when computing validity.
"""
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.
def check(self, value):
"""
Returns a byte tensor of ``sample_shape + batch_shape`` indicating
whether each event in value satisfies this constraint.
"""
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__[1:] + '()'
class _Dependent(Constraint):
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
Args:
is_discrete (bool): Optional value of ``.is_discrete`` in case this
can be computed statically. If not provided, access to the
``.is_discrete`` attribute will raise a NotImplementedError.
event_dim (int): Optional value of ``.event_dim`` in case this
can be computed statically. If not provided, access to the
``.event_dim`` attribute will raise a NotImplementedError.
"""
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()
@property
def is_discrete(self):
if self._is_discrete is NotImplemented:
raise NotImplementedError(".is_discrete cannot be determined statically")
return self._is_discrete
@property
def event_dim(self):
if self._event_dim is NotImplemented:
raise NotImplementedError(".event_dim cannot be determined statically")
return self._event_dim
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
"""
Support for syntax to customize static attributes::
constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is NotImplemented:
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
def check(self, x):
raise ValueError('Cannot determine validity of dependent constraint')
def is_dependent(constraint):
return isinstance(constraint, _Dependent)
class _DependentProperty(property, _Dependent):
"""
Decorator that extends @property to act like a `Dependent` constraint when
called on a class and act like a property when called on an object.
Example::
class Uniform(Distribution):
def __init__(self, low, high):
self.low = low
self.high = high
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.interval(self.low, self.high)
Args:
fn (Callable): The function to be decorated.
is_discrete (bool): Optional value of ``.is_discrete`` in case this
can be computed statically. If not provided, access to the
``.is_discrete`` attribute will raise a NotImplementedError.
event_dim (int): Optional value of ``.event_dim`` in case this
can be computed statically. If not provided, access to the
``.event_dim`` attribute will raise a NotImplementedError.
"""
def __init__(self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented):
super().__init__(fn)
self._is_discrete = is_discrete
self._event_dim = event_dim
def __call__(self, fn):
"""
Support for syntax to customize static attributes::
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
...
"""
return _DependentProperty(fn, is_discrete=self._is_discrete, event_dim=self._event_dim)
class _IndependentConstraint(Constraint):
"""
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
dims in :meth:`check`, so that an event is valid only if all its
independent entries are valid.
"""
def __init__(self, base_constraint, reinterpreted_batch_ndims):
assert isinstance(base_constraint, Constraint)
assert isinstance(reinterpreted_batch_ndims, int)
assert reinterpreted_batch_ndims >= 0
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()
@property
def is_discrete(self):
return self.base_constraint.is_discrete
@property
def event_dim(self):
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
def check(self, value):
result = self.base_constraint.check(value)
if result.dim() < self.reinterpreted_batch_ndims:
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
raise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}")
result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,))
result = result.all(-1)
return result
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__[1:], repr(self.base_constraint),
self.reinterpreted_batch_ndims)
class _Boolean(Constraint):
"""
Constrain to the two values `{0, 1}`.
"""
is_discrete = True
def check(self, value):
return (value == 0) | (value == 1)
class _OneHot(Constraint):
"""
Constrain to one-hot vectors.
"""
is_discrete = True
event_dim = 1
def check(self, value):
is_boolean = (value == 0) | (value == 1)
is_normalized = value.sum(-1).eq(1)
return is_boolean.all(-1) & is_normalized
class _IntegerInterval(Constraint):
"""
Constrain to an integer interval `[lower_bound, upper_bound]`.
"""
is_discrete = True
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()
def check(self, value):
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string
class _IntegerLessThan(Constraint):
"""
Constrain to an integer interval `(-inf, upper_bound]`.
"""
is_discrete = True
def __init__(self, upper_bound):
self.upper_bound = upper_bound
super().__init__()
def check(self, value):
return (value % 1 == 0) & (value <= self.upper_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(upper_bound={})'.format(self.upper_bound)
return fmt_string
class _IntegerGreaterThan(Constraint):
"""
Constrain to an integer interval `[lower_bound, inf)`.
"""
is_discrete = True
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()
def check(self, value):
return (value % 1 == 0) & (value >= self.lower_bound)
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string
class _Real(Constraint):
"""
Trivially constrain to the extended real line `[-inf, inf]`.
"""
def check(self, value):
return value == value # False for NANs.
class _GreaterThan(Constraint):
"""
Constrain to a real half line `(lower_bound, inf]`.
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()
def check(self, value):
return self.lower_bound < value
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string
class _GreaterThanEq(Constraint):
"""
Constrain to a real half line `[lower_bound, inf)`.
"""
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()
def check(self, value):
return self.lower_bound <= value
def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string
class _LessThan(Constraint):
Loading ...