239 lines
7.1 KiB
239 lines
7.1 KiB
Overloads for ClassInstanceType for built-in functions that call dunder methods
on an object.
from functools import wraps
import inspect
import operator
from numba.core.extending import overload
from numba.core.types import ClassInstanceType
def _get_args(n_args):
assert n_args in (1, 2)
return list("xy")[:n_args]
def class_instance_overload(target):
Decorator to add an overload for target that applies when the first argument
is a ClassInstanceType.
def decorator(func):
def wrapped(*args, **kwargs):
if not isinstance(args[0], ClassInstanceType):
return func(*args, **kwargs)
if target is not complex:
# complex ctor needs special treatment as it uses kwargs
params = list(inspect.signature(wrapped).parameters)
assert params == _get_args(len(params))
return overload(target)(wrapped)
return decorator
def extract_template(template, name):
Extract a code-generated function from a string template.
namespace = {}
exec(template, namespace)
return namespace[name]
def register_simple_overload(func, *attrs, n_args=1,):
Register an overload for func that checks for methods __attr__ for each
attr in attrs.
# Use a template to set the signature correctly.
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
@wraps(extract_template(template, "func"))
def overload_func(*args, **kwargs):
options = [
try_call_method(args[0], f"__{attr}__", n_args)
for attr in attrs
return take_first(*options)
return class_instance_overload(func)(overload_func)
def try_call_method(cls_type, method, n_args=1):
If method is defined for cls_type, return a callable that calls this method.
If not, return None.
if method in cls_type.jit_methods:
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
return {arg_names[0]}.{method}({','.join(arg_names[1:])})
return extract_template(template, "func")
def try_call_complex_method(cls_type, method):
""" __complex__ needs special treatment as the argument names are kwargs
and therefore specific in name and default value.
if method in cls_type.jit_methods:
template = f"""
def func(real=0, imag=0):
return real.{method}()
return extract_template(template, "func")
def take_first(*options):
Take the first non-None option.
assert all(o is None or inspect.isfunction(o) for o in options), options
for o in options:
if o is not None:
return o
def class_bool(x):
using_bool_impl = try_call_method(x, "__bool__")
if '__len__' in x.jit_methods:
def using_len_impl(x):
return bool(len(x))
using_len_impl = None
always_true_impl = lambda x: True
return take_first(using_bool_impl, using_len_impl, always_true_impl)
def class_complex(real=0, imag=0):
return take_first(
try_call_complex_method(real, "__complex__"),
lambda real=0, imag=0: complex(float(real))
def class_contains(x, y):
# https://docs.python.org/3/reference/expressions.html#membership-test-operations
return try_call_method(x, "__contains__", 2)
# TODO: use __iter__ if defined.
def class_float(x):
options = [try_call_method(x, "__float__")]
if (
"__index__" in x.jit_methods
options.append(lambda x: float(x.__index__()))
return take_first(*options)
def class_int(x):
options = [try_call_method(x, "__int__")]
options.append(try_call_method(x, "__index__"))
return take_first(*options)
def class_str(x):
return take_first(
try_call_method(x, "__str__"),
lambda x: repr(x),
def class_ne(x, y):
# This doesn't use register_reflected_overload like the other operators
# because it falls back to inverting __eq__ rather than reflecting its
# arguments (as per the definition of the Python data model).
return take_first(
try_call_method(x, "__ne__", 2),
lambda x, y: not (x == y),
def register_reflected_overload(func, meth_forward, meth_reflected):
def class_lt(x, y):
normal_impl = try_call_method(x, f"__{meth_forward}__", 2)
if f"__{meth_reflected}__" in y.jit_methods:
def reflected_impl(x, y):
return y > x
reflected_impl = None
return take_first(normal_impl, reflected_impl)
register_simple_overload(abs, "abs")
register_simple_overload(len, "len")
register_simple_overload(hash, "hash")
# Comparison operators.
register_reflected_overload(operator.ge, "ge", "le")
register_reflected_overload(operator.gt, "gt", "lt")
register_reflected_overload(operator.le, "le", "ge")
register_reflected_overload(operator.lt, "lt", "gt")
# Note that eq is missing support for fallback to `x is y`, but `is` and
# `operator.is` are presently unsupported in general.
register_reflected_overload(operator.eq, "eq", "eq")
# Arithmetic operators.
register_simple_overload(operator.add, "add", n_args=2)
register_simple_overload(operator.floordiv, "floordiv", n_args=2)
register_simple_overload(operator.lshift, "lshift", n_args=2)
register_simple_overload(operator.mul, "mul", n_args=2)
register_simple_overload(operator.mod, "mod", n_args=2)
register_simple_overload(operator.neg, "neg")
register_simple_overload(operator.pos, "pos")
register_simple_overload(operator.pow, "pow", n_args=2)
register_simple_overload(operator.rshift, "rshift", n_args=2)
register_simple_overload(operator.sub, "sub", n_args=2)
register_simple_overload(operator.truediv, "truediv", n_args=2)
# Inplace arithmetic operators.
register_simple_overload(operator.iadd, "iadd", "add", n_args=2)
register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2)
register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2)
register_simple_overload(operator.imul, "imul", "mul", n_args=2)
register_simple_overload(operator.imod, "imod", "mod", n_args=2)
register_simple_overload(operator.ipow, "ipow", "pow", n_args=2)
register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2)
register_simple_overload(operator.isub, "isub", "sub", n_args=2)
register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2)
# Logical operators.
register_simple_overload(operator.and_, "and", n_args=2)
register_simple_overload(operator.or_, "or", n_args=2)
register_simple_overload(operator.xor, "xor", n_args=2)
# Inplace logical operators.
register_simple_overload(operator.iand, "iand", "and", n_args=2)
register_simple_overload(operator.ior, "ior", "or", n_args=2)
register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)