187 lines
5.8 KiB
Python
187 lines
5.8 KiB
Python
|
from sympy.core import I, symbols, Basic, Mul, S
|
||
|
from sympy.core.mul import mul
|
||
|
from sympy.functions import adjoint, transpose
|
||
|
from sympy.matrices.common import ShapeError
|
||
|
from sympy.matrices import (Identity, Inverse, Matrix, MatrixSymbol, ZeroMatrix,
|
||
|
eye, ImmutableMatrix)
|
||
|
from sympy.matrices.expressions import Adjoint, Transpose, det, MatPow
|
||
|
from sympy.matrices.expressions.special import GenericIdentity
|
||
|
from sympy.matrices.expressions.matmul import (factor_in_front, remove_ids,
|
||
|
MatMul, combine_powers, any_zeros, unpack, only_squares)
|
||
|
from sympy.strategies import null_safe
|
||
|
from sympy.assumptions.ask import Q
|
||
|
from sympy.assumptions.refine import refine
|
||
|
from sympy.core.symbol import Symbol
|
||
|
|
||
|
from sympy.testing.pytest import XFAIL, raises
|
||
|
|
||
|
n, m, l, k = symbols('n m l k', integer=True)
|
||
|
x = symbols('x')
|
||
|
A = MatrixSymbol('A', n, m)
|
||
|
B = MatrixSymbol('B', m, l)
|
||
|
C = MatrixSymbol('C', n, n)
|
||
|
D = MatrixSymbol('D', n, n)
|
||
|
E = MatrixSymbol('E', m, n)
|
||
|
|
||
|
def test_evaluate():
|
||
|
assert MatMul(C, C, evaluate=True) == MatMul(C, C).doit()
|
||
|
|
||
|
def test_adjoint():
|
||
|
assert adjoint(A*B) == Adjoint(B)*Adjoint(A)
|
||
|
assert adjoint(2*A*B) == 2*Adjoint(B)*Adjoint(A)
|
||
|
assert adjoint(2*I*C) == -2*I*Adjoint(C)
|
||
|
|
||
|
M = Matrix(2, 2, [1, 2 + I, 3, 4])
|
||
|
MA = Matrix(2, 2, [1, 3, 2 - I, 4])
|
||
|
assert adjoint(M) == MA
|
||
|
assert adjoint(2*M) == 2*MA
|
||
|
assert adjoint(MatMul(2, M)) == MatMul(2, MA).doit()
|
||
|
|
||
|
|
||
|
def test_transpose():
|
||
|
assert transpose(A*B) == Transpose(B)*Transpose(A)
|
||
|
assert transpose(2*A*B) == 2*Transpose(B)*Transpose(A)
|
||
|
assert transpose(2*I*C) == 2*I*Transpose(C)
|
||
|
|
||
|
M = Matrix(2, 2, [1, 2 + I, 3, 4])
|
||
|
MT = Matrix(2, 2, [1, 3, 2 + I, 4])
|
||
|
assert transpose(M) == MT
|
||
|
assert transpose(2*M) == 2*MT
|
||
|
assert transpose(x*M) == x*MT
|
||
|
assert transpose(MatMul(2, M)) == MatMul(2, MT).doit()
|
||
|
|
||
|
|
||
|
def test_factor_in_front():
|
||
|
assert factor_in_front(MatMul(A, 2, B, evaluate=False)) ==\
|
||
|
MatMul(2, A, B, evaluate=False)
|
||
|
|
||
|
|
||
|
def test_remove_ids():
|
||
|
assert remove_ids(MatMul(A, Identity(m), B, evaluate=False)) == \
|
||
|
MatMul(A, B, evaluate=False)
|
||
|
assert null_safe(remove_ids)(MatMul(Identity(n), evaluate=False)) == \
|
||
|
MatMul(Identity(n), evaluate=False)
|
||
|
|
||
|
|
||
|
def test_combine_powers():
|
||
|
assert combine_powers(MatMul(D, Inverse(D), D, evaluate=False)) == \
|
||
|
MatMul(Identity(n), D, evaluate=False)
|
||
|
assert combine_powers(MatMul(B.T, Inverse(E*A), E, A, B, evaluate=False)) == \
|
||
|
MatMul(B.T, Identity(m), B, evaluate=False)
|
||
|
assert combine_powers(MatMul(A, E, Inverse(A*E), D, evaluate=False)) == \
|
||
|
MatMul(Identity(n), D, evaluate=False)
|
||
|
|
||
|
|
||
|
def test_any_zeros():
|
||
|
assert any_zeros(MatMul(A, ZeroMatrix(m, k), evaluate=False)) == \
|
||
|
ZeroMatrix(n, k)
|
||
|
|
||
|
|
||
|
def test_unpack():
|
||
|
assert unpack(MatMul(A, evaluate=False)) == A
|
||
|
x = MatMul(A, B)
|
||
|
assert unpack(x) == x
|
||
|
|
||
|
|
||
|
def test_only_squares():
|
||
|
assert only_squares(C) == [C]
|
||
|
assert only_squares(C, D) == [C, D]
|
||
|
assert only_squares(C, A, A.T, D) == [C, A*A.T, D]
|
||
|
|
||
|
|
||
|
def test_determinant():
|
||
|
assert det(2*C) == 2**n*det(C)
|
||
|
assert det(2*C*D) == 2**n*det(C)*det(D)
|
||
|
assert det(3*C*A*A.T*D) == 3**n*det(C)*det(A*A.T)*det(D)
|
||
|
|
||
|
|
||
|
def test_doit():
|
||
|
assert MatMul(C, 2, D).args == (C, 2, D)
|
||
|
assert MatMul(C, 2, D).doit().args == (2, C, D)
|
||
|
assert MatMul(C, Transpose(D*C)).args == (C, Transpose(D*C))
|
||
|
assert MatMul(C, Transpose(D*C)).doit(deep=True).args == (C, C.T, D.T)
|
||
|
|
||
|
|
||
|
def test_doit_drills_down():
|
||
|
X = ImmutableMatrix([[1, 2], [3, 4]])
|
||
|
Y = ImmutableMatrix([[2, 3], [4, 5]])
|
||
|
assert MatMul(X, MatPow(Y, 2)).doit() == X*Y**2
|
||
|
assert MatMul(C, Transpose(D*C)).doit().args == (C, C.T, D.T)
|
||
|
|
||
|
|
||
|
def test_doit_deep_false_still_canonical():
|
||
|
assert (MatMul(C, Transpose(D*C), 2).doit(deep=False).args ==
|
||
|
(2, C, Transpose(D*C)))
|
||
|
|
||
|
|
||
|
def test_matmul_scalar_Matrix_doit():
|
||
|
# Issue 9053
|
||
|
X = Matrix([[1, 2], [3, 4]])
|
||
|
assert MatMul(2, X).doit() == 2*X
|
||
|
|
||
|
|
||
|
def test_matmul_sympify():
|
||
|
assert isinstance(MatMul(eye(1), eye(1)).args[0], Basic)
|
||
|
|
||
|
|
||
|
def test_collapse_MatrixBase():
|
||
|
A = Matrix([[1, 1], [1, 1]])
|
||
|
B = Matrix([[1, 2], [3, 4]])
|
||
|
assert MatMul(A, B).doit() == ImmutableMatrix([[4, 6], [4, 6]])
|
||
|
|
||
|
|
||
|
def test_refine():
|
||
|
assert refine(C*C.T*D, Q.orthogonal(C)).doit() == D
|
||
|
|
||
|
kC = k*C
|
||
|
assert refine(kC*C.T, Q.orthogonal(C)).doit() == k*Identity(n)
|
||
|
assert refine(kC* kC.T, Q.orthogonal(C)).doit() == (k**2)*Identity(n)
|
||
|
|
||
|
def test_matmul_no_matrices():
|
||
|
assert MatMul(1) == 1
|
||
|
assert MatMul(n, m) == n*m
|
||
|
assert not isinstance(MatMul(n, m), MatMul)
|
||
|
|
||
|
def test_matmul_args_cnc():
|
||
|
assert MatMul(n, A, A.T).args_cnc() == [[n], [A, A.T]]
|
||
|
assert MatMul(A, A.T).args_cnc() == [[], [A, A.T]]
|
||
|
|
||
|
@XFAIL
|
||
|
def test_matmul_args_cnc_symbols():
|
||
|
# Not currently supported
|
||
|
a, b = symbols('a b', commutative=False)
|
||
|
assert MatMul(n, a, b, A, A.T).args_cnc() == [[n], [a, b, A, A.T]]
|
||
|
assert MatMul(n, a, A, b, A.T).args_cnc() == [[n], [a, A, b, A.T]]
|
||
|
|
||
|
def test_issue_12950():
|
||
|
M = Matrix([[Symbol("x")]]) * MatrixSymbol("A", 1, 1)
|
||
|
assert MatrixSymbol("A", 1, 1).as_explicit()[0]*Symbol('x') == M.as_explicit()[0]
|
||
|
|
||
|
def test_construction_with_Mul():
|
||
|
assert Mul(C, D) == MatMul(C, D)
|
||
|
assert Mul(D, C) == MatMul(D, C)
|
||
|
|
||
|
def test_construction_with_mul():
|
||
|
assert mul(C, D) == MatMul(C, D)
|
||
|
assert mul(D, C) == MatMul(D, C)
|
||
|
assert mul(C, D) != MatMul(D, C)
|
||
|
|
||
|
def test_generic_identity():
|
||
|
assert MatMul.identity == GenericIdentity()
|
||
|
assert MatMul.identity != S.One
|
||
|
|
||
|
|
||
|
def test_issue_23519():
|
||
|
N = Symbol("N", integer=True)
|
||
|
M1 = MatrixSymbol("M1", N, N)
|
||
|
M2 = MatrixSymbol("M2", N, N)
|
||
|
I = Identity(N)
|
||
|
z = (M2 + 2 * (M2 + I) * M1 + I)
|
||
|
assert z.coeff(M1) == 2*I + 2*M2
|
||
|
|
||
|
|
||
|
def test_shape_error():
|
||
|
A = MatrixSymbol('A', 2, 2)
|
||
|
B = MatrixSymbol('B', 3, 3)
|
||
|
raises(ShapeError, lambda: MatMul(A, B))
|