from sympy.core.relational import Eq from sympy.core.expr import Expr from sympy.core.numbers import Integer from sympy.logic.boolalg import Boolean, And from sympy.matrices.expressions.matexpr import MatrixExpr from sympy.matrices.common import ShapeError from typing import Union def is_matadd_valid(*args: MatrixExpr) -> Boolean: """Return the symbolic condition how ``MatAdd``, ``HadamardProduct`` makes sense. Parameters ========== args The list of arguments of matrices to be tested for. Examples ======== >>> from sympy import MatrixSymbol, symbols >>> from sympy.matrices.expressions._shape import is_matadd_valid >>> m, n, p, q = symbols('m n p q') >>> A = MatrixSymbol('A', m, n) >>> B = MatrixSymbol('B', p, q) >>> is_matadd_valid(A, B) Eq(m, p) & Eq(n, q) """ rows, cols = zip(*(arg.shape for arg in args)) return And( *(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])), *(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])), ) def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean: """Return the symbolic condition how ``MatMul`` makes sense Parameters ========== args The list of arguments of matrices and scalar expressions to be tested for. Examples ======== >>> from sympy import MatrixSymbol, symbols >>> from sympy.matrices.expressions._shape import is_matmul_valid >>> m, n, p, q = symbols('m n p q') >>> A = MatrixSymbol('A', m, n) >>> B = MatrixSymbol('B', p, q) >>> is_matmul_valid(A, B) Eq(n, p) """ rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr))) return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:]))) def is_square(arg: MatrixExpr, /) -> Boolean: """Return the symbolic condition how the matrix is assumed to be square Parameters ========== arg The matrix to be tested for. Examples ======== >>> from sympy import MatrixSymbol, symbols >>> from sympy.matrices.expressions._shape import is_square >>> m, n = symbols('m n') >>> A = MatrixSymbol('A', m, n) >>> is_square(A) Eq(m, n) """ return Eq(arg.rows, arg.cols) def validate_matadd_integer(*args: MatrixExpr) -> None: """Validate matrix shape for addition only for integer values""" rows, cols = zip(*(x.shape for x in args)) if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1: raise ShapeError(f"Matrices have mismatching shape: {rows}") if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1: raise ShapeError(f"Matrices have mismatching shape: {cols}") def validate_matmul_integer(*args: MatrixExpr) -> None: """Validate matrix shape for multiplication only for integer values""" for A, B in zip(args[:-1], args[1:]): i, j = A.cols, B.rows if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j: raise ShapeError("Matrices are not aligned", i, j)