ai-content-maker/.venv/Lib/site-packages/sympy/tensor/array/array_derivatives.py

130 lines
4.7 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from __future__ import annotations
from sympy.core.expr import Expr
from sympy.core.function import Derivative
from sympy.core.numbers import Integer
from sympy.matrices.common import MatrixCommon
from .ndim_array import NDimArray
from .arrayop import derive_by_array
from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.matrices.expressions.special import ZeroMatrix
from sympy.matrices.expressions.matexpr import _matrix_derivative
class ArrayDerivative(Derivative):
is_scalar = False
def __new__(cls, expr, *variables, **kwargs):
obj = super().__new__(cls, expr, *variables, **kwargs)
if isinstance(obj, ArrayDerivative):
obj._shape = obj._get_shape()
return obj
def _get_shape(self):
shape = ()
for v, count in self.variable_count:
if hasattr(v, "shape"):
for i in range(count):
shape += v.shape
if hasattr(self.expr, "shape"):
shape += self.expr.shape
return shape
@property
def shape(self):
return self._shape
@classmethod
def _get_zero_with_shape_like(cls, expr):
if isinstance(expr, (MatrixCommon, NDimArray)):
return expr.zeros(*expr.shape)
elif isinstance(expr, MatrixExpr):
return ZeroMatrix(*expr.shape)
else:
raise RuntimeError("Unable to determine shape of array-derivative.")
@staticmethod
def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixCommon) -> Expr:
return v.applyfunc(lambda x: expr.diff(x))
@staticmethod
def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr:
if expr.has(v):
return _matrix_derivative(expr, v)
else:
return ZeroMatrix(*v.shape)
@staticmethod
def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr:
return v.applyfunc(lambda x: expr.diff(x))
@staticmethod
def _call_derive_matrix_by_scalar(expr: MatrixCommon, v: Expr) -> Expr:
return _matrix_derivative(expr, v)
@staticmethod
def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr:
return expr._eval_derivative(v)
@staticmethod
def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr:
return expr.applyfunc(lambda x: x.diff(v))
@staticmethod
def _call_derive_default(expr: Expr, v: Expr) -> Expr | None:
if expr.has(v):
return _matrix_derivative(expr, v)
else:
return None
@classmethod
def _dispatch_eval_derivative_n_times(cls, expr, v, count):
# Evaluate the derivative `n` times. If
# `_eval_derivative_n_times` is not overridden by the current
# object, the default in `Basic` will call a loop over
# `_eval_derivative`:
if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
return None
# TODO: this could be done with multiple-dispatching:
if expr.is_scalar:
if isinstance(v, MatrixCommon):
result = cls._call_derive_scalar_by_matrix(expr, v)
elif isinstance(v, MatrixExpr):
result = cls._call_derive_scalar_by_matexpr(expr, v)
elif isinstance(v, NDimArray):
result = cls._call_derive_scalar_by_array(expr, v)
elif v.is_scalar:
# scalar by scalar has a special
return super()._dispatch_eval_derivative_n_times(expr, v, count)
else:
return None
elif v.is_scalar:
if isinstance(expr, MatrixCommon):
result = cls._call_derive_matrix_by_scalar(expr, v)
elif isinstance(expr, MatrixExpr):
result = cls._call_derive_matexpr_by_scalar(expr, v)
elif isinstance(expr, NDimArray):
result = cls._call_derive_array_by_scalar(expr, v)
else:
return None
else:
# Both `expr` and `v` are some array/matrix type:
if isinstance(expr, MatrixCommon) or isinstance(expr, MatrixCommon):
result = derive_by_array(expr, v)
elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
result = cls._call_derive_default(expr, v)
elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
# if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
return None
else:
result = derive_by_array(expr, v)
if result is None:
return None
if count == 1:
return result
else:
return cls._dispatch_eval_derivative_n_times(result, v, count - 1)