""" Python code printers This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code. """ from collections import defaultdict from itertools import chain from sympy.core import S from sympy.core.mod import Mod from .precedence import precedence from .codeprinter import CodePrinter _kw = { 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', 'with', 'yield', 'None', 'False', 'nonlocal', 'True' } _known_functions = { 'Abs': 'abs', 'Min': 'min', 'Max': 'max', } _known_functions_math = { 'acos': 'acos', 'acosh': 'acosh', 'asin': 'asin', 'asinh': 'asinh', 'atan': 'atan', 'atan2': 'atan2', 'atanh': 'atanh', 'ceiling': 'ceil', 'cos': 'cos', 'cosh': 'cosh', 'erf': 'erf', 'erfc': 'erfc', 'exp': 'exp', 'expm1': 'expm1', 'factorial': 'factorial', 'floor': 'floor', 'gamma': 'gamma', 'hypot': 'hypot', 'loggamma': 'lgamma', 'log': 'log', 'ln': 'log', 'log10': 'log10', 'log1p': 'log1p', 'log2': 'log2', 'sin': 'sin', 'sinh': 'sinh', 'Sqrt': 'sqrt', 'tan': 'tan', 'tanh': 'tanh' } # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf # radians trunc fmod fsum gcd degrees fabs] _known_constants_math = { 'Exp1': 'e', 'Pi': 'pi', 'E': 'e', 'Infinity': 'inf', 'NaN': 'nan', 'ComplexInfinity': 'nan' } def _print_known_func(self, expr): known = self.known_functions[expr.__class__.__name__] return '{name}({args})'.format(name=self._module_format(known), args=', '.join((self._print(arg) for arg in expr.args))) def _print_known_const(self, expr): known = self.known_constants[expr.__class__.__name__] return self._module_format(known) class AbstractPythonCodePrinter(CodePrinter): printmethod = "_pythoncode" language = "Python" reserved_words = _kw modules = None # initialized to a set in __init__ tab = ' ' _kf = dict(chain( _known_functions.items(), [(k, 'math.' + v) for k, v in _known_functions_math.items()] )) _kc = {k: 'math.'+v for k, v in _known_constants_math.items()} _operators = {'and': 'and', 'or': 'or', 'not': 'not'} _default_settings = dict( CodePrinter._default_settings, user_functions={}, precision=17, inline=True, fully_qualified_modules=True, contract=False, standard='python3', ) def __init__(self, settings=None): super().__init__(settings) # Python standard handler std = self._settings['standard'] if std is None: import sys std = 'python{}'.format(sys.version_info.major) if std != 'python3': raise ValueError('Only Python 3 is supported.') self.standard = std self.module_imports = defaultdict(set) # Known functions and constants handler self.known_functions = dict(self._kf, **(settings or {}).get( 'user_functions', {})) self.known_constants = dict(self._kc, **(settings or {}).get( 'user_constants', {})) def _declare_number_const(self, name, value): return "%s = %s" % (name, value) def _module_format(self, fqn, register=True): parts = fqn.split('.') if register and len(parts) > 1: self.module_imports['.'.join(parts[:-1])].add(parts[-1]) if self._settings['fully_qualified_modules']: return fqn else: return fqn.split('(')[0].split('[')[0].split('.')[-1] def _format_code(self, lines): return lines def _get_statement(self, codestring): return "{}".format(codestring) def _get_comment(self, text): return " # {}".format(text) def _expand_fold_binary_op(self, op, args): """ This method expands a fold on binary operations. ``functools.reduce`` is an example of a folded operation. For example, the expression `A + B + C + D` is folded into `((A + B) + C) + D` """ if len(args) == 1: return self._print(args[0]) else: return "%s(%s, %s)" % ( self._module_format(op), self._expand_fold_binary_op(op, args[:-1]), self._print(args[-1]), ) def _expand_reduce_binary_op(self, op, args): """ This method expands a reductin on binary operations. Notice: this is NOT the same as ``functools.reduce``. For example, the expression `A + B + C + D` is reduced into: `(A + B) + (C + D)` """ if len(args) == 1: return self._print(args[0]) else: N = len(args) Nhalf = N // 2 return "%s(%s, %s)" % ( self._module_format(op), self._expand_reduce_binary_op(args[:Nhalf]), self._expand_reduce_binary_op(args[Nhalf:]), ) def _print_NaN(self, expr): return "float('nan')" def _print_Infinity(self, expr): return "float('inf')" def _print_NegativeInfinity(self, expr): return "float('-inf')" def _print_ComplexInfinity(self, expr): return self._print_NaN(expr) def _print_Mod(self, expr): PREC = precedence(expr) return ('{} % {}'.format(*(self.parenthesize(x, PREC) for x in expr.args))) def _print_Piecewise(self, expr): result = [] i = 0 for arg in expr.args: e = arg.expr c = arg.cond if i == 0: result.append('(') result.append('(') result.append(self._print(e)) result.append(')') result.append(' if ') result.append(self._print(c)) result.append(' else ') i += 1 result = result[:-1] if result[-1] == 'True': result = result[:-2] result.append(')') else: result.append(' else None)') return ''.join(result) def _print_Relational(self, expr): "Relational printer for Equality and Unequality" op = { '==' :'equal', '!=' :'not_equal', '<' :'less', '<=' :'less_equal', '>' :'greater', '>=' :'greater_equal', } if expr.rel_op in op: lhs = self._print(expr.lhs) rhs = self._print(expr.rhs) return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs) return super()._print_Relational(expr) def _print_ITE(self, expr): from sympy.functions.elementary.piecewise import Piecewise return self._print(expr.rewrite(Piecewise)) def _print_Sum(self, expr): loops = ( 'for {i} in range({a}, {b}+1)'.format( i=self._print(i), a=self._print(a), b=self._print(b)) for i, a, b in expr.limits) return '(builtins.sum({function} {loops}))'.format( function=self._print(expr.function), loops=' '.join(loops)) def _print_ImaginaryUnit(self, expr): return '1j' def _print_KroneckerDelta(self, expr): a, b = expr.args return '(1 if {a} == {b} else 0)'.format( a = self._print(a), b = self._print(b) ) def _print_MatrixBase(self, expr): name = expr.__class__.__name__ func = self.known_functions.get(name, name) return "%s(%s)" % (func, self._print(expr.tolist())) _print_SparseRepMatrix = \ _print_MutableSparseMatrix = \ _print_ImmutableSparseMatrix = \ _print_Matrix = \ _print_DenseMatrix = \ _print_MutableDenseMatrix = \ _print_ImmutableMatrix = \ _print_ImmutableDenseMatrix = \ lambda self, expr: self._print_MatrixBase(expr) def _indent_codestring(self, codestring): return '\n'.join([self.tab + line for line in codestring.split('\n')]) def _print_FunctionDefinition(self, fd): body = '\n'.join((self._print(arg) for arg in fd.body)) return "def {name}({parameters}):\n{body}".format( name=self._print(fd.name), parameters=', '.join([self._print(var.symbol) for var in fd.parameters]), body=self._indent_codestring(body) ) def _print_While(self, whl): body = '\n'.join((self._print(arg) for arg in whl.body)) return "while {cond}:\n{body}".format( cond=self._print(whl.condition), body=self._indent_codestring(body) ) def _print_Declaration(self, decl): return '%s = %s' % ( self._print(decl.variable.symbol), self._print(decl.variable.value) ) def _print_Return(self, ret): arg, = ret.args return 'return %s' % self._print(arg) def _print_Print(self, prnt): print_args = ', '.join((self._print(arg) for arg in prnt.print_args)) if prnt.format_string != None: # Must be '!= None', cannot be 'is not None' print_args = '{} % ({})'.format( self._print(prnt.format_string), print_args) if prnt.file != None: # Must be '!= None', cannot be 'is not None' print_args += ', file=%s' % self._print(prnt.file) return 'print(%s)' % print_args def _print_Stream(self, strm): if str(strm.name) == 'stdout': return self._module_format('sys.stdout') elif str(strm.name) == 'stderr': return self._module_format('sys.stderr') else: return self._print(strm.name) def _print_NoneToken(self, arg): return 'None' def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'): """Printing helper function for ``Pow`` Notes ===== This preprocesses the ``sqrt`` as math formatter and prints division Examples ======== >>> from sympy import sqrt >>> from sympy.printing.pycode import PythonCodePrinter >>> from sympy.abc import x Python code printer automatically looks up ``math.sqrt``. >>> printer = PythonCodePrinter() >>> printer._hprint_Pow(sqrt(x), rational=True) 'x**(1/2)' >>> printer._hprint_Pow(sqrt(x), rational=False) 'math.sqrt(x)' >>> printer._hprint_Pow(1/sqrt(x), rational=True) 'x**(-1/2)' >>> printer._hprint_Pow(1/sqrt(x), rational=False) '1/math.sqrt(x)' >>> printer._hprint_Pow(1/x, rational=False) '1/x' >>> printer._hprint_Pow(1/x, rational=True) 'x**(-1)' Using sqrt from numpy or mpmath >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt') 'numpy.sqrt(x)' >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt') 'mpmath.sqrt(x)' See Also ======== sympy.printing.str.StrPrinter._print_Pow """ PREC = precedence(expr) if expr.exp == S.Half and not rational: func = self._module_format(sqrt) arg = self._print(expr.base) return '{func}({arg})'.format(func=func, arg=arg) if expr.is_commutative and not rational: if -expr.exp is S.Half: func = self._module_format(sqrt) num = self._print(S.One) arg = self._print(expr.base) return f"{num}/{func}({arg})" if expr.exp is S.NegativeOne: num = self._print(S.One) arg = self.parenthesize(expr.base, PREC, strict=False) return f"{num}/{arg}" base_str = self.parenthesize(expr.base, PREC, strict=False) exp_str = self.parenthesize(expr.exp, PREC, strict=False) return "{}**{}".format(base_str, exp_str) class ArrayPrinter: def _arrayify(self, indexed): from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array try: return convert_indexed_to_array(indexed) except Exception: return indexed def _get_einsum_string(self, subranks, contraction_indices): letters = self._get_letter_generator_for_einsum() contraction_string = "" counter = 0 d = {j: min(i) for i in contraction_indices for j in i} indices = [] for rank_arg in subranks: lindices = [] for i in range(rank_arg): if counter in d: lindices.append(d[counter]) else: lindices.append(counter) counter += 1 indices.append(lindices) mapping = {} letters_free = [] letters_dum = [] for i in indices: for j in i: if j not in mapping: l = next(letters) mapping[j] = l else: l = mapping[j] contraction_string += l if j in d: if l not in letters_dum: letters_dum.append(l) else: letters_free.append(l) contraction_string += "," contraction_string = contraction_string[:-1] return contraction_string, letters_free, letters_dum def _get_letter_generator_for_einsum(self): for i in range(97, 123): yield chr(i) for i in range(65, 91): yield chr(i) raise ValueError("out of letters") def _print_ArrayTensorProduct(self, expr): letters = self._get_letter_generator_for_einsum() contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks]) return '%s("%s", %s)' % ( self._module_format(self._module + "." + self._einsum), contraction_string, ", ".join([self._print(arg) for arg in expr.args]) ) def _print_ArrayContraction(self, expr): from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct base = expr.expr contraction_indices = expr.contraction_indices if isinstance(base, ArrayTensorProduct): elems = ",".join(["%s" % (self._print(arg)) for arg in base.args]) ranks = base.subranks else: elems = self._print(base) ranks = [len(base.shape)] contraction_string, letters_free, letters_dum = self._get_einsum_string(ranks, contraction_indices) if not contraction_indices: return self._print(base) if isinstance(base, ArrayTensorProduct): elems = ",".join(["%s" % (self._print(arg)) for arg in base.args]) else: elems = self._print(base) return "%s(\"%s\", %s)" % ( self._module_format(self._module + "." + self._einsum), "{}->{}".format(contraction_string, "".join(sorted(letters_free))), elems, ) def _print_ArrayDiagonal(self, expr): from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct diagonal_indices = list(expr.diagonal_indices) if isinstance(expr.expr, ArrayTensorProduct): subranks = expr.expr.subranks elems = expr.expr.args else: subranks = expr.subranks elems = [expr.expr] diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices) elems = [self._print(i) for i in elems] return '%s("%s", %s)' % ( self._module_format(self._module + "." + self._einsum), "{}->{}".format(diagonal_string, "".join(letters_free+letters_dum)), ", ".join(elems) ) def _print_PermuteDims(self, expr): return "%s(%s, %s)" % ( self._module_format(self._module + "." + self._transpose), self._print(expr.expr), self._print(expr.permutation.array_form), ) def _print_ArrayAdd(self, expr): return self._expand_fold_binary_op(self._module + "." + self._add, expr.args) def _print_OneArray(self, expr): return "%s((%s,))" % ( self._module_format(self._module+ "." + self._ones), ','.join(map(self._print,expr.args)) ) def _print_ZeroArray(self, expr): return "%s((%s,))" % ( self._module_format(self._module+ "." + self._zeros), ','.join(map(self._print,expr.args)) ) def _print_Assignment(self, expr): #XXX: maybe this needs to happen at a higher level e.g. at _print or #doprint? lhs = self._print(self._arrayify(expr.lhs)) rhs = self._print(self._arrayify(expr.rhs)) return "%s = %s" % ( lhs, rhs ) def _print_IndexedBase(self, expr): return self._print_ArraySymbol(expr) class PythonCodePrinter(AbstractPythonCodePrinter): def _print_sign(self, e): return '(0.0 if {e} == 0 else {f}(1, {e}))'.format( f=self._module_format('math.copysign'), e=self._print(e.args[0])) def _print_Not(self, expr): PREC = precedence(expr) return self._operators['not'] + self.parenthesize(expr.args[0], PREC) def _print_Indexed(self, expr): base = expr.args[0] index = expr.args[1:] return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index])) def _print_Pow(self, expr, rational=False): return self._hprint_Pow(expr, rational=rational) def _print_Rational(self, expr): return '{}/{}'.format(expr.p, expr.q) def _print_Half(self, expr): return self._print_Rational(expr) def _print_frac(self, expr): return self._print_Mod(Mod(expr.args[0], 1)) def _print_Symbol(self, expr): name = super()._print_Symbol(expr) if name in self.reserved_words: if self._settings['error_on_reserved']: msg = ('This expression includes the symbol "{}" which is a ' 'reserved keyword in this language.') raise ValueError(msg.format(name)) return name + self._settings['reserved_word_suffix'] elif '{' in name: # Remove curly braces from subscripted variables return name.replace('{', '').replace('}', '') else: return name _print_lowergamma = CodePrinter._print_not_supported _print_uppergamma = CodePrinter._print_not_supported _print_fresnelc = CodePrinter._print_not_supported _print_fresnels = CodePrinter._print_not_supported for k in PythonCodePrinter._kf: setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func) for k in _known_constants_math: setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const) def pycode(expr, **settings): """ Converts an expr to a string of Python code Parameters ========== expr : Expr A SymPy expression. fully_qualified_modules : bool Whether or not to write out full module names of functions (``math.sin`` vs. ``sin``). default: ``True``. standard : str or None, optional Only 'python3' (default) is supported. This parameter may be removed in the future. Examples ======== >>> from sympy import pycode, tan, Symbol >>> pycode(tan(Symbol('x')) + 1) 'math.tan(x) + 1' """ return PythonCodePrinter(settings).doprint(expr) _not_in_mpmath = 'log1p log2'.split() _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath] _known_functions_mpmath = dict(_in_mpmath, **{ 'beta': 'beta', 'frac': 'frac', 'fresnelc': 'fresnelc', 'fresnels': 'fresnels', 'sign': 'sign', 'loggamma': 'loggamma', 'hyper': 'hyper', 'meijerg': 'meijerg', 'besselj': 'besselj', 'bessely': 'bessely', 'besseli': 'besseli', 'besselk': 'besselk', }) _known_constants_mpmath = { 'Exp1': 'e', 'Pi': 'pi', 'GoldenRatio': 'phi', 'EulerGamma': 'euler', 'Catalan': 'catalan', 'NaN': 'nan', 'Infinity': 'inf', 'NegativeInfinity': 'ninf' } def _unpack_integral_limits(integral_expr): """ helper function for _print_Integral that - accepts an Integral expression - returns a tuple of - a list variables of integration - a list of tuples of the upper and lower limits of integration """ integration_vars = [] limits = [] for integration_range in integral_expr.limits: if len(integration_range) == 3: integration_var, lower_limit, upper_limit = integration_range else: raise NotImplementedError("Only definite integrals are supported") integration_vars.append(integration_var) limits.append((lower_limit, upper_limit)) return integration_vars, limits class MpmathPrinter(PythonCodePrinter): """ Lambda printer for mpmath which maintains precision for floats """ printmethod = "_mpmathcode" language = "Python with mpmath" _kf = dict(chain( _known_functions.items(), [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()] )) _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()} def _print_Float(self, e): # XXX: This does not handle setting mpmath.mp.dps. It is assumed that # the caller of the lambdified function will have set it to sufficient # precision to match the Floats in the expression. # Remove 'mpz' if gmpy is installed. args = str(tuple(map(int, e._mpf_))) return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args) def _print_Rational(self, e): return "{func}({p})/{func}({q})".format( func=self._module_format('mpmath.mpf'), q=self._print(e.q), p=self._print(e.p) ) def _print_Half(self, e): return self._print_Rational(e) def _print_uppergamma(self, e): return "{}({}, {}, {})".format( self._module_format('mpmath.gammainc'), self._print(e.args[0]), self._print(e.args[1]), self._module_format('mpmath.inf')) def _print_lowergamma(self, e): return "{}({}, 0, {})".format( self._module_format('mpmath.gammainc'), self._print(e.args[0]), self._print(e.args[1])) def _print_log2(self, e): return '{0}({1})/{0}(2)'.format( self._module_format('mpmath.log'), self._print(e.args[0])) def _print_log1p(self, e): return '{}({})'.format( self._module_format('mpmath.log1p'), self._print(e.args[0])) def _print_Pow(self, expr, rational=False): return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt') def _print_Integral(self, e): integration_vars, limits = _unpack_integral_limits(e) return "{}(lambda {}: {}, {})".format( self._module_format("mpmath.quad"), ", ".join(map(self._print, integration_vars)), self._print(e.args[0]), ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits)) for k in MpmathPrinter._kf: setattr(MpmathPrinter, '_print_%s' % k, _print_known_func) for k in _known_constants_mpmath: setattr(MpmathPrinter, '_print_%s' % k, _print_known_const) class SymPyPrinter(AbstractPythonCodePrinter): language = "Python with SymPy" def _print_Function(self, expr): mod = expr.func.__module__ or '' return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__), ', '.join((self._print(arg) for arg in expr.args))) def _print_Pow(self, expr, rational=False): return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')