ai-content-maker/.venv/Lib/site-packages/sympy/utilities/tests/test_codegen_rust.py

402 lines
12 KiB
Python

from io import StringIO
from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function
from sympy.core.relational import Equality
from sympy.functions.elementary.piecewise import Piecewise
from sympy.utilities.codegen import RustCodeGen, codegen, make_routine
from sympy.testing.pytest import XFAIL
import sympy
x, y, z = symbols('x,y,z')
def test_empty_rust_code():
code_gen = RustCodeGen()
output = StringIO()
code_gen.dump_rs([], output, "file", header=False, empty=False)
source = output.getvalue()
assert source == ""
def test_simple_rust_code():
name_expr = ("test", (x + y)*z)
result, = codegen(name_expr, "Rust", header=False, empty=False)
assert result[0] == "test.rs"
source = result[1]
expected = (
"fn test(x: f64, y: f64, z: f64) -> f64 {\n"
" let out1 = z*(x + y);\n"
" out1\n"
"}\n"
)
assert source == expected
def test_simple_code_with_header():
name_expr = ("test", (x + y)*z)
result, = codegen(name_expr, "Rust", header=True, empty=False)
assert result[0] == "test.rs"
source = result[1]
version_str = "Code generated with SymPy %s" % sympy.__version__
version_line = version_str.center(76).rstrip()
expected = (
"/*\n"
" *%(version_line)s\n"
" *\n"
" * See http://www.sympy.org/ for more information.\n"
" *\n"
" * This file is part of 'project'\n"
" */\n"
"fn test(x: f64, y: f64, z: f64) -> f64 {\n"
" let out1 = z*(x + y);\n"
" out1\n"
"}\n"
) % {'version_line': version_line}
assert source == expected
def test_simple_code_nameout():
expr = Equality(z, (x + y))
name_expr = ("test", expr)
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn test(x: f64, y: f64) -> f64 {\n"
" let z = x + y;\n"
" z\n"
"}\n"
)
assert source == expected
def test_numbersymbol():
name_expr = ("test", pi**Catalan)
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn test() -> f64 {\n"
" const Catalan: f64 = %s;\n"
" let out1 = PI.powf(Catalan);\n"
" out1\n"
"}\n"
) % Catalan.evalf(17)
assert source == expected
@XFAIL
def test_numbersymbol_inline():
# FIXME: how to pass inline to the RustCodePrinter?
name_expr = ("test", [pi**Catalan, EulerGamma])
result, = codegen(name_expr, "Rust", header=False,
empty=False, inline=True)
source = result[1]
expected = (
"fn test() -> (f64, f64) {\n"
" const Catalan: f64 = %s;\n"
" const EulerGamma: f64 = %s;\n"
" let out1 = PI.powf(Catalan);\n"
" let out2 = EulerGamma);\n"
" (out1, out2)\n"
"}\n"
) % (Catalan.evalf(17), EulerGamma.evalf(17))
assert source == expected
def test_argument_order():
expr = x + y
routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust")
code_gen = RustCodeGen()
output = StringIO()
code_gen.dump_rs([routine], output, "test", header=False, empty=False)
source = output.getvalue()
expected = (
"fn test(z: f64, x: f64, y: f64) -> f64 {\n"
" let out1 = x + y;\n"
" out1\n"
"}\n"
)
assert source == expected
def test_multiple_results_rust():
# Here the output order is the input order
expr1 = (x + y)*z
expr2 = (x - y)*z
name_expr = ("test", [expr1, expr2])
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
" let out1 = z*(x + y);\n"
" let out2 = z*(x - y);\n"
" (out1, out2)\n"
"}\n"
)
assert source == expected
def test_results_named_unordered():
# Here output order is based on name_expr
A, B, C = symbols('A,B,C')
expr1 = Equality(C, (x + y)*z)
expr2 = Equality(A, (x - y)*z)
expr3 = Equality(B, 2*x)
name_expr = ("test", [expr1, expr2, expr3])
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n"
" let C = z*(x + y);\n"
" let A = z*(x - y);\n"
" let B = 2*x;\n"
" (C, A, B)\n"
"}\n"
)
assert source == expected
def test_results_named_ordered():
A, B, C = symbols('A,B,C')
expr1 = Equality(C, (x + y)*z)
expr2 = Equality(A, (x - y)*z)
expr3 = Equality(B, 2*x)
name_expr = ("test", [expr1, expr2, expr3])
result = codegen(name_expr, "Rust", header=False, empty=False,
argument_sequence=(x, z, y))
assert result[0][0] == "test.rs"
source = result[0][1]
expected = (
"fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n"
" let C = z*(x + y);\n"
" let A = z*(x - y);\n"
" let B = 2*x;\n"
" (C, A, B)\n"
"}\n"
)
assert source == expected
def test_complicated_rs_codegen():
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
name_expr = ("testlong",
[ ((sin(x) + cos(y) + tan(z))**3).expand(),
cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
])
result = codegen(name_expr, "Rust", header=False, empty=False)
assert result[0][0] == "testlong.rs"
source = result[0][1]
expected = (
"fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
" let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()"
" + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)"
" + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)"
" + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()"
" + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n"
" let out2 = (x + y + z).cos().cos().cos().cos()"
".cos().cos().cos().cos();\n"
" (out1, out2)\n"
"}\n"
)
assert source == expected
def test_output_arg_mixed_unordered():
# named outputs are alphabetical, unnamed output appear in the given order
from sympy.functions.elementary.trigonometric import (cos, sin)
a = symbols("a")
name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
result, = codegen(name_expr, "Rust", header=False, empty=False)
assert result[0] == "foo.rs"
source = result[1];
expected = (
"fn foo(x: f64) -> (f64, f64, f64, f64) {\n"
" let out1 = (2*x).cos();\n"
" let y = x.sin();\n"
" let out3 = x.cos();\n"
" let a = (2*x).sin();\n"
" (out1, y, out3, a)\n"
"}\n"
)
assert source == expected
def test_piecewise_():
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
name_expr = ("pwtest", pw)
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn pwtest(x: f64) -> f64 {\n"
" let out1 = if (x < -1) {\n"
" 0\n"
" } else if (x <= 1) {\n"
" x.powi(2)\n"
" } else if (x > 1) {\n"
" 2 - x\n"
" } else {\n"
" 1\n"
" };\n"
" out1\n"
"}\n"
)
assert source == expected
@XFAIL
def test_piecewise_inline():
# FIXME: how to pass inline to the RustCodePrinter?
pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
name_expr = ("pwtest", pw)
result, = codegen(name_expr, "Rust", header=False, empty=False,
inline=True)
source = result[1]
expected = (
"fn pwtest(x: f64) -> f64 {\n"
" let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }"
" else if (x > 1) { -x + 2 } else { 1 };\n"
" out1\n"
"}\n"
)
assert source == expected
def test_multifcns_per_file():
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
result = codegen(name_expr, "Rust", header=False, empty=False)
assert result[0][0] == "foo.rs"
source = result[0][1];
expected = (
"fn foo(x: f64, y: f64) -> (f64, f64) {\n"
" let out1 = 2*x;\n"
" let out2 = 3*y;\n"
" (out1, out2)\n"
"}\n"
"fn bar(y: f64) -> (f64, f64) {\n"
" let out1 = y.powi(2);\n"
" let out2 = 4*y;\n"
" (out1, out2)\n"
"}\n"
)
assert source == expected
def test_multifcns_per_file_w_header():
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
result = codegen(name_expr, "Rust", header=True, empty=False)
assert result[0][0] == "foo.rs"
source = result[0][1];
version_str = "Code generated with SymPy %s" % sympy.__version__
version_line = version_str.center(76).rstrip()
expected = (
"/*\n"
" *%(version_line)s\n"
" *\n"
" * See http://www.sympy.org/ for more information.\n"
" *\n"
" * This file is part of 'project'\n"
" */\n"
"fn foo(x: f64, y: f64) -> (f64, f64) {\n"
" let out1 = 2*x;\n"
" let out2 = 3*y;\n"
" (out1, out2)\n"
"}\n"
"fn bar(y: f64) -> (f64, f64) {\n"
" let out1 = y.powi(2);\n"
" let out2 = 4*y;\n"
" (out1, out2)\n"
"}\n"
) % {'version_line': version_line}
assert source == expected
def test_filename_match_prefix():
name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
result, = codegen(name_expr, "Rust", prefix="baz", header=False,
empty=False)
assert result[0] == "baz.rs"
def test_InOutArgument():
expr = Equality(x, x**2)
name_expr = ("mysqr", expr)
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn mysqr(x: f64) -> f64 {\n"
" let x = x.powi(2);\n"
" x\n"
"}\n"
)
assert source == expected
def test_InOutArgument_order():
# can specify the order as (x, y)
expr = Equality(x, x**2 + y)
name_expr = ("test", expr)
result, = codegen(name_expr, "Rust", header=False,
empty=False, argument_sequence=(x,y))
source = result[1]
expected = (
"fn test(x: f64, y: f64) -> f64 {\n"
" let x = x.powi(2) + y;\n"
" x\n"
"}\n"
)
assert source == expected
# make sure it gives (x, y) not (y, x)
expr = Equality(x, x**2 + y)
name_expr = ("test", expr)
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn test(x: f64, y: f64) -> f64 {\n"
" let x = x.powi(2) + y;\n"
" x\n"
"}\n"
)
assert source == expected
def test_not_supported():
f = Function('f')
name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
result, = codegen(name_expr, "Rust", header=False, empty=False)
source = result[1]
expected = (
"fn test(x: f64) -> (f64, f64) {\n"
" // unsupported: Derivative(f(x), x)\n"
" // unsupported: zoo\n"
" let out1 = Derivative(f(x), x);\n"
" let out2 = zoo;\n"
" (out1, out2)\n"
"}\n"
)
assert source == expected
def test_global_vars_rust():
x, y, z, t = symbols("x y z t")
result = codegen(('f', x*y), "Rust", header=False, empty=False,
global_vars=(y,))
source = result[0][1]
expected = (
"fn f(x: f64) -> f64 {\n"
" let out1 = x*y;\n"
" out1\n"
"}\n"
)
assert source == expected
result = codegen(('f', x*y+z), "Rust", header=False, empty=False,
argument_sequence=(x, y), global_vars=(z, t))
source = result[0][1]
expected = (
"fn f(x: f64, y: f64) -> f64 {\n"
" let out1 = x*y + z;\n"
" out1\n"
"}\n"
)
assert source == expected