"""
Implementation of optimized einsum.
"""
from __future__ import division, absolute_import, print_function
import itertools
from numpy.compat import basestring
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asanyarray, tensordot
from numpy.core.overrides import array_function_dispatch
__all__ = ['einsum', 'einsum_path']
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
einsum_symbols_set = set(einsum_symbols)
def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
"""
Computes the number of FLOPS in the contraction.
Parameters
----------
idx_contraction : iterable
The indices involved in the contraction
inner : bool
Does this contraction require an inner product?
num_terms : int
The number of terms in a contraction
size_dictionary : dict
The size of each of the indices in idx_contraction
Returns
-------
flop_count : int
The total number of FLOPS required for the contraction.
Examples
--------
>>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
90
>>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
270
"""
overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
op_factor = max(1, num_terms - 1)
if inner:
op_factor += 1
return overall_size * op_factor
def _compute_size_by_dict(indices, idx_dict):
"""
Computes the product of the elements in indices based on the dictionary
idx_dict.
Parameters
----------
indices : iterable
Indices to base the product on.
idx_dict : dictionary
Dictionary of index sizes
Returns
-------
ret : int
The resulting product.
Examples
--------
>>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
90
"""
ret = 1
for i in indices:
ret *= idx_dict[i]
return ret
def _find_contraction(positions, input_sets, output_set):
"""
Finds the contraction for a given set of input and output sets.
Parameters
----------
positions : iterable
Integer positions of terms used in the contraction.
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
Returns
-------
new_result : set
The indices of the resulting contraction
remaining : list
List of sets that have not been contracted, the new set is appended to
the end of this list
idx_removed : set
Indices removed from the entire contraction
idx_contraction : set
The indices used in the current contraction
Examples
--------
# A simple dot product test case
>>> pos = (0, 1)
>>> isets = [set('ab'), set('bc')]
>>> oset = set('ac')
>>> _find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
# A more complex case with additional terms in the contraction
>>> pos = (0, 2)
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set('ac')
>>> _find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
"""
idx_contract = set()
idx_remain = output_set.copy()
remaining = []
for ind, value in enumerate(input_sets):
if ind in positions:
idx_contract |= value
else:
remaining.append(value)
idx_remain |= value
new_result = idx_remain & idx_contract
idx_removed = (idx_contract - new_result)
remaining.append(new_result)
return (new_result, remaining, idx_removed, idx_contract)
def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
"""
Computes all possible pair contractions, sieves the results based
on ``memory_limit`` and returns the lowest cost path. This algorithm
scales factorial with respect to the elements in the list ``input_sets``.
Parameters
----------
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
idx_dict : dictionary
Dictionary of index sizes
memory_limit : int
The maximum number of elements in a temporary array
Returns
-------
path : list
The optimal contraction order within the memory limit constraint.
Examples
--------
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set()
>>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
>>> _path__optimal_path(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
full_results = [(0, [], input_sets)]
for iteration in range(len(input_sets) - 1):
iter_results = []
# Compute all unique pairs
for curr in full_results:
cost, positions, remaining = curr
for con in itertools.combinations(range(len(input_sets) - iteration), 2):
# Find the contraction
cont = _find_contraction(con, remaining, output_set)
new_result, new_input_sets, idx_removed, idx_contract = cont
# Sieve the results based on memory_limit
new_size = _compute_size_by_dict(new_result, idx_dict)
if new_size > memory_limit:
continue
# Build (total_cost, positions, indices_remaining)
total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
new_pos = positions + [con]
iter_results.append((total_cost, new_pos, new_input_sets))
# Update combinatorial list, if we did not find anything return best
# path + remaining contractions
if iter_results:
full_results = iter_results
else:
path = min(full_results, key=lambda x: x[0])[1]
path += [tuple(range(len(input_sets) - iteration))]
return path
# If we have not found anything return single einsum contraction
if len(full_results) == 0:
return [tuple(range(len(input_sets)))]
path = min(full_results, key=lambda x: x[0])[1]
return path
def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
"""Compute the cost (removed size + flops) and resultant indices for
performing the contraction specified by ``positions``.
Parameters
----------
positions : tuple of int
The locations of the proposed tensors to contract.
input_sets : list of sets
The indices found on each tensors.
output_set : set
The output indices of the expression.
idx_dict : dict
Mapping of each index to its size.
memory_limit : int
The total allowed size for an intermediary tensor.
path_cost : int
The contraction cost so far.
naive_cost : int
The cost of the unoptimized expression.
Returns
-------
cost : (int, int)
A tuple containing the size of any indices removed, and the flop cost.
positions : tuple of int
The locations of the proposed tensors to contract.
new_input_sets : list of sets
The resulting new list of indices if this proposed contraction is performed.
"""
# Find the contraction
contract = _find_contraction(positions, input_sets, output_set)
idx_result, new_input_sets, idx_removed, idx_contract = contract
# Sieve the results based on memory_limit
new_size = _compute_size_by_dict(idx_result, idx_dict)
if new_size > memory_limit:
return None
# Build sort tuple
old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
removed_size = sum(old_sizes) - new_size
# NB: removed_size used to be just the size of any removed indices i.e.:
# helpers.compute_size_by_dict(idx_removed, idx_dict)
cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
sort = (-removed_size, cost)
# Sieve based on total cost as well
if (path_cost + cost) > naive_cost:
return None
# Add contraction to possible choices
return [sort, positions, new_input_sets]
def _update_other_results(results, best):
"""Update the positions and provisional input_sets of ``results`` based on
performing the contraction result ``best``. Remove any involving the tensors
contracted.
Parameters
----------
results : list
List of contraction results produced by ``_parse_possible_contraction``.
best : list
The best contraction of ``results`` i.e. the one that will be performed.
Returns
-------
mod_results : list
The list of modifed results, updated with outcome of ``best`` contraction.
"""
best_con = best[1]
bx, by = best_con
mod_results = []
for cost, (x, y), con_sets in results:
# Ignore results involving tensors just contracted
if x in best_con or y in best_con:
continue
# Update the input_sets
del con_sets[by - int(by > x) - int(by > y)]
del con_sets[bx - int(bx > x) - int(bx > y)]
con_sets.insert(-1, best[2][-1])
# Update the position indices
mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
mod_results.append((cost, mod_con, con_sets))
return mod_results
def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
"""
Finds the path by contracting the best pair until the input list is
exhausted. The best pair is found by minimizing the tuple
``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
matrix multiplication or inner product operations, then Hadamard like
operations, and finally outer operations. Outer products are limited by
``memory_limit``. This algorithm scales cubically with respect to the
number of elements in the list ``input_sets``.
Parameters
----------
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
idx_dict : dictionary
Dictionary of index sizes
memory_limit_limit : int
The maximum number of elements in a temporary array
Returns
-------
path : list
The greedy contraction order within the memory limit constraint.
Examples
--------
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set()
>>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
>>> _path__greedy_path(isets, oset, idx_sizes, 5000)
Loading ...