210 lines
5.5 KiB
Python
210 lines
5.5 KiB
Python
|
from sympy.parsing.sym_expr import SymPyExpression
|
||
|
from sympy.testing.pytest import raises
|
||
|
from sympy.external import import_module
|
||
|
|
||
|
lfortran = import_module('lfortran')
|
||
|
cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']})
|
||
|
|
||
|
if lfortran and cin:
|
||
|
from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String,
|
||
|
Declaration, FloatType)
|
||
|
from sympy.core import Integer, Float
|
||
|
from sympy.core.symbol import Symbol
|
||
|
|
||
|
expr1 = SymPyExpression()
|
||
|
src = """\
|
||
|
integer :: a, b, c, d
|
||
|
real :: p, q, r, s
|
||
|
"""
|
||
|
|
||
|
def test_c_parse():
|
||
|
src1 = """\
|
||
|
int a, b = 4;
|
||
|
float c, d = 2.4;
|
||
|
"""
|
||
|
expr1.convert_to_expr(src1, 'c')
|
||
|
ls = expr1.return_expr()
|
||
|
|
||
|
assert ls[0] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('a'),
|
||
|
type=IntBaseType(String('intc'))
|
||
|
)
|
||
|
)
|
||
|
assert ls[1] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('b'),
|
||
|
type=IntBaseType(String('intc')),
|
||
|
value=Integer(4)
|
||
|
)
|
||
|
)
|
||
|
assert ls[2] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('c'),
|
||
|
type=FloatType(
|
||
|
String('float32'),
|
||
|
nbits=Integer(32),
|
||
|
nmant=Integer(23),
|
||
|
nexp=Integer(8)
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
assert ls[3] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('d'),
|
||
|
type=FloatType(
|
||
|
String('float32'),
|
||
|
nbits=Integer(32),
|
||
|
nmant=Integer(23),
|
||
|
nexp=Integer(8)
|
||
|
),
|
||
|
value=Float('2.3999999999999999', precision=53)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_fortran_parse():
|
||
|
expr = SymPyExpression(src, 'f')
|
||
|
ls = expr.return_expr()
|
||
|
|
||
|
assert ls[0] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('a'),
|
||
|
type=IntBaseType(String('integer')),
|
||
|
value=Integer(0)
|
||
|
)
|
||
|
)
|
||
|
assert ls[1] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('b'),
|
||
|
type=IntBaseType(String('integer')),
|
||
|
value=Integer(0)
|
||
|
)
|
||
|
)
|
||
|
assert ls[2] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('c'),
|
||
|
type=IntBaseType(String('integer')),
|
||
|
value=Integer(0)
|
||
|
)
|
||
|
)
|
||
|
assert ls[3] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('d'),
|
||
|
type=IntBaseType(String('integer')),
|
||
|
value=Integer(0)
|
||
|
)
|
||
|
)
|
||
|
assert ls[4] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('p'),
|
||
|
type=FloatBaseType(String('real')),
|
||
|
value=Float('0.0', precision=53)
|
||
|
)
|
||
|
)
|
||
|
assert ls[5] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('q'),
|
||
|
type=FloatBaseType(String('real')),
|
||
|
value=Float('0.0', precision=53)
|
||
|
)
|
||
|
)
|
||
|
assert ls[6] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('r'),
|
||
|
type=FloatBaseType(String('real')),
|
||
|
value=Float('0.0', precision=53)
|
||
|
)
|
||
|
)
|
||
|
assert ls[7] == Declaration(
|
||
|
Variable(
|
||
|
Symbol('s'),
|
||
|
type=FloatBaseType(String('real')),
|
||
|
value=Float('0.0', precision=53)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_convert_py():
|
||
|
src1 = (
|
||
|
src +
|
||
|
"""\
|
||
|
a = b + c
|
||
|
s = p * q / r
|
||
|
"""
|
||
|
)
|
||
|
expr1.convert_to_expr(src1, 'f')
|
||
|
exp_py = expr1.convert_to_python()
|
||
|
assert exp_py == [
|
||
|
'a = 0',
|
||
|
'b = 0',
|
||
|
'c = 0',
|
||
|
'd = 0',
|
||
|
'p = 0.0',
|
||
|
'q = 0.0',
|
||
|
'r = 0.0',
|
||
|
's = 0.0',
|
||
|
'a = b + c',
|
||
|
's = p*q/r'
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_convert_fort():
|
||
|
src1 = (
|
||
|
src +
|
||
|
"""\
|
||
|
a = b + c
|
||
|
s = p * q / r
|
||
|
"""
|
||
|
)
|
||
|
expr1.convert_to_expr(src1, 'f')
|
||
|
exp_fort = expr1.convert_to_fortran()
|
||
|
assert exp_fort == [
|
||
|
' integer*4 a',
|
||
|
' integer*4 b',
|
||
|
' integer*4 c',
|
||
|
' integer*4 d',
|
||
|
' real*8 p',
|
||
|
' real*8 q',
|
||
|
' real*8 r',
|
||
|
' real*8 s',
|
||
|
' a = b + c',
|
||
|
' s = p*q/r'
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_convert_c():
|
||
|
src1 = (
|
||
|
src +
|
||
|
"""\
|
||
|
a = b + c
|
||
|
s = p * q / r
|
||
|
"""
|
||
|
)
|
||
|
expr1.convert_to_expr(src1, 'f')
|
||
|
exp_c = expr1.convert_to_c()
|
||
|
assert exp_c == [
|
||
|
'int a = 0',
|
||
|
'int b = 0',
|
||
|
'int c = 0',
|
||
|
'int d = 0',
|
||
|
'double p = 0.0',
|
||
|
'double q = 0.0',
|
||
|
'double r = 0.0',
|
||
|
'double s = 0.0',
|
||
|
'a = b + c;',
|
||
|
's = p*q/r;'
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_exceptions():
|
||
|
src = 'int a;'
|
||
|
raises(ValueError, lambda: SymPyExpression(src))
|
||
|
raises(ValueError, lambda: SymPyExpression(mode = 'c'))
|
||
|
raises(NotImplementedError, lambda: SymPyExpression(src, mode = 'd'))
|
||
|
|
||
|
elif not lfortran and not cin:
|
||
|
def test_raise():
|
||
|
raises(ImportError, lambda: SymPyExpression('int a;', 'c'))
|
||
|
raises(ImportError, lambda: SymPyExpression('integer :: a', 'f'))
|