400 lines
12 KiB
Python
400 lines
12 KiB
Python
import functools, itertools
|
|
from sympy.core.sympify import _sympify, sympify
|
|
from sympy.core.expr import Expr
|
|
from sympy.core import Basic, Tuple
|
|
from sympy.tensor.array import ImmutableDenseNDimArray
|
|
from sympy.core.symbol import Symbol
|
|
from sympy.core.numbers import Integer
|
|
|
|
|
|
class ArrayComprehension(Basic):
|
|
"""
|
|
Generate a list comprehension.
|
|
|
|
Explanation
|
|
===========
|
|
|
|
If there is a symbolic dimension, for example, say [i for i in range(1, N)] where
|
|
N is a Symbol, then the expression will not be expanded to an array. Otherwise,
|
|
calling the doit() function will launch the expansion.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a
|
|
ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.doit()
|
|
[[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
|
|
>>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
|
|
>>> b.doit()
|
|
ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
|
|
"""
|
|
def __new__(cls, function, *symbols, **assumptions):
|
|
if any(len(l) != 3 or None for l in symbols):
|
|
raise ValueError('ArrayComprehension requires values lower and upper bound'
|
|
' for the expression')
|
|
arglist = [sympify(function)]
|
|
arglist.extend(cls._check_limits_validity(function, symbols))
|
|
obj = Basic.__new__(cls, *arglist, **assumptions)
|
|
obj._limits = obj._args[1:]
|
|
obj._shape = cls._calculate_shape_from_limits(obj._limits)
|
|
obj._rank = len(obj._shape)
|
|
obj._loop_size = cls._calculate_loop_size(obj._shape)
|
|
return obj
|
|
|
|
@property
|
|
def function(self):
|
|
"""The function applied across limits.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j = symbols('i j')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.function
|
|
10*i + j
|
|
"""
|
|
return self._args[0]
|
|
|
|
@property
|
|
def limits(self):
|
|
"""
|
|
The list of limits that will be applied while expanding the array.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j = symbols('i j')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.limits
|
|
((i, 1, 4), (j, 1, 3))
|
|
"""
|
|
return self._limits
|
|
|
|
@property
|
|
def free_symbols(self):
|
|
"""
|
|
The set of the free_symbols in the array.
|
|
Variables appeared in the bounds are supposed to be excluded
|
|
from the free symbol set.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.free_symbols
|
|
set()
|
|
>>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
|
|
>>> b.free_symbols
|
|
{k}
|
|
"""
|
|
expr_free_sym = self.function.free_symbols
|
|
for var, inf, sup in self._limits:
|
|
expr_free_sym.discard(var)
|
|
curr_free_syms = inf.free_symbols.union(sup.free_symbols)
|
|
expr_free_sym = expr_free_sym.union(curr_free_syms)
|
|
return expr_free_sym
|
|
|
|
@property
|
|
def variables(self):
|
|
"""The tuples of the variables in the limits.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.variables
|
|
[i, j]
|
|
"""
|
|
return [l[0] for l in self._limits]
|
|
|
|
@property
|
|
def bound_symbols(self):
|
|
"""The list of dummy variables.
|
|
|
|
Note
|
|
====
|
|
|
|
Note that all variables are dummy variables since a limit without
|
|
lower bound or upper bound is not accepted.
|
|
"""
|
|
return [l[0] for l in self._limits if len(l) != 1]
|
|
|
|
@property
|
|
def shape(self):
|
|
"""
|
|
The shape of the expanded array, which may have symbols.
|
|
|
|
Note
|
|
====
|
|
|
|
Both the lower and the upper bounds are included while
|
|
calculating the shape.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.shape
|
|
(4, 3)
|
|
>>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
|
|
>>> b.shape
|
|
(4, k + 3)
|
|
"""
|
|
return self._shape
|
|
|
|
@property
|
|
def is_shape_numeric(self):
|
|
"""
|
|
Test if the array is shape-numeric which means there is no symbolic
|
|
dimension.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.is_shape_numeric
|
|
True
|
|
>>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
|
|
>>> b.is_shape_numeric
|
|
False
|
|
"""
|
|
for _, inf, sup in self._limits:
|
|
if Basic(inf, sup).atoms(Symbol):
|
|
return False
|
|
return True
|
|
|
|
def rank(self):
|
|
"""The rank of the expanded array.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.rank()
|
|
2
|
|
"""
|
|
return self._rank
|
|
|
|
def __len__(self):
|
|
"""
|
|
The length of the expanded array which means the number
|
|
of elements in the array.
|
|
|
|
Raises
|
|
======
|
|
|
|
ValueError : When the length of the array is symbolic
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j = symbols('i j')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> len(a)
|
|
12
|
|
"""
|
|
if self._loop_size.free_symbols:
|
|
raise ValueError('Symbolic length is not supported')
|
|
return self._loop_size
|
|
|
|
@classmethod
|
|
def _check_limits_validity(cls, function, limits):
|
|
#limits = sympify(limits)
|
|
new_limits = []
|
|
for var, inf, sup in limits:
|
|
var = _sympify(var)
|
|
inf = _sympify(inf)
|
|
#since this is stored as an argument, it should be
|
|
#a Tuple
|
|
if isinstance(sup, list):
|
|
sup = Tuple(*sup)
|
|
else:
|
|
sup = _sympify(sup)
|
|
new_limits.append(Tuple(var, inf, sup))
|
|
if any((not isinstance(i, Expr)) or i.atoms(Symbol, Integer) != i.atoms()
|
|
for i in [inf, sup]):
|
|
raise TypeError('Bounds should be an Expression(combination of Integer and Symbol)')
|
|
if (inf > sup) == True:
|
|
raise ValueError('Lower bound should be inferior to upper bound')
|
|
if var in inf.free_symbols or var in sup.free_symbols:
|
|
raise ValueError('Variable should not be part of its bounds')
|
|
return new_limits
|
|
|
|
@classmethod
|
|
def _calculate_shape_from_limits(cls, limits):
|
|
return tuple([sup - inf + 1 for _, inf, sup in limits])
|
|
|
|
@classmethod
|
|
def _calculate_loop_size(cls, shape):
|
|
if not shape:
|
|
return 0
|
|
loop_size = 1
|
|
for l in shape:
|
|
loop_size = loop_size * l
|
|
|
|
return loop_size
|
|
|
|
def doit(self, **hints):
|
|
if not self.is_shape_numeric:
|
|
return self
|
|
|
|
return self._expand_array()
|
|
|
|
def _expand_array(self):
|
|
res = []
|
|
for values in itertools.product(*[range(inf, sup+1)
|
|
for var, inf, sup
|
|
in self._limits]):
|
|
res.append(self._get_element(values))
|
|
|
|
return ImmutableDenseNDimArray(res, self.shape)
|
|
|
|
def _get_element(self, values):
|
|
temp = self.function
|
|
for var, val in zip(self.variables, values):
|
|
temp = temp.subs(var, val)
|
|
return temp
|
|
|
|
def tolist(self):
|
|
"""Transform the expanded array to a list.
|
|
|
|
Raises
|
|
======
|
|
|
|
ValueError : When there is a symbolic dimension
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j = symbols('i j')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.tolist()
|
|
[[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
|
|
"""
|
|
if self.is_shape_numeric:
|
|
return self._expand_array().tolist()
|
|
|
|
raise ValueError("A symbolic array cannot be expanded to a list")
|
|
|
|
def tomatrix(self):
|
|
"""Transform the expanded array to a matrix.
|
|
|
|
Raises
|
|
======
|
|
|
|
ValueError : When there is a symbolic dimension
|
|
ValueError : When the rank of the expanded array is not equal to 2
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehension
|
|
>>> from sympy import symbols
|
|
>>> i, j = symbols('i j')
|
|
>>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
|
|
>>> a.tomatrix()
|
|
Matrix([
|
|
[11, 12, 13],
|
|
[21, 22, 23],
|
|
[31, 32, 33],
|
|
[41, 42, 43]])
|
|
"""
|
|
from sympy.matrices import Matrix
|
|
|
|
if not self.is_shape_numeric:
|
|
raise ValueError("A symbolic array cannot be expanded to a matrix")
|
|
if self._rank != 2:
|
|
raise ValueError('Dimensions must be of size of 2')
|
|
|
|
return Matrix(self._expand_array().tomatrix())
|
|
|
|
|
|
def isLambda(v):
|
|
LAMBDA = lambda: 0
|
|
return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__
|
|
|
|
class ArrayComprehensionMap(ArrayComprehension):
|
|
'''
|
|
A subclass of ArrayComprehension dedicated to map external function lambda.
|
|
|
|
Notes
|
|
=====
|
|
|
|
Only the lambda function is considered.
|
|
At most one argument in lambda function is accepted in order to avoid ambiguity
|
|
in value assignment.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.tensor.array import ArrayComprehensionMap
|
|
>>> from sympy import symbols
|
|
>>> i, j, k = symbols('i j k')
|
|
>>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4))
|
|
>>> a.doit()
|
|
[1, 1, 1, 1]
|
|
>>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4))
|
|
>>> b.doit()
|
|
[2, 3, 4, 5]
|
|
|
|
'''
|
|
def __new__(cls, function, *symbols, **assumptions):
|
|
if any(len(l) != 3 or None for l in symbols):
|
|
raise ValueError('ArrayComprehension requires values lower and upper bound'
|
|
' for the expression')
|
|
|
|
if not isLambda(function):
|
|
raise ValueError('Data type not supported')
|
|
|
|
arglist = cls._check_limits_validity(function, symbols)
|
|
obj = Basic.__new__(cls, *arglist, **assumptions)
|
|
obj._limits = obj._args
|
|
obj._shape = cls._calculate_shape_from_limits(obj._limits)
|
|
obj._rank = len(obj._shape)
|
|
obj._loop_size = cls._calculate_loop_size(obj._shape)
|
|
obj._lambda = function
|
|
return obj
|
|
|
|
@property
|
|
def func(self):
|
|
class _(ArrayComprehensionMap):
|
|
def __new__(cls, *args, **kwargs):
|
|
return ArrayComprehensionMap(self._lambda, *args, **kwargs)
|
|
return _
|
|
|
|
def _get_element(self, values):
|
|
temp = self._lambda
|
|
if self._lambda.__code__.co_argcount == 0:
|
|
temp = temp()
|
|
elif self._lambda.__code__.co_argcount == 1:
|
|
temp = temp(functools.reduce(lambda a, b: a*b, values))
|
|
return temp
|