ai-content-maker/.venv/Lib/site-packages/sympy/tensor/array/expressions/utils.py

124 lines
3.8 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
import bisect
from collections import defaultdict
from sympy.combinatorics import Permutation
from sympy.core.containers import Tuple
from sympy.core.numbers import Integer
def _get_mapping_from_subranks(subranks):
mapping = {}
counter = 0
for i, rank in enumerate(subranks):
for j in range(rank):
mapping[counter] = (i, j)
counter += 1
return mapping
def _get_contraction_links(args, subranks, *contraction_indices):
mapping = _get_mapping_from_subranks(subranks)
contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
dlinks = defaultdict(dict)
for links in contraction_tuples:
if len(links) == 2:
(arg1, pos1), (arg2, pos2) = links
dlinks[arg1][pos1] = (arg2, pos2)
dlinks[arg2][pos2] = (arg1, pos1)
continue
return args, dict(dlinks)
def _sort_contraction_indices(pairing_indices):
pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
pairing_indices.sort(key=lambda x: min(x))
return pairing_indices
def _get_diagonal_indices(flattened_indices):
axes_contraction = defaultdict(list)
for i, ind in enumerate(flattened_indices):
if isinstance(ind, (int, Integer)):
# If the indices is a number, there can be no diagonal operation:
continue
axes_contraction[ind].append(i)
axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
# Put the diagonalized indices at the end:
ret_indices = [i for i in flattened_indices if i not in axes_contraction]
diag_indices = list(axes_contraction)
diag_indices.sort(key=lambda x: flattened_indices.index(x))
diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
ret_indices += diag_indices
ret_indices = tuple(ret_indices)
return diagonal_indices, ret_indices
def _get_argindex(subindices, ind):
for i, sind in enumerate(subindices):
if ind == sind:
return i
if isinstance(sind, (set, frozenset)) and ind in sind:
return i
raise IndexError("%s not found in %s" % (ind, subindices))
def _apply_recursively_over_nested_lists(func, arr):
if isinstance(arr, (tuple, list, Tuple)):
return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
elif isinstance(arr, Tuple):
return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
else:
return func(arr)
def _build_push_indices_up_func_transformation(flattened_contraction_indices):
shifts = {0: 0}
i = 0
cumulative = 0
while i < len(flattened_contraction_indices):
j = 1
while i+j < len(flattened_contraction_indices):
if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
break
j += 1
cumulative += j
shifts[flattened_contraction_indices[i]] = cumulative
i += j
shift_keys = sorted(shifts.keys())
def func(idx):
return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
def transform(j):
if j in flattened_contraction_indices:
return None
else:
return j - func(j)
return transform
def _build_push_indices_down_func_transformation(flattened_contraction_indices):
N = flattened_contraction_indices[-1]+2
shifts = [i for i in range(N) if i not in flattened_contraction_indices]
def transform(j):
if j < len(shifts):
return shifts[j]
else:
return j + shifts[-1] - len(shifts) + 1
return transform
def _apply_permutation_to_list(perm: Permutation, target_list: list):
"""
Permute a list according to the given permutation.
"""
new_list = [None for i in range(perm.size)]
for i, e in enumerate(target_list):
new_list[perm(i)] = e
return new_list