from typing import (
Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
)
import torch
import torch.nn.functional as F
from torch.types import _size
from ._lowrank import svd_lowrank, pca_lowrank
from .overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
handle_torch_function)
from ._jit_internal import boolean_dispatch, List
from ._jit_internal import _overload as overload
from torch._autograd_functions import _LU
Tensor = torch.Tensor
from torch import _VF
__all__ = [
'atleast_1d',
'atleast_2d',
'atleast_3d',
'align_tensors',
'broadcast_shapes',
'broadcast_tensors',
'cartesian_prod',
'block_diag',
'cdist',
'chain_matmul',
'einsum',
'istft',
'lu',
'lu_unpack',
'norm',
'meshgrid',
'pca_lowrank',
'split',
'stft',
'svd_lowrank',
'tensordot',
'unique',
'unique_consecutive',
]
def broadcast_tensors(*tensors):
r"""broadcast_tensors(*tensors) -> List of Tensors
Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
Args:
*tensors: any number of tensors of the same type
.. warning::
More than one element of a broadcasted tensor may refer to a single
memory location. As a result, in-place operations (especially ones that
are vectorized) may result in incorrect behavior. If you need to write
to the tensors, please clone them first.
Example::
>>> x = torch.arange(3).view(1, 3)
>>> y = torch.arange(2).view(2, 1)
>>> a, b = torch.broadcast_tensors(x, y)
>>> a.size()
torch.Size([2, 3])
>>> a
tensor([[0, 1, 2],
[0, 1, 2]])
"""
if has_torch_function(tensors):
return handle_torch_function(broadcast_tensors, tensors, *tensors)
return _VF.broadcast_tensors(tensors) # type: ignore
def broadcast_shapes(*shapes):
r"""broadcast_shapes(*shapes) -> Size
Similar to :func:`broadcast_tensors` but for shapes.
This is equivalent to
``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
but avoids the need create to intermediate tensors. This is useful for
broadcasting tensors of common batch shape but different rightmost shape,
e.g. to broadcast mean vectors with covariance matrices.
Example::
>>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
torch.Size([1, 3, 2])
Args:
\*shapes (torch.Size): Shapes of tensors.
Returns:
shape (torch.Size): A shape compatible with all input shapes.
Raises:
RuntimeError: If shapes are incompatible.
"""
# TODO Movie this to C++ once the jit has better support for torch.Size.
with torch.no_grad():
scalar = torch.zeros((), device="cpu")
tensors = [scalar.expand(shape) for shape in shapes]
tensors = broadcast_tensors(*tensors)
return tensors[0].shape
def split(tensor, split_size_or_sections, dim=0):
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
be split into equally sized chunks (if possible). Last chunk will be smaller if
the tensor size along the given dimension :attr:`dim` is not divisible by
:attr:`split_size`.
If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
to :attr:`split_size_or_sections`.
Args:
tensor (Tensor): tensor to split.
split_size_or_sections (int) or (list(int)): size of a single chunk or
list of sizes for each chunk
dim (int): dimension along which to split the tensor.
Example::
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
"""
if has_torch_function_unary(tensor):
return handle_torch_function(
split, (tensor,), tensor, split_size_or_sections, dim=dim)
# Overwriting reason:
# This dispatches to two ATen functions depending on the type of
# split_size_or_sections. The branching code is in tensor.py, which we
# call here.
return tensor.split(split_size_or_sections, dim)
if TYPE_CHECKING:
_Indices = _size
else:
_Indices = List[int]
# equivalent to itertools.product(indices)
def _indices_product(indices: _Indices) -> List[List[int]]:
empty_list = torch.jit.annotate(List[int], [])
result = [empty_list]
for idx in indices:
result_temp = torch.jit.annotate(List[List[int]], [])
for res in result:
for i in range(idx):
result_temp.append(res + [i])
result = result_temp
return result
def _index_tensor_with_indices_list(tensor, indices):
# type: (Tensor, List[int]) -> Tensor
out = tensor
for index in indices:
out = out[index]
return out
def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
# type: (Tensor, Tensor, bool, bool) -> (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])
r"""Unpacks the data and pivots from a LU factorization of a tensor.
Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
Args:
LU_data (Tensor): the packed LU factorization data
LU_pivots (Tensor): the packed LU factorization pivots
unpack_data (bool): flag indicating if the data should be unpacked
unpack_pivots (bool): flag indicating if the pivots should be unpacked
Examples::
>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = A.lu()
>>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
>>>
>>> # can recover A from factorization
>>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
>>> # LU factorization of a rectangular matrix:
>>> A = torch.randn(2, 3, 2)
>>> A_LU, pivots = A.lu()
>>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
>>> P
tensor([[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]],
[[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.]]])
>>> A_L
tensor([[[ 1.0000, 0.0000],
[ 0.4763, 1.0000],
[ 0.3683, 0.1135]],
[[ 1.0000, 0.0000],
[ 0.2957, 1.0000],
[-0.9668, -0.3335]]])
>>> A_U
tensor([[[ 2.1962, 1.0881],
[ 0.0000, -0.8681]],
[[-1.0947, 0.3736],
[ 0.0000, 0.5718]]])
>>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
>>> torch.norm(A_ - A)
tensor(2.9802e-08)
"""
if has_torch_function_variadic(LU_data, LU_pivots):
return handle_torch_function(
lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots,
unpack_data=unpack_data,
unpack_pivots=unpack_pivots)
shape = LU_data.shape
# In generalized LU factorization, the following shape relations hold:
# A.shape[-2:] == (m, n)
# P.shape[-2:] == (m, m)
# L.shape[-2:] == (m, k)
# U.shape[-2:] == (k, n)
# where k = min(m, n)
m, n = shape[-2:]
k = min(m, n)
if unpack_data:
U: Optional[Tensor] = LU_data.triu()
assert U is not None
if m != k:
U = U.narrow(-2, 0, k)
L: Optional[Tensor] = LU_data.tril()
assert L is not None
if k != n:
L = L.narrow(-1, 0, k)
L.diagonal(dim1=-2, dim2=-1).fill_(1)
else:
L = U = None
if unpack_pivots:
LU_pivots_zero_idx = LU_pivots - 1
if LU_data.dim() > 2:
P: Optional[Tensor] = torch.eye(m, device=LU_data.device,
dtype=LU_data.dtype) \
.expand(shape[:-1] + (m,)) \
.clone(memory_format=torch.contiguous_format)
assert P is not None
# TODO: rewrite when TorchScript supports product and map as
# product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed
indices = _indices_product(shape[:-2])
for idx in indices:
final_order = list(range(m))
for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):
final_order[k], final_order[j] = final_order[j], final_order[k]
# TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list
p_idx = _index_tensor_with_indices_list(P, idx)
p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))
else:
P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)
final_order = list(range(m))
for k, j, in enumerate(LU_pivots_zero_idx):
final_order[k], final_order[j] = final_order[j], final_order[k]
P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
else:
P = None
return P, L, U
def einsum(equation, *operands):
r"""einsum(equation, *operands) -> Tensor
Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
based on the Einstein summation convention.
Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
with some subscript and define which subscripts are part of the output. The output is then computed by summing
the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
Equation:
The :attr:`equation` string specifies the subscripts (lower case letters `['a', 'z']`) for each dimension of
the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a
comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
followed by the subscripts for the output. For instance, the following equation computes the transpose of a
matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
at most once for the output.
Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
batch matrix multiplication `'...ij,...jk'`.
A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
.. note::
``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
.. note::
This function does not optimize the given expression, so a different formula for the same computation may
Loading ...