""" Overloads for ClassInstanceType for built-in functions that call dunder methods on an object. """ from functools import wraps import inspect import operator from numba.core.extending import overload from numba.core.types import ClassInstanceType def _get_args(n_args): assert n_args in (1, 2) return list("xy")[:n_args] def class_instance_overload(target): """ Decorator to add an overload for target that applies when the first argument is a ClassInstanceType. """ def decorator(func): @wraps(func) def wrapped(*args, **kwargs): if not isinstance(args[0], ClassInstanceType): return return func(*args, **kwargs) if target is not complex: # complex ctor needs special treatment as it uses kwargs params = list(inspect.signature(wrapped).parameters) assert params == _get_args(len(params)) return overload(target)(wrapped) return decorator def extract_template(template, name): """ Extract a code-generated function from a string template. """ namespace = {} exec(template, namespace) return namespace[name] def register_simple_overload(func, *attrs, n_args=1,): """ Register an overload for func that checks for methods __attr__ for each attr in attrs. """ # Use a template to set the signature correctly. arg_names = _get_args(n_args) template = f""" def func({','.join(arg_names)}): pass """ @wraps(extract_template(template, "func")) def overload_func(*args, **kwargs): options = [ try_call_method(args[0], f"__{attr}__", n_args) for attr in attrs ] return take_first(*options) return class_instance_overload(func)(overload_func) def try_call_method(cls_type, method, n_args=1): """ If method is defined for cls_type, return a callable that calls this method. If not, return None. """ if method in cls_type.jit_methods: arg_names = _get_args(n_args) template = f""" def func({','.join(arg_names)}): return {arg_names[0]}.{method}({','.join(arg_names[1:])}) """ return extract_template(template, "func") def try_call_complex_method(cls_type, method): """ __complex__ needs special treatment as the argument names are kwargs and therefore specific in name and default value. """ if method in cls_type.jit_methods: template = f""" def func(real=0, imag=0): return real.{method}() """ return extract_template(template, "func") def take_first(*options): """ Take the first non-None option. """ assert all(o is None or inspect.isfunction(o) for o in options), options for o in options: if o is not None: return o @class_instance_overload(bool) def class_bool(x): using_bool_impl = try_call_method(x, "__bool__") if '__len__' in x.jit_methods: def using_len_impl(x): return bool(len(x)) else: using_len_impl = None always_true_impl = lambda x: True return take_first(using_bool_impl, using_len_impl, always_true_impl) @class_instance_overload(complex) def class_complex(real=0, imag=0): return take_first( try_call_complex_method(real, "__complex__"), lambda real=0, imag=0: complex(float(real)) ) @class_instance_overload(operator.contains) def class_contains(x, y): # https://docs.python.org/3/reference/expressions.html#membership-test-operations return try_call_method(x, "__contains__", 2) # TODO: use __iter__ if defined. @class_instance_overload(float) def class_float(x): options = [try_call_method(x, "__float__")] if ( "__index__" in x.jit_methods ): options.append(lambda x: float(x.__index__())) return take_first(*options) @class_instance_overload(int) def class_int(x): options = [try_call_method(x, "__int__")] options.append(try_call_method(x, "__index__")) return take_first(*options) @class_instance_overload(str) def class_str(x): return take_first( try_call_method(x, "__str__"), lambda x: repr(x), ) @class_instance_overload(operator.ne) def class_ne(x, y): # This doesn't use register_reflected_overload like the other operators # because it falls back to inverting __eq__ rather than reflecting its # arguments (as per the definition of the Python data model). return take_first( try_call_method(x, "__ne__", 2), lambda x, y: not (x == y), ) def register_reflected_overload(func, meth_forward, meth_reflected): def class_lt(x, y): normal_impl = try_call_method(x, f"__{meth_forward}__", 2) if f"__{meth_reflected}__" in y.jit_methods: def reflected_impl(x, y): return y > x else: reflected_impl = None return take_first(normal_impl, reflected_impl) class_instance_overload(func)(class_lt) register_simple_overload(abs, "abs") register_simple_overload(len, "len") register_simple_overload(hash, "hash") # Comparison operators. register_reflected_overload(operator.ge, "ge", "le") register_reflected_overload(operator.gt, "gt", "lt") register_reflected_overload(operator.le, "le", "ge") register_reflected_overload(operator.lt, "lt", "gt") # Note that eq is missing support for fallback to `x is y`, but `is` and # `operator.is` are presently unsupported in general. register_reflected_overload(operator.eq, "eq", "eq") # Arithmetic operators. register_simple_overload(operator.add, "add", n_args=2) register_simple_overload(operator.floordiv, "floordiv", n_args=2) register_simple_overload(operator.lshift, "lshift", n_args=2) register_simple_overload(operator.mul, "mul", n_args=2) register_simple_overload(operator.mod, "mod", n_args=2) register_simple_overload(operator.neg, "neg") register_simple_overload(operator.pos, "pos") register_simple_overload(operator.pow, "pow", n_args=2) register_simple_overload(operator.rshift, "rshift", n_args=2) register_simple_overload(operator.sub, "sub", n_args=2) register_simple_overload(operator.truediv, "truediv", n_args=2) # Inplace arithmetic operators. register_simple_overload(operator.iadd, "iadd", "add", n_args=2) register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2) register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2) register_simple_overload(operator.imul, "imul", "mul", n_args=2) register_simple_overload(operator.imod, "imod", "mod", n_args=2) register_simple_overload(operator.ipow, "ipow", "pow", n_args=2) register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2) register_simple_overload(operator.isub, "isub", "sub", n_args=2) register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2) # Logical operators. register_simple_overload(operator.and_, "and", n_args=2) register_simple_overload(operator.or_, "or", n_args=2) register_simple_overload(operator.xor, "xor", n_args=2) # Inplace logical operators. register_simple_overload(operator.iand, "iand", "and", n_args=2) register_simple_overload(operator.ior, "ior", "or", n_args=2) register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)