""" A Printer for generating executable code. The most important function here is srepr that returns a string so that the relation eval(srepr(expr))=expr holds in an appropriate environment. """ from __future__ import annotations from typing import Any from sympy.core.function import AppliedUndef from sympy.core.mul import Mul from mpmath.libmp import repr_dps, to_str as mlib_to_str from .printer import Printer, print_function class ReprPrinter(Printer): printmethod = "_sympyrepr" _default_settings: dict[str, Any] = { "order": None, "perm_cyclic" : True, } def reprify(self, args, sep): """ Prints each item in `args` and joins them with `sep`. """ return sep.join([self.doprint(item) for item in args]) def emptyPrinter(self, expr): """ The fallback printer. """ if isinstance(expr, str): return expr elif hasattr(expr, "__srepr__"): return expr.__srepr__() elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"): l = [] for o in expr.args: l.append(self._print(o)) return expr.__class__.__name__ + '(%s)' % ', '.join(l) elif hasattr(expr, "__module__") and hasattr(expr, "__name__"): return "<'%s.%s'>" % (expr.__module__, expr.__name__) else: return str(expr) def _print_Add(self, expr, order=None): args = self._as_ordered_terms(expr, order=order) args = map(self._print, args) clsname = type(expr).__name__ return clsname + "(%s)" % ", ".join(args) def _print_Cycle(self, expr): return expr.__repr__() def _print_Permutation(self, expr): from sympy.combinatorics.permutations import Permutation, Cycle from sympy.utilities.exceptions import sympy_deprecation_warning perm_cyclic = Permutation.print_cyclic if perm_cyclic is not None: sympy_deprecation_warning( f""" Setting Permutation.print_cyclic is deprecated. Instead use init_printing(perm_cyclic={perm_cyclic}). """, deprecated_since_version="1.6", active_deprecations_target="deprecated-permutation-print_cyclic", stacklevel=7, ) else: perm_cyclic = self._settings.get("perm_cyclic", True) if perm_cyclic: if not expr.size: return 'Permutation()' # before taking Cycle notation, see if the last element is # a singleton and move it to the head of the string s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):] last = s.rfind('(') if not last == 0 and ',' not in s[last:]: s = s[last:] + s[:last] return 'Permutation%s' %s else: s = expr.support() if not s: if expr.size < 5: return 'Permutation(%s)' % str(expr.array_form) return 'Permutation([], size=%s)' % expr.size trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size use = full = str(expr.array_form) if len(trim) < len(full): use = trim return 'Permutation(%s)' % use def _print_Function(self, expr): r = self._print(expr.func) r += '(%s)' % ', '.join([self._print(a) for a in expr.args]) return r def _print_Heaviside(self, expr): # Same as _print_Function but uses pargs to suppress default value for # 2nd arg. r = self._print(expr.func) r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs]) return r def _print_FunctionClass(self, expr): if issubclass(expr, AppliedUndef): return 'Function(%r)' % (expr.__name__) else: return expr.__name__ def _print_Half(self, expr): return 'Rational(1, 2)' def _print_RationalConstant(self, expr): return str(expr) def _print_AtomicExpr(self, expr): return str(expr) def _print_NumberSymbol(self, expr): return str(expr) def _print_Integer(self, expr): return 'Integer(%i)' % expr.p def _print_Complexes(self, expr): return 'Complexes' def _print_Integers(self, expr): return 'Integers' def _print_Naturals(self, expr): return 'Naturals' def _print_Naturals0(self, expr): return 'Naturals0' def _print_Rationals(self, expr): return 'Rationals' def _print_Reals(self, expr): return 'Reals' def _print_EmptySet(self, expr): return 'EmptySet' def _print_UniversalSet(self, expr): return 'UniversalSet' def _print_EmptySequence(self, expr): return 'EmptySequence' def _print_list(self, expr): return "[%s]" % self.reprify(expr, ", ") def _print_dict(self, expr): sep = ", " dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()] return "{%s}" % sep.join(dict_kvs) def _print_set(self, expr): if not expr: return "set()" return "{%s}" % self.reprify(expr, ", ") def _print_MatrixBase(self, expr): # special case for some empty matrices if (expr.rows == 0) ^ (expr.cols == 0): return '%s(%s, %s, %s)' % (expr.__class__.__name__, self._print(expr.rows), self._print(expr.cols), self._print([])) l = [] for i in range(expr.rows): l.append([]) for j in range(expr.cols): l[-1].append(expr[i, j]) return '%s(%s)' % (expr.__class__.__name__, self._print(l)) def _print_BooleanTrue(self, expr): return "true" def _print_BooleanFalse(self, expr): return "false" def _print_NaN(self, expr): return "nan" def _print_Mul(self, expr, order=None): if self.order not in ('old', 'none'): args = expr.as_ordered_factors() else: # use make_args in case expr was something like -x -> x args = Mul.make_args(expr) args = map(self._print, args) clsname = type(expr).__name__ return clsname + "(%s)" % ", ".join(args) def _print_Rational(self, expr): return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q)) def _print_PythonRational(self, expr): return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q) def _print_Fraction(self, expr): return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator)) def _print_Float(self, expr): r = mlib_to_str(expr._mpf_, repr_dps(expr._prec)) return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec) def _print_Sum2(self, expr): return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i), self._print(expr.a), self._print(expr.b)) def _print_Str(self, s): return "%s(%s)" % (s.__class__.__name__, self._print(s.name)) def _print_Symbol(self, expr): d = expr._assumptions_orig # print the dummy_index like it was an assumption if expr.is_Dummy: d['dummy_index'] = expr.dummy_index if d == {}: return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name)) else: attr = ['%s=%s' % (k, v) for k, v in d.items()] return "%s(%s, %s)" % (expr.__class__.__name__, self._print(expr.name), ', '.join(attr)) def _print_CoordinateSymbol(self, expr): d = expr._assumptions.generator if d == {}: return "%s(%s, %s)" % ( expr.__class__.__name__, self._print(expr.coord_sys), self._print(expr.index) ) else: attr = ['%s=%s' % (k, v) for k, v in d.items()] return "%s(%s, %s, %s)" % ( expr.__class__.__name__, self._print(expr.coord_sys), self._print(expr.index), ', '.join(attr) ) def _print_Predicate(self, expr): return "Q.%s" % expr.name def _print_AppliedPredicate(self, expr): # will be changed to just expr.args when args overriding is removed args = expr._args return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", ")) def _print_str(self, expr): return repr(expr) def _print_tuple(self, expr): if len(expr) == 1: return "(%s,)" % self._print(expr[0]) else: return "(%s)" % self.reprify(expr, ", ") def _print_WildFunction(self, expr): return "%s('%s')" % (expr.__class__.__name__, expr.name) def _print_AlgebraicNumber(self, expr): return "%s(%s, %s)" % (expr.__class__.__name__, self._print(expr.root), self._print(expr.coeffs())) def _print_PolyRing(self, ring): return "%s(%s, %s, %s)" % (ring.__class__.__name__, self._print(ring.symbols), self._print(ring.domain), self._print(ring.order)) def _print_FracField(self, field): return "%s(%s, %s, %s)" % (field.__class__.__name__, self._print(field.symbols), self._print(field.domain), self._print(field.order)) def _print_PolyElement(self, poly): terms = list(poly.terms()) terms.sort(key=poly.ring.order, reverse=True) return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms)) def _print_FracElement(self, frac): numer_terms = list(frac.numer.terms()) numer_terms.sort(key=frac.field.order, reverse=True) denom_terms = list(frac.denom.terms()) denom_terms.sort(key=frac.field.order, reverse=True) numer = self._print(numer_terms) denom = self._print(denom_terms) return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom) def _print_FractionField(self, domain): cls = domain.__class__.__name__ field = self._print(domain.field) return "%s(%s)" % (cls, field) def _print_PolynomialRingBase(self, ring): cls = ring.__class__.__name__ dom = self._print(ring.domain) gens = ', '.join(map(self._print, ring.gens)) order = str(ring.order) if order != ring.default_order: orderstr = ", order=" + order else: orderstr = "" return "%s(%s, %s%s)" % (cls, dom, gens, orderstr) def _print_DMP(self, p): cls = p.__class__.__name__ rep = self._print(p.rep) dom = self._print(p.dom) if p.ring is not None: ringstr = ", ring=" + self._print(p.ring) else: ringstr = "" return "%s(%s, %s%s)" % (cls, rep, dom, ringstr) def _print_MonogenicFiniteExtension(self, ext): # The expanded tree shown by srepr(ext.modulus) # is not practical. return "FiniteExtension(%s)" % str(ext.modulus) def _print_ExtensionElement(self, f): rep = self._print(f.rep) ext = self._print(f.ext) return "ExtElem(%s, %s)" % (rep, ext) @print_function(ReprPrinter) def srepr(expr, **settings): """return expr in repr form""" return ReprPrinter(settings).doprint(expr)