590 lines
17 KiB
Python
590 lines
17 KiB
Python
|
from io import StringIO
|
||
|
|
||
|
from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
|
||
|
from sympy.core.relational import Equality
|
||
|
from sympy.functions.elementary.piecewise import Piecewise
|
||
|
from sympy.matrices import Matrix, MatrixSymbol
|
||
|
from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine
|
||
|
from sympy.testing.pytest import raises
|
||
|
from sympy.testing.pytest import XFAIL
|
||
|
import sympy
|
||
|
|
||
|
|
||
|
x, y, z = symbols('x,y,z')
|
||
|
|
||
|
|
||
|
def test_empty_m_code():
|
||
|
code_gen = OctaveCodeGen()
|
||
|
output = StringIO()
|
||
|
code_gen.dump_m([], output, "file", header=False, empty=False)
|
||
|
source = output.getvalue()
|
||
|
assert source == ""
|
||
|
|
||
|
|
||
|
def test_m_simple_code():
|
||
|
name_expr = ("test", (x + y)*z)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
assert result[0] == "test.m"
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function out1 = test(x, y, z)\n"
|
||
|
" out1 = z.*(x + y);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_simple_code_with_header():
|
||
|
name_expr = ("test", (x + y)*z)
|
||
|
result, = codegen(name_expr, "Octave", header=True, empty=False)
|
||
|
assert result[0] == "test.m"
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function out1 = test(x, y, z)\n"
|
||
|
" %TEST Autogenerated by SymPy\n"
|
||
|
" % Code generated with SymPy " + sympy.__version__ + "\n"
|
||
|
" %\n"
|
||
|
" % See http://www.sympy.org/ for more information.\n"
|
||
|
" %\n"
|
||
|
" % This file is part of 'project'\n"
|
||
|
" out1 = z.*(x + y);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_simple_code_nameout():
|
||
|
expr = Equality(z, (x + y))
|
||
|
name_expr = ("test", expr)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function z = test(x, y)\n"
|
||
|
" z = x + y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_numbersymbol():
|
||
|
name_expr = ("test", pi**Catalan)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function out1 = test()\n"
|
||
|
" out1 = pi^%s;\n"
|
||
|
"end\n"
|
||
|
) % Catalan.evalf(17)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
@XFAIL
|
||
|
def test_m_numbersymbol_no_inline():
|
||
|
# FIXME: how to pass inline=False to the OctaveCodePrinter?
|
||
|
name_expr = ("test", [pi**Catalan, EulerGamma])
|
||
|
result, = codegen(name_expr, "Octave", header=False,
|
||
|
empty=False, inline=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [out1, out2] = test()\n"
|
||
|
" Catalan = 0.915965594177219; % constant\n"
|
||
|
" EulerGamma = 0.5772156649015329; % constant\n"
|
||
|
" out1 = pi^Catalan;\n"
|
||
|
" out2 = EulerGamma;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_code_argument_order():
|
||
|
expr = x + y
|
||
|
routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave")
|
||
|
code_gen = OctaveCodeGen()
|
||
|
output = StringIO()
|
||
|
code_gen.dump_m([routine], output, "test", header=False, empty=False)
|
||
|
source = output.getvalue()
|
||
|
expected = (
|
||
|
"function out1 = test(z, x, y)\n"
|
||
|
" out1 = x + y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_multiple_results_m():
|
||
|
# Here the output order is the input order
|
||
|
expr1 = (x + y)*z
|
||
|
expr2 = (x - y)*z
|
||
|
name_expr = ("test", [expr1, expr2])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [out1, out2] = test(x, y, z)\n"
|
||
|
" out1 = z.*(x + y);\n"
|
||
|
" out2 = z.*(x - y);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_results_named_unordered():
|
||
|
# Here output order is based on name_expr
|
||
|
A, B, C = symbols('A,B,C')
|
||
|
expr1 = Equality(C, (x + y)*z)
|
||
|
expr2 = Equality(A, (x - y)*z)
|
||
|
expr3 = Equality(B, 2*x)
|
||
|
name_expr = ("test", [expr1, expr2, expr3])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [C, A, B] = test(x, y, z)\n"
|
||
|
" C = z.*(x + y);\n"
|
||
|
" A = z.*(x - y);\n"
|
||
|
" B = 2*x;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_results_named_ordered():
|
||
|
A, B, C = symbols('A,B,C')
|
||
|
expr1 = Equality(C, (x + y)*z)
|
||
|
expr2 = Equality(A, (x - y)*z)
|
||
|
expr3 = Equality(B, 2*x)
|
||
|
name_expr = ("test", [expr1, expr2, expr3])
|
||
|
result = codegen(name_expr, "Octave", header=False, empty=False,
|
||
|
argument_sequence=(x, z, y))
|
||
|
assert result[0][0] == "test.m"
|
||
|
source = result[0][1]
|
||
|
expected = (
|
||
|
"function [C, A, B] = test(x, z, y)\n"
|
||
|
" C = z.*(x + y);\n"
|
||
|
" A = z.*(x - y);\n"
|
||
|
" B = 2*x;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_complicated_m_codegen():
|
||
|
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
|
||
|
name_expr = ("testlong",
|
||
|
[ ((sin(x) + cos(y) + tan(z))**3).expand(),
|
||
|
cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
|
||
|
])
|
||
|
result = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
assert result[0][0] == "testlong.m"
|
||
|
source = result[0][1]
|
||
|
expected = (
|
||
|
"function [out1, out2] = testlong(x, y, z)\n"
|
||
|
" out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)"
|
||
|
" + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2"
|
||
|
" + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n"
|
||
|
" out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_output_arg_mixed_unordered():
|
||
|
# named outputs are alphabetical, unnamed output appear in the given order
|
||
|
from sympy.functions.elementary.trigonometric import (cos, sin)
|
||
|
a = symbols("a")
|
||
|
name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
assert result[0] == "foo.m"
|
||
|
source = result[1];
|
||
|
expected = (
|
||
|
'function [out1, y, out3, a] = foo(x)\n'
|
||
|
' out1 = cos(2*x);\n'
|
||
|
' y = sin(x);\n'
|
||
|
' out3 = cos(x);\n'
|
||
|
' a = sin(2*x);\n'
|
||
|
'end\n'
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_piecewise_():
|
||
|
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
|
||
|
name_expr = ("pwtest", pw)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function out1 = pwtest(x)\n"
|
||
|
" out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n"
|
||
|
" (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n"
|
||
|
" (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
@XFAIL
|
||
|
def test_m_piecewise_no_inline():
|
||
|
# FIXME: how to pass inline=False to the OctaveCodePrinter?
|
||
|
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
|
||
|
name_expr = ("pwtest", pw)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False,
|
||
|
inline=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function out1 = pwtest(x)\n"
|
||
|
" if (x < -1)\n"
|
||
|
" out1 = 0;\n"
|
||
|
" elseif (x <= 1)\n"
|
||
|
" out1 = x.^2;\n"
|
||
|
" elseif (x > 1)\n"
|
||
|
" out1 = -x + 2;\n"
|
||
|
" else\n"
|
||
|
" out1 = 1;\n"
|
||
|
" end\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_multifcns_per_file():
|
||
|
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
||
|
result = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
assert result[0][0] == "foo.m"
|
||
|
source = result[0][1];
|
||
|
expected = (
|
||
|
"function [out1, out2] = foo(x, y)\n"
|
||
|
" out1 = 2*x;\n"
|
||
|
" out2 = 3*y;\n"
|
||
|
"end\n"
|
||
|
"function [out1, out2] = bar(y)\n"
|
||
|
" out1 = y.^2;\n"
|
||
|
" out2 = 4*y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_multifcns_per_file_w_header():
|
||
|
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
||
|
result = codegen(name_expr, "Octave", header=True, empty=False)
|
||
|
assert result[0][0] == "foo.m"
|
||
|
source = result[0][1];
|
||
|
expected = (
|
||
|
"function [out1, out2] = foo(x, y)\n"
|
||
|
" %FOO Autogenerated by SymPy\n"
|
||
|
" % Code generated with SymPy " + sympy.__version__ + "\n"
|
||
|
" %\n"
|
||
|
" % See http://www.sympy.org/ for more information.\n"
|
||
|
" %\n"
|
||
|
" % This file is part of 'project'\n"
|
||
|
" out1 = 2*x;\n"
|
||
|
" out2 = 3*y;\n"
|
||
|
"end\n"
|
||
|
"function [out1, out2] = bar(y)\n"
|
||
|
" out1 = y.^2;\n"
|
||
|
" out2 = 4*y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_filename_match_first_fcn():
|
||
|
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
|
||
|
raises(ValueError, lambda: codegen(name_expr,
|
||
|
"Octave", prefix="bar", header=False, empty=False))
|
||
|
|
||
|
|
||
|
def test_m_matrix_named():
|
||
|
e2 = Matrix([[x, 2*y, pi*z]])
|
||
|
name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
|
||
|
result = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
assert result[0][0] == "test.m"
|
||
|
source = result[0][1]
|
||
|
expected = (
|
||
|
"function myout1 = test(x, y, z)\n"
|
||
|
" myout1 = [x 2*y pi*z];\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrix_named_matsym():
|
||
|
myout1 = MatrixSymbol('myout1', 1, 3)
|
||
|
e2 = Matrix([[x, 2*y, pi*z]])
|
||
|
name_expr = ("test", Equality(myout1, e2, evaluate=False))
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function myout1 = test(x, y, z)\n"
|
||
|
" myout1 = [x 2*y pi*z];\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrix_output_autoname():
|
||
|
expr = Matrix([[x, x+y, 3]])
|
||
|
name_expr = ("test", expr)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function out1 = test(x, y)\n"
|
||
|
" out1 = [x x + y 3];\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrix_output_autoname_2():
|
||
|
e1 = (x + y)
|
||
|
e2 = Matrix([[2*x, 2*y, 2*z]])
|
||
|
e3 = Matrix([[x], [y], [z]])
|
||
|
e4 = Matrix([[x, y], [z, 16]])
|
||
|
name_expr = ("test", (e1, e2, e3, e4))
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [out1, out2, out3, out4] = test(x, y, z)\n"
|
||
|
" out1 = x + y;\n"
|
||
|
" out2 = [2*x 2*y 2*z];\n"
|
||
|
" out3 = [x; y; z];\n"
|
||
|
" out4 = [x y; z 16];\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_results_matrix_named_ordered():
|
||
|
B, C = symbols('B,C')
|
||
|
A = MatrixSymbol('A', 1, 3)
|
||
|
expr1 = Equality(C, (x + y)*z)
|
||
|
expr2 = Equality(A, Matrix([[1, 2, x]]))
|
||
|
expr3 = Equality(B, 2*x)
|
||
|
name_expr = ("test", [expr1, expr2, expr3])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False,
|
||
|
argument_sequence=(x, z, y))
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [C, A, B] = test(x, z, y)\n"
|
||
|
" C = z.*(x + y);\n"
|
||
|
" A = [1 2 x];\n"
|
||
|
" B = 2*x;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrixsymbol_slice():
|
||
|
A = MatrixSymbol('A', 2, 3)
|
||
|
B = MatrixSymbol('B', 1, 3)
|
||
|
C = MatrixSymbol('C', 1, 3)
|
||
|
D = MatrixSymbol('D', 2, 1)
|
||
|
name_expr = ("test", [Equality(B, A[0, :]),
|
||
|
Equality(C, A[1, :]),
|
||
|
Equality(D, A[:, 2])])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [B, C, D] = test(A)\n"
|
||
|
" B = A(1, :);\n"
|
||
|
" C = A(2, :);\n"
|
||
|
" D = A(:, 3);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrixsymbol_slice2():
|
||
|
A = MatrixSymbol('A', 3, 4)
|
||
|
B = MatrixSymbol('B', 2, 2)
|
||
|
C = MatrixSymbol('C', 2, 2)
|
||
|
name_expr = ("test", [Equality(B, A[0:2, 0:2]),
|
||
|
Equality(C, A[0:2, 1:3])])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [B, C] = test(A)\n"
|
||
|
" B = A(1:2, 1:2);\n"
|
||
|
" C = A(1:2, 2:3);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrixsymbol_slice3():
|
||
|
A = MatrixSymbol('A', 8, 7)
|
||
|
B = MatrixSymbol('B', 2, 2)
|
||
|
C = MatrixSymbol('C', 4, 2)
|
||
|
name_expr = ("test", [Equality(B, A[6:, 1::3]),
|
||
|
Equality(C, A[::2, ::3])])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [B, C] = test(A)\n"
|
||
|
" B = A(7:end, 2:3:end);\n"
|
||
|
" C = A(1:2:end, 1:3:end);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_matrixsymbol_slice_autoname():
|
||
|
A = MatrixSymbol('A', 2, 3)
|
||
|
B = MatrixSymbol('B', 1, 3)
|
||
|
name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [B, out2, out3, out4] = test(A)\n"
|
||
|
" B = A(1, :);\n"
|
||
|
" out2 = A(2, :);\n"
|
||
|
" out3 = A(:, 1);\n"
|
||
|
" out4 = A(:, 2);\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_loops():
|
||
|
# Note: an Octave programmer would probably vectorize this across one or
|
||
|
# more dimensions. Also, size(A) would be used rather than passing in m
|
||
|
# and n. Perhaps users would expect us to vectorize automatically here?
|
||
|
# Or is it possible to represent such things using IndexedBase?
|
||
|
from sympy.tensor import IndexedBase, Idx
|
||
|
from sympy.core.symbol import symbols
|
||
|
n, m = symbols('n m', integer=True)
|
||
|
A = IndexedBase('A')
|
||
|
x = IndexedBase('x')
|
||
|
y = IndexedBase('y')
|
||
|
i = Idx('i', m)
|
||
|
j = Idx('j', n)
|
||
|
result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave",
|
||
|
header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
'function y = mat_vec_mult(A, m, n, x)\n'
|
||
|
' for i = 1:m\n'
|
||
|
' y(i) = 0;\n'
|
||
|
' end\n'
|
||
|
' for i = 1:m\n'
|
||
|
' for j = 1:n\n'
|
||
|
' y(i) = %(rhs)s + y(i);\n'
|
||
|
' end\n'
|
||
|
' end\n'
|
||
|
'end\n'
|
||
|
)
|
||
|
assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or
|
||
|
source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)})
|
||
|
|
||
|
|
||
|
def test_m_tensor_loops_multiple_contractions():
|
||
|
# see comments in previous test about vectorizing
|
||
|
from sympy.tensor import IndexedBase, Idx
|
||
|
from sympy.core.symbol import symbols
|
||
|
n, m, o, p = symbols('n m o p', integer=True)
|
||
|
A = IndexedBase('A')
|
||
|
B = IndexedBase('B')
|
||
|
y = IndexedBase('y')
|
||
|
i = Idx('i', m)
|
||
|
j = Idx('j', n)
|
||
|
k = Idx('k', o)
|
||
|
l = Idx('l', p)
|
||
|
result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
|
||
|
"Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
'function y = tensorthing(A, B, m, n, o, p)\n'
|
||
|
' for i = 1:m\n'
|
||
|
' y(i) = 0;\n'
|
||
|
' end\n'
|
||
|
' for i = 1:m\n'
|
||
|
' for j = 1:n\n'
|
||
|
' for k = 1:o\n'
|
||
|
' for l = 1:p\n'
|
||
|
' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n'
|
||
|
' end\n'
|
||
|
' end\n'
|
||
|
' end\n'
|
||
|
' end\n'
|
||
|
'end\n'
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_InOutArgument():
|
||
|
expr = Equality(x, x**2)
|
||
|
name_expr = ("mysqr", expr)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function x = mysqr(x)\n"
|
||
|
" x = x.^2;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_InOutArgument_order():
|
||
|
# can specify the order as (x, y)
|
||
|
expr = Equality(x, x**2 + y)
|
||
|
name_expr = ("test", expr)
|
||
|
result, = codegen(name_expr, "Octave", header=False,
|
||
|
empty=False, argument_sequence=(x,y))
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function x = test(x, y)\n"
|
||
|
" x = x.^2 + y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
# make sure it gives (x, y) not (y, x)
|
||
|
expr = Equality(x, x**2 + y)
|
||
|
name_expr = ("test", expr)
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function x = test(x, y)\n"
|
||
|
" x = x.^2 + y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_m_not_supported():
|
||
|
f = Function('f')
|
||
|
name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
|
||
|
result, = codegen(name_expr, "Octave", header=False, empty=False)
|
||
|
source = result[1]
|
||
|
expected = (
|
||
|
"function [out1, out2] = test(x)\n"
|
||
|
" % unsupported: Derivative(f(x), x)\n"
|
||
|
" % unsupported: zoo\n"
|
||
|
" out1 = Derivative(f(x), x);\n"
|
||
|
" out2 = zoo;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
|
||
|
def test_global_vars_octave():
|
||
|
x, y, z, t = symbols("x y z t")
|
||
|
result = codegen(('f', x*y), "Octave", header=False, empty=False,
|
||
|
global_vars=(y,))
|
||
|
source = result[0][1]
|
||
|
expected = (
|
||
|
"function out1 = f(x)\n"
|
||
|
" global y\n"
|
||
|
" out1 = x.*y;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|
||
|
|
||
|
result = codegen(('f', x*y+z), "Octave", header=False, empty=False,
|
||
|
argument_sequence=(x, y), global_vars=(z, t))
|
||
|
source = result[0][1]
|
||
|
expected = (
|
||
|
"function out1 = f(x, y)\n"
|
||
|
" global t z\n"
|
||
|
" out1 = x.*y + z;\n"
|
||
|
"end\n"
|
||
|
)
|
||
|
assert source == expected
|