Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
jax / jax / _src / third_party / numpy / linalg.py
Size: Mime:
import numpy as np

import jax.numpy as jnp
import jax.numpy.linalg as la
from jax._src.numpy.util import check_arraylike, _wraps


def _isEmpty2d(arr):
  # check size first for efficiency
  return arr.size == 0 and np.prod(arr.shape[-2:]) == 0


def _assertNoEmpty2d(*arrays):
  for a in arrays:
    if _isEmpty2d(a):
      raise np.linalg.LinAlgError("Arrays cannot be empty")


def _assertRankAtLeast2(*arrays):
  for a in arrays:
    if a.ndim < 2:
      raise np.linalg.LinAlgError(
          '%d-dimensional array given. Array must be '
          'at least two-dimensional' % a.ndim)


def _assertNdSquareness(*arrays):
  for a in arrays:
    m, n = a.shape[-2:]
    if m != n:
      raise np.linalg.LinAlgError(
          'Last 2 dimensions of the array must be square')


def _assert2d(*arrays):
  for a in arrays:
    if a.ndim != 2:
      raise ValueError(f'{a.ndim}-dimensional array given. '
                       'Array must be two-dimensional')


@_wraps(np.linalg.cond)
def cond(x, p=None):
  check_arraylike('jnp.linalg.cond', x)
  _assertNoEmpty2d(x)
  if p in (None, 2):
    s = la.svd(x, compute_uv=False)
    return s[..., 0] / s[..., -1]
  elif p == -2:
    s = la.svd(x, compute_uv=False)
    r = s[..., -1] / s[..., 0]
  else:
    _assertRankAtLeast2(x)
    _assertNdSquareness(x)
    invx = la.inv(x)
    r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(invx, ord=p, axis=(-2, -1))

  # Convert nans to infs unless the original array had nan entries
  orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any())
  nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
  r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
  return r


@_wraps(np.linalg.tensorinv)
def tensorinv(a, ind=2):
  check_arraylike('jnp.linalg.tensorinv', a)
  a = jnp.asarray(a)
  oldshape = a.shape
  prod = 1
  if ind > 0:
    invshape = oldshape[ind:] + oldshape[:ind]
    for k in oldshape[ind:]:
      prod *= k
  else:
    raise ValueError("Invalid ind argument.")
  a = a.reshape(prod, -1)
  ia = la.inv(a)
  return ia.reshape(*invshape)


@_wraps(np.linalg.tensorsolve)
def tensorsolve(a, b, axes=None):
  check_arraylike('jnp.linalg.tensorsolve', a, b)
  a = jnp.asarray(a)
  b = jnp.asarray(b)
  an = a.ndim
  if axes is not None:
    allaxes = list(range(0, an))
    for k in axes:
      allaxes.remove(k)
      allaxes.insert(an, k)

    a = a.transpose(allaxes)

  Q = a.shape[-(an - b.ndim):]

  prod = 1
  for k in Q:
    prod *= k

  a = a.reshape(-1, prod)
  b = b.ravel()

  res = jnp.asarray(la.solve(a, b))
  res = res.reshape(Q)

  return res


@_wraps(np.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
  check_arraylike('jnp.linalg.multi_dot', *arrays)
  n = len(arrays)
  # optimization only makes sense for len(arrays) > 2
  if n < 2:
    raise ValueError("Expecting at least two arrays.")
  elif n == 2:
    return jnp.dot(arrays[0], arrays[1], precision=precision)

  arrays = [jnp.asarray(a) for a in arrays]

  # save original ndim to reshape the result array into the proper form later
  ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
  # Explicitly convert vectors to 2D arrays to keep the logic of the internal
  # _multi_dot_* functions as simple as possible.
  if arrays[0].ndim == 1:
    arrays[0] = jnp.atleast_2d(arrays[0])
  if arrays[-1].ndim == 1:
    arrays[-1] = jnp.atleast_2d(arrays[-1]).T
  _assert2d(*arrays)

  # _multi_dot_three is much faster than _multi_dot_matrix_chain_order
  if n == 3:
    result = _multi_dot_three(*arrays, precision)
  else:
    order = _multi_dot_matrix_chain_order(arrays)
    result = _multi_dot(arrays, order, 0, n - 1, precision)

  # return proper shape
  if ndim_first == 1 and ndim_last == 1:
    return result[0, 0]  # scalar
  elif ndim_first == 1 or ndim_last == 1:
    return result.ravel()  # 1-D
  else:
    return result


def _multi_dot_three(A, B, C, precision):
  """
  Find the best order for three arrays and do the multiplication.
  For three arguments `_multi_dot_three` is approximately 15 times faster
  than `_multi_dot_matrix_chain_order`
  """
  a0, a1b0 = A.shape
  b1c0, c1 = C.shape
  # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
  cost1 = a0 * b1c0 * (a1b0 + c1)
  # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
  cost2 = a1b0 * c1 * (a0 + b1c0)

  if cost1 < cost2:
    return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
  else:
    return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)


def _multi_dot_matrix_chain_order(arrays, return_costs=False):
  """
  Return a jnp.array that encodes the optimal order of mutiplications.
  The optimal order array is then used by `_multi_dot()` to do the
  multiplication.
  Also return the cost matrix if `return_costs` is `True`
  The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
  Chapter 15.2, p. 370-378.  Note that Cormen uses 1-based indices.
      cost[i, j] = min([
          cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
          for k in range(i, j)])
  """
  n = len(arrays)
  # p stores the dimensions of the matrices
  # Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
  p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
  # m is a matrix of costs of the subproblems
  # m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
  m = np.zeros((n, n), dtype=np.double)
  # s is the actual ordering
  # s[i, j] is the value of k at which we split the product A_i..A_j
  s = np.empty((n, n), dtype=np.intp)

  for l in range(1, n):
    for i in range(n - l):
      j = i + l
      m[i, j] = jnp.inf
      for k in range(i, j):
        q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
        if q < m[i, j]:
          m[i, j] = q
          s[i, j] = k  # Note that Cormen uses 1-based index

  return (s, m) if return_costs else s


def _multi_dot(arrays, order, i, j, precision):
  """Actually do the multiplication with the given order."""
  if i == j:
    return arrays[i]
  else:
    return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
                   _multi_dot(arrays, order, order[i, j] + 1, j, precision),
                   precision=precision)