124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
|
import bisect
|
||
|
from collections import defaultdict
|
||
|
|
||
|
from sympy.combinatorics import Permutation
|
||
|
from sympy.core.containers import Tuple
|
||
|
from sympy.core.numbers import Integer
|
||
|
|
||
|
|
||
|
def _get_mapping_from_subranks(subranks):
|
||
|
mapping = {}
|
||
|
counter = 0
|
||
|
for i, rank in enumerate(subranks):
|
||
|
for j in range(rank):
|
||
|
mapping[counter] = (i, j)
|
||
|
counter += 1
|
||
|
return mapping
|
||
|
|
||
|
|
||
|
def _get_contraction_links(args, subranks, *contraction_indices):
|
||
|
mapping = _get_mapping_from_subranks(subranks)
|
||
|
contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
|
||
|
dlinks = defaultdict(dict)
|
||
|
for links in contraction_tuples:
|
||
|
if len(links) == 2:
|
||
|
(arg1, pos1), (arg2, pos2) = links
|
||
|
dlinks[arg1][pos1] = (arg2, pos2)
|
||
|
dlinks[arg2][pos2] = (arg1, pos1)
|
||
|
continue
|
||
|
|
||
|
return args, dict(dlinks)
|
||
|
|
||
|
|
||
|
def _sort_contraction_indices(pairing_indices):
|
||
|
pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
|
||
|
pairing_indices.sort(key=lambda x: min(x))
|
||
|
return pairing_indices
|
||
|
|
||
|
|
||
|
def _get_diagonal_indices(flattened_indices):
|
||
|
axes_contraction = defaultdict(list)
|
||
|
for i, ind in enumerate(flattened_indices):
|
||
|
if isinstance(ind, (int, Integer)):
|
||
|
# If the indices is a number, there can be no diagonal operation:
|
||
|
continue
|
||
|
axes_contraction[ind].append(i)
|
||
|
axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
|
||
|
# Put the diagonalized indices at the end:
|
||
|
ret_indices = [i for i in flattened_indices if i not in axes_contraction]
|
||
|
diag_indices = list(axes_contraction)
|
||
|
diag_indices.sort(key=lambda x: flattened_indices.index(x))
|
||
|
diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
|
||
|
ret_indices += diag_indices
|
||
|
ret_indices = tuple(ret_indices)
|
||
|
return diagonal_indices, ret_indices
|
||
|
|
||
|
|
||
|
def _get_argindex(subindices, ind):
|
||
|
for i, sind in enumerate(subindices):
|
||
|
if ind == sind:
|
||
|
return i
|
||
|
if isinstance(sind, (set, frozenset)) and ind in sind:
|
||
|
return i
|
||
|
raise IndexError("%s not found in %s" % (ind, subindices))
|
||
|
|
||
|
|
||
|
def _apply_recursively_over_nested_lists(func, arr):
|
||
|
if isinstance(arr, (tuple, list, Tuple)):
|
||
|
return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
|
||
|
elif isinstance(arr, Tuple):
|
||
|
return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
|
||
|
else:
|
||
|
return func(arr)
|
||
|
|
||
|
|
||
|
def _build_push_indices_up_func_transformation(flattened_contraction_indices):
|
||
|
shifts = {0: 0}
|
||
|
i = 0
|
||
|
cumulative = 0
|
||
|
while i < len(flattened_contraction_indices):
|
||
|
j = 1
|
||
|
while i+j < len(flattened_contraction_indices):
|
||
|
if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
|
||
|
break
|
||
|
j += 1
|
||
|
cumulative += j
|
||
|
shifts[flattened_contraction_indices[i]] = cumulative
|
||
|
i += j
|
||
|
shift_keys = sorted(shifts.keys())
|
||
|
|
||
|
def func(idx):
|
||
|
return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
|
||
|
|
||
|
def transform(j):
|
||
|
if j in flattened_contraction_indices:
|
||
|
return None
|
||
|
else:
|
||
|
return j - func(j)
|
||
|
|
||
|
return transform
|
||
|
|
||
|
|
||
|
def _build_push_indices_down_func_transformation(flattened_contraction_indices):
|
||
|
N = flattened_contraction_indices[-1]+2
|
||
|
|
||
|
shifts = [i for i in range(N) if i not in flattened_contraction_indices]
|
||
|
|
||
|
def transform(j):
|
||
|
if j < len(shifts):
|
||
|
return shifts[j]
|
||
|
else:
|
||
|
return j + shifts[-1] - len(shifts) + 1
|
||
|
|
||
|
return transform
|
||
|
|
||
|
|
||
|
def _apply_permutation_to_list(perm: Permutation, target_list: list):
|
||
|
"""
|
||
|
Permute a list according to the given permutation.
|
||
|
"""
|
||
|
new_list = [None for i in range(perm.size)]
|
||
|
for i, e in enumerate(target_list):
|
||
|
new_list[perm(i)] = e
|
||
|
return new_list
|