from __future__ import annotations from sympy.core.expr import Expr from sympy.core.function import Derivative from sympy.core.numbers import Integer from sympy.matrices.common import MatrixCommon from .ndim_array import NDimArray from .arrayop import derive_by_array from sympy.matrices.expressions.matexpr import MatrixExpr from sympy.matrices.expressions.special import ZeroMatrix from sympy.matrices.expressions.matexpr import _matrix_derivative class ArrayDerivative(Derivative): is_scalar = False def __new__(cls, expr, *variables, **kwargs): obj = super().__new__(cls, expr, *variables, **kwargs) if isinstance(obj, ArrayDerivative): obj._shape = obj._get_shape() return obj def _get_shape(self): shape = () for v, count in self.variable_count: if hasattr(v, "shape"): for i in range(count): shape += v.shape if hasattr(self.expr, "shape"): shape += self.expr.shape return shape @property def shape(self): return self._shape @classmethod def _get_zero_with_shape_like(cls, expr): if isinstance(expr, (MatrixCommon, NDimArray)): return expr.zeros(*expr.shape) elif isinstance(expr, MatrixExpr): return ZeroMatrix(*expr.shape) else: raise RuntimeError("Unable to determine shape of array-derivative.") @staticmethod def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixCommon) -> Expr: return v.applyfunc(lambda x: expr.diff(x)) @staticmethod def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr: if expr.has(v): return _matrix_derivative(expr, v) else: return ZeroMatrix(*v.shape) @staticmethod def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr: return v.applyfunc(lambda x: expr.diff(x)) @staticmethod def _call_derive_matrix_by_scalar(expr: MatrixCommon, v: Expr) -> Expr: return _matrix_derivative(expr, v) @staticmethod def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr: return expr._eval_derivative(v) @staticmethod def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr: return expr.applyfunc(lambda x: x.diff(v)) @staticmethod def _call_derive_default(expr: Expr, v: Expr) -> Expr | None: if expr.has(v): return _matrix_derivative(expr, v) else: return None @classmethod def _dispatch_eval_derivative_n_times(cls, expr, v, count): # Evaluate the derivative `n` times. If # `_eval_derivative_n_times` is not overridden by the current # object, the default in `Basic` will call a loop over # `_eval_derivative`: if not isinstance(count, (int, Integer)) or ((count <= 0) == True): return None # TODO: this could be done with multiple-dispatching: if expr.is_scalar: if isinstance(v, MatrixCommon): result = cls._call_derive_scalar_by_matrix(expr, v) elif isinstance(v, MatrixExpr): result = cls._call_derive_scalar_by_matexpr(expr, v) elif isinstance(v, NDimArray): result = cls._call_derive_scalar_by_array(expr, v) elif v.is_scalar: # scalar by scalar has a special return super()._dispatch_eval_derivative_n_times(expr, v, count) else: return None elif v.is_scalar: if isinstance(expr, MatrixCommon): result = cls._call_derive_matrix_by_scalar(expr, v) elif isinstance(expr, MatrixExpr): result = cls._call_derive_matexpr_by_scalar(expr, v) elif isinstance(expr, NDimArray): result = cls._call_derive_array_by_scalar(expr, v) else: return None else: # Both `expr` and `v` are some array/matrix type: if isinstance(expr, MatrixCommon) or isinstance(expr, MatrixCommon): result = derive_by_array(expr, v) elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr): result = cls._call_derive_default(expr, v) elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr): # if one expression is a symbolic matrix expression while the other isn't, don't evaluate: return None else: result = derive_by_array(expr, v) if result is None: return None if count == 1: return result else: return cls._dispatch_eval_derivative_n_times(result, v, count - 1)