Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations for linear algebra."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# Linear algebra ops.
band_part = array_ops.matrix_band_part
cholesky = dispatch.add_dispatch_support(linalg_ops.cholesky)
cholesky_solve = linalg_ops.cholesky_solve
det = dispatch.add_dispatch_support(linalg_ops.matrix_determinant)
slogdet = gen_linalg_ops.log_matrix_determinant
tf_export('linalg.slogdet')(slogdet)
diag = array_ops.matrix_diag
diag_part = dispatch.add_dispatch_support(array_ops.matrix_diag_part)
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
eye = linalg_ops.eye
inv = dispatch.add_dispatch_support(linalg_ops.matrix_inverse)
logm = gen_linalg_ops.matrix_logarithm
lu = gen_linalg_ops.lu
tf_export('linalg.logm')(logm)
lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm
qr = linalg_ops.qr
set_diag = array_ops.matrix_set_diag
solve = dispatch.add_dispatch_support(linalg_ops.matrix_solve)
sqrtm = linalg_ops.matrix_square_root
svd = linalg_ops.svd
tensordot = math_ops.tensordot
trace = dispatch.add_dispatch_support(math_ops.trace)
transpose = array_ops.matrix_transpose
triangular_solve = linalg_ops.matrix_triangular_solve
@tf_export('linalg.logdet')
@dispatch.add_dispatch_support
def logdet(matrix, name=None):
"""Computes log of the determinant of a hermitian positive definite matrix.
```python
# Compute the determinant of a matrix while reducing the chance of over- or
underflow:
A = ... # shape 10 x 10
det = tf.exp(tf.logdet(A)) # scalar
```
Args:
matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
or `complex128` with shape `[..., M, M]`.
name: A name to give this `Op`. Defaults to `logdet`.
Returns:
The natural log of the determinant of `matrix`.
@compatibility(numpy)
Equivalent to numpy.linalg.slogdet, although no sign is returned since only
hermitian positive definite matrices are supported.
@end_compatibility
"""
# This uses the property that the log det(A) = 2*sum(log(real(diag(C))))
# where C is the cholesky decomposition of A.
with ops.name_scope(name, 'logdet', [matrix]):
chol = gen_linalg_ops.cholesky(matrix)
return 2.0 * math_ops.reduce_sum(
math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
axis=[-1])
@tf_export('linalg.adjoint')
@dispatch.add_dispatch_support
def adjoint(matrix, name=None):
"""Transposes the last two dimensions of and conjugates tensor `matrix`.
For example:
```python
x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
[4 + 4j, 5 + 5j, 6 + 6j]])
tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j],
# [2 - 2j, 5 - 5j],
# [3 - 3j, 6 - 6j]]
```
Args:
matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
or `complex128` with shape `[..., M, M]`.
name: A name to give this `Op` (optional).
Returns:
The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of
matrix.
"""
with ops.name_scope(name, 'adjoint', [matrix]):
matrix = ops.convert_to_tensor(matrix, name='matrix')
return array_ops.matrix_transpose(matrix, conjugate=True)
# This section is ported nearly verbatim from Eigen's implementation:
# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html
def _matrix_exp_pade3(matrix):
"""3rd-order Pade approximant for matrix exponential."""
b = [120.0, 60.0, 12.0]
b = [constant_op.constant(x, matrix.dtype) for x in b]
ident = linalg_ops.eye(
array_ops.shape(matrix)[-2],
batch_shape=array_ops.shape(matrix)[:-2],
dtype=matrix.dtype)
matrix_2 = math_ops.matmul(matrix, matrix)
tmp = matrix_2 + b[1] * ident
matrix_u = math_ops.matmul(matrix, tmp)
matrix_v = b[2] * matrix_2 + b[0] * ident
return matrix_u, matrix_v
def _matrix_exp_pade5(matrix):
"""5th-order Pade approximant for matrix exponential."""
b = [30240.0, 15120.0, 3360.0, 420.0, 30.0]
b = [constant_op.constant(x, matrix.dtype) for x in b]
ident = linalg_ops.eye(
array_ops.shape(matrix)[-2],
batch_shape=array_ops.shape(matrix)[:-2],
dtype=matrix.dtype)
matrix_2 = math_ops.matmul(matrix, matrix)
matrix_4 = math_ops.matmul(matrix_2, matrix_2)
tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident
matrix_u = math_ops.matmul(matrix, tmp)
matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
return matrix_u, matrix_v
def _matrix_exp_pade7(matrix):
"""7th-order Pade approximant for matrix exponential."""
b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0]
b = [constant_op.constant(x, matrix.dtype) for x in b]
ident = linalg_ops.eye(
array_ops.shape(matrix)[-2],
batch_shape=array_ops.shape(matrix)[:-2],
dtype=matrix.dtype)
matrix_2 = math_ops.matmul(matrix, matrix)
matrix_4 = math_ops.matmul(matrix_2, matrix_2)
matrix_6 = math_ops.matmul(matrix_4, matrix_2)
tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident
matrix_u = math_ops.matmul(matrix, tmp)
matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
return matrix_u, matrix_v
def _matrix_exp_pade9(matrix):
"""9th-order Pade approximant for matrix exponential."""
b = [
17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0,
2162160.0, 110880.0, 3960.0, 90.0
]
b = [constant_op.constant(x, matrix.dtype) for x in b]
ident = linalg_ops.eye(
array_ops.shape(matrix)[-2],
batch_shape=array_ops.shape(matrix)[:-2],
dtype=matrix.dtype)
matrix_2 = math_ops.matmul(matrix, matrix)
matrix_4 = math_ops.matmul(matrix_2, matrix_2)
matrix_6 = math_ops.matmul(matrix_4, matrix_2)
matrix_8 = math_ops.matmul(matrix_6, matrix_2)
tmp = (
matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 +
b[1] * ident)
matrix_u = math_ops.matmul(matrix, tmp)
matrix_v = (
b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 +
b[0] * ident)
return matrix_u, matrix_v
def _matrix_exp_pade13(matrix):
"""13th-order Pade approximant for matrix exponential."""
b = [
64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0
]
b = [constant_op.constant(x, matrix.dtype) for x in b]
ident = linalg_ops.eye(
array_ops.shape(matrix)[-2],
batch_shape=array_ops.shape(matrix)[:-2],
dtype=matrix.dtype)
matrix_2 = math_ops.matmul(matrix, matrix)
matrix_4 = math_ops.matmul(matrix_2, matrix_2)
matrix_6 = math_ops.matmul(matrix_4, matrix_2)
tmp_u = (
math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) +
b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident)
matrix_u = math_ops.matmul(matrix, tmp_u)
tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2
matrix_v = (
math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 +
b[2] * matrix_2 + b[0] * ident)
return matrix_u, matrix_v
@tf_export('linalg.expm')
def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
r"""Computes the matrix exponential of one or more square matrices.
exp(A) = \sum_{n=0}^\infty A^n/n!
The exponential is computed using a combination of the scaling and squaring
method and the Pade approximation. Details can be found in:
Nicholas J. Higham, "The scaling and squaring method for the matrix
exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices. The output is a tensor of the same shape as the input
containing the exponential for all input submatrices `[..., :, :]`.
Args:
input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or
`complex128` with shape `[..., M, M]`.
name: A name to give this `Op` (optional).
Returns:
the matrix exponential of the input.
Raises:
ValueError: An unsupported type is provided as input.
@compatibility(scipy)
Equivalent to scipy.linalg.expm
@end_compatibility
"""
with ops.name_scope(name, 'matrix_exponential', [input]):
matrix = ops.convert_to_tensor(input, name='input')
if matrix.shape[-2:] == [0, 0]:
return matrix
batch_shape = matrix.shape[:-2]
if not batch_shape.is_fully_defined():
batch_shape = array_ops.shape(matrix)[:-2]
# reshaping the batch makes the where statements work better
matrix = array_ops.reshape(
matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
l1_norm = math_ops.reduce_max(
math_ops.reduce_sum(
math_ops.abs(matrix),
axis=array_ops.size(array_ops.shape(matrix)) - 2),
axis=-1)
const = lambda x: constant_op.constant(x, l1_norm.dtype)
def _nest_where(vals, cases):
assert len(vals) == len(cases) - 1
if len(vals) == 1:
return array_ops.where(
math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
else:
return array_ops.where(
math_ops.less(l1_norm, const(vals[0])), cases[0],
_nest_where(vals[1:], cases[1:]))
if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
maxnorm = const(3.925724783138660)
squarings = math_ops.maximum(
math_ops.floor(
math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
u3, v3 = _matrix_exp_pade3(matrix)
u5, v5 = _matrix_exp_pade5(matrix)
u7, v7 = _matrix_exp_pade7(matrix / math_ops.pow(
constant_op.constant(2.0, dtype=matrix.dtype),
math_ops.cast(
squarings,
matrix.dtype))[..., array_ops.newaxis, array_ops.newaxis])
conds = (4.258730016922831e-001, 1.880152677804762e+000)
u = _nest_where(conds, (u3, u5, u7))
v = _nest_where(conds, (v3, v5, v7))
elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
maxnorm = const(5.371920351148152)
squarings = math_ops.maximum(
math_ops.floor(
math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
u3, v3 = _matrix_exp_pade3(matrix)
u5, v5 = _matrix_exp_pade5(matrix)
u7, v7 = _matrix_exp_pade7(matrix)
u9, v9 = _matrix_exp_pade9(matrix)
u13, v13 = _matrix_exp_pade13(matrix / math_ops.pow(
constant_op.constant(2.0, dtype=matrix.dtype),
math_ops.cast(
squarings,
matrix.dtype))[..., array_ops.newaxis, array_ops.newaxis])
conds = (1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000)
u = _nest_where(conds, (u3, u5, u7, u9, u13))
v = _nest_where(conds, (v3, v5, v7, v9, v13))
else:
raise ValueError('tf.linalg.expm does not support matrices of type %s' %
matrix.dtype)
numer = u + v
denom = -u + v
result = linalg_ops.matrix_solve(denom, numer)
max_squarings = math_ops.reduce_max(squarings)
i = const(0.0)
c = lambda i, r: math_ops.less(i, max_squarings)
def b(i, r):
return i + 1, array_ops.where(
math_ops.less(i, squarings), math_ops.matmul(r, r), r)
_, result = control_flow_ops.while_loop(c, b, [i, result])
if not matrix.shape.is_fully_defined():
return array_ops.reshape(
result,
array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0))
return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
@tf_export('linalg.tridiagonal_solve')
def tridiagonal_solve(diagonals,
rhs,
diagonals_format='compact',
transpose_rhs=False,
conjugate_rhs=False,
name=None,
partial_pivoting=True):
r"""Solves tridiagonal systems of equations.
The input can be supplied in various formats: `matrix`, `sequence` and
`compact`, specified by the `diagonals_format` arg.
In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
two inner-most dimensions representing the square tridiagonal matrices.
Elements outside of the three diagonals will be ignored.
In `sequence` format, `diagonals` are supplied as a tuple or list of three
tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
`M-1` or `M`; in the latter case, the last element of superdiagonal and the
first element of subdiagonal will be ignored.
In `compact` format the three diagonals are brought together into one tensor
of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
diagonals, and subdiagonals, in order. Similarly to `sequence` format,
elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
The `compact` format is recommended as the one with best performance. In case
you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
An example for a tensor of shape [m, m]:
```python
rhs = tf.constant([...])
matrix = tf.constant([[...]])
m = matrix.shape[0]
dummy_idx = [0, 0] # An arbitrary element to use as a dummy
indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal
[[i, i] for i in range(m)], # Diagonal
[dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal
diagonals=tf.gather_nd(matrix, indices)
x = tf.linalg.tridiagonal_solve(diagonals, rhs)
```
Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
`[..., M, K]`. The latter allows to simultaneously solve K systems with the
same left-hand sides and K different right-hand sides. If `transpose_rhs`
is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.
The batch dimensions, denoted as `...`, must be the same in `diagonals` and
`rhs`.
The output is a tensor of the same shape as `rhs`: either `[..., M]` or
`[..., M, K]`.
The op isn't guaranteed to raise an error if the input matrix is not
invertible. `tf.debugging.check_numerics` can be applied to the output to
detect invertibility problems.
**Note**: with large batch sizes, the computation on the GPU may be slow, if
either `partial_pivoting=True` or there are multiple right-hand sides
(`K > 1`). If this issue arises, consider if it's possible to disable pivoting
and have `K = 1`, or, alternatively, consider using CPU.
On CPU, solution is computed via Gaussian elimination with or without partial
pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE
library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
Args:
diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
shape depends of `diagonals_format`, see description above. Must be
`float32`, `float64`, `complex64`, or `complex128`.
rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
`diagonals`.
diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
`compact`.
transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
if the shape of rhs is [..., M]).
conjugate_rhs: If `True`, `rhs` is conjugated before solving.
name: A name to give this `Op` (optional).
partial_pivoting: whether to perform partial pivoting. `True` by default.
Partial pivoting makes the procedure more stable, but slower. Partial
pivoting is unnecessary in some cases, including diagonally dominant and
symmetric positive definite matrices (see e.g. theorem 9.12 in [1]).
Returns:
A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.
Raises:
ValueError: An unsupported type is provided as input, or when the input
tensors have incorrect shapes.
[1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
"""
if diagonals_format == 'compact':
return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
conjugate_rhs, partial_pivoting,
name)
if diagonals_format == 'sequence':
if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
raise ValueError('Expected diagonals to be a sequence of length 3.')
superdiag, maindiag, subdiag = diagonals
if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or
not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])):
raise ValueError(
'Tensors representing the three diagonals must have the same shape,'
'except for the last dimension, got {}, {}, {}'.format(
subdiag.shape, maindiag.shape, superdiag.shape))
m = tensor_shape.dimension_value(maindiag.shape[-1])
def pad_if_necessary(t, name, last_dim_padding):
n = tensor_shape.dimension_value(t.shape[-1])
if not n or n == m:
return t
if n == m - 1:
paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
[last_dim_padding])
return array_ops.pad(t, paddings)
raise ValueError('Expected {} to be have length {} or {}, got {}.'.format(
name, m, m - 1, n))
subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])
diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
conjugate_rhs, partial_pivoting,
name)
if diagonals_format == 'matrix':
m1 = tensor_shape.dimension_value(diagonals.shape[-1])
m2 = tensor_shape.dimension_value(diagonals.shape[-2])
if m1 and m2 and m1 != m2:
raise ValueError(
'Expected last two dimensions of diagonals to be same, got {} and {}'
.format(m1, m2))
m = m1 or m2
if not m:
raise ValueError('The size of the matrix needs to be known for '
'diagonals_format="matrix"')
# Extract diagonals; use input[..., 0, 0] as "dummy" m-th elements of sub-
# and superdiagonal.
# gather_nd slices into first indices, whereas we need to slice into the
# last two, so transposing back and forth is necessary.
dummy_idx = [0, 0]
indices = ([[[1, 0], [0, 0], dummy_idx]] +
[[[i + 1, i], [i, i], [i - 1, i]] for i in range(1, m - 1)] +
[[dummy_idx, [m - 1, m - 1], [m - 2, m - 1]]])
diagonals = array_ops.transpose(
array_ops.gather_nd(array_ops.transpose(diagonals), indices))
return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
conjugate_rhs, partial_pivoting,
name)
raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format))
def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
conjugate_rhs, partial_pivoting, name):
"""Helper function used after the input has been cast to compact form."""
diags_rank, rhs_rank = len(diagonals.shape), len(rhs.shape)
if diags_rank < 2:
raise ValueError(
'Expected diagonals to have rank at least 2, got {}'.format(diags_rank))
if rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format(
diags_rank - 1, diags_rank, rhs_rank))
if diagonals.shape[-2] and diagonals.shape[-2] != 3:
raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2]))
if not diagonals.shape[:-2].is_compatible_with(rhs.shape[:diags_rank - 2]):
raise ValueError('Batch shapes {} and {} are incompatible'.format(
diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))
def check_num_lhs_matches_num_rhs():
if (diagonals.shape[-1] and rhs.shape[-2] and
diagonals.shape[-1] != rhs.shape[-2]):
raise ValueError('Expected number of left-hand sided and right-hand '
'sides to be equal, got {} and {}'.format(
diagonals.shape[-1], rhs.shape[-2]))
if rhs_rank == diags_rank - 1:
# Rhs provided as a vector, ignoring transpose_rhs
if conjugate_rhs:
rhs = math_ops.conj(rhs)
rhs = array_ops.expand_dims(rhs, -1)
check_num_lhs_matches_num_rhs()
return array_ops.squeeze(
linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, name),
-1)
if transpose_rhs:
rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
elif conjugate_rhs:
rhs = math_ops.conj(rhs)
check_num_lhs_matches_num_rhs()
result = linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, name)
return array_ops.matrix_transpose(result) if transpose_rhs else result
@tf_export('linalg.tridiagonal_matmul')
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
r"""Multiplies tridiagonal matrix by matrix.
`diagonals` is representation of 3-diagonal NxN matrix, which depends on
`diagonals_format`.
In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
two inner-most dimensions representing the square tridiagonal matrices.
Elements outside of the three diagonals will be ignored.
If `sequence` format, `diagonals` is list or tuple of three tensors:
`[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
of `superdiag` first element of `subdiag` are ignored.
In `compact` format the three diagonals are brought together into one tensor
of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
diagonals, and subdiagonals, in order. Similarly to `sequence` format,
elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
The `sequence` format is recommended as the one with the best performance.
`rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.
Example:
```python
superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [superdiag, maindiag, subdiag]
rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
```
Args:
diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
shape depends of `diagonals_format`, see description above. Must be
`float32`, `float64`, `complex64`, or `complex128`.
rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
name: A name to give this `Op` (optional).
Returns:
A `Tensor` of shape [..., M, N] containing the result of multiplication.
Raises:
ValueError: An unsupported type is provided as input, or when the input
tensors have incorrect shapes.
"""
if diagonals_format == 'compact':
superdiag = diagonals[..., 0, :]
maindiag = diagonals[..., 1, :]
subdiag = diagonals[..., 2, :]
elif diagonals_format == 'sequence':
superdiag, maindiag, subdiag = diagonals
elif diagonals_format == 'matrix':
m1 = tensor_shape.dimension_value(diagonals.shape[-1])
m2 = tensor_shape.dimension_value(diagonals.shape[-2])
if not m1 or not m2:
raise ValueError('The size of the matrix needs to be known for '
'diagonals_format="matrix"')
if m1 != m2:
raise ValueError(
'Expected last two dimensions of diagonals to be same, got {} and {}'
.format(m1, m2))
# TODO(b/131695260): use matrix_diag_part when it supports extracting
# arbitrary diagonals.
maindiag = array_ops.matrix_diag_part(diagonals)
diagonals = array_ops.transpose(diagonals)
dummy_index = [0, 0]
superdiag_indices = [[i + 1, i] for i in range(0, m1 - 1)] + [dummy_index]
subdiag_indices = [dummy_index] + [[i - 1, i] for i in range(1, m1)]
superdiag = array_ops.transpose(
array_ops.gather_nd(diagonals, superdiag_indices))
subdiag = array_ops.transpose(
array_ops.gather_nd(diagonals, subdiag_indices))
else:
raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)
# C++ backend requires matrices.
# Converting 1-dimensional vectors to matrices with 1 row.
superdiag = array_ops.expand_dims(superdiag, -2)
maindiag = array_ops.expand_dims(maindiag, -2)
subdiag = array_ops.expand_dims(subdiag, -2)
return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)