593 lines
21 KiB
Python
593 lines
21 KiB
Python
import inspect
|
|
import operator
|
|
import types as pytypes
|
|
import typing as pt
|
|
from collections import OrderedDict
|
|
from collections.abc import Sequence
|
|
|
|
from llvmlite import ir as llvmir
|
|
from numba import njit
|
|
from numba.core import cgutils, errors, imputils, types, utils
|
|
from numba.core.datamodel import default_manager, models
|
|
from numba.core.registry import cpu_target
|
|
from numba.core.typing import templates
|
|
from numba.core.typing.asnumbatype import as_numba_type
|
|
from numba.core.serialize import disable_pickling
|
|
from numba.experimental.jitclass import _box
|
|
|
|
##############################################################################
|
|
# Data model
|
|
|
|
|
|
class InstanceModel(models.StructModel):
|
|
def __init__(self, dmm, fe_typ):
|
|
cls_data_ty = types.ClassDataType(fe_typ)
|
|
# MemInfoPointer uses the `dtype` attribute to traverse for nested
|
|
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
|
|
# we will replace provide MemInfoPointer with an opaque type
|
|
# so that it does not raise exception for nested meminfo.
|
|
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
|
|
members = [
|
|
('meminfo', types.MemInfoPointer(dtype)),
|
|
('data', types.CPointer(cls_data_ty)),
|
|
]
|
|
super(InstanceModel, self).__init__(dmm, fe_typ, members)
|
|
|
|
|
|
class InstanceDataModel(models.StructModel):
|
|
def __init__(self, dmm, fe_typ):
|
|
clsty = fe_typ.class_type
|
|
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
|
|
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
|
|
|
|
|
|
default_manager.register(types.ClassInstanceType, InstanceModel)
|
|
default_manager.register(types.ClassDataType, InstanceDataModel)
|
|
default_manager.register(types.ClassType, models.OpaqueModel)
|
|
|
|
|
|
def _mangle_attr(name):
|
|
"""
|
|
Mangle attributes.
|
|
The resulting name does not startswith an underscore '_'.
|
|
"""
|
|
return 'm_' + name
|
|
|
|
|
|
##############################################################################
|
|
# Class object
|
|
|
|
_ctor_template = """
|
|
def ctor({args}):
|
|
return __numba_cls_({args})
|
|
"""
|
|
|
|
|
|
def _getargs(fn_sig):
|
|
"""
|
|
Returns list of positional and keyword argument names in order.
|
|
"""
|
|
params = fn_sig.parameters
|
|
args = []
|
|
for k, v in params.items():
|
|
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
|
|
args.append(k)
|
|
else:
|
|
msg = "%s argument type unsupported in jitclass" % v.kind
|
|
raise errors.UnsupportedError(msg)
|
|
return args
|
|
|
|
|
|
@disable_pickling
|
|
class JitClassType(type):
|
|
"""
|
|
The type of any jitclass.
|
|
"""
|
|
def __new__(cls, name, bases, dct):
|
|
if len(bases) != 1:
|
|
raise TypeError("must have exactly one base class")
|
|
[base] = bases
|
|
if isinstance(base, JitClassType):
|
|
raise TypeError("cannot subclass from a jitclass")
|
|
assert 'class_type' in dct, 'missing "class_type" attr'
|
|
outcls = type.__new__(cls, name, bases, dct)
|
|
outcls._set_init()
|
|
return outcls
|
|
|
|
def _set_init(cls):
|
|
"""
|
|
Generate a wrapper for calling the constructor from pure Python.
|
|
Note the wrapper will only accept positional arguments.
|
|
"""
|
|
init = cls.class_type.instance_type.methods['__init__']
|
|
init_sig = utils.pysignature(init)
|
|
# get postitional and keyword arguments
|
|
# offset by one to exclude the `self` arg
|
|
args = _getargs(init_sig)[1:]
|
|
cls._ctor_sig = init_sig
|
|
ctor_source = _ctor_template.format(args=', '.join(args))
|
|
glbls = {"__numba_cls_": cls}
|
|
exec(ctor_source, glbls)
|
|
ctor = glbls['ctor']
|
|
cls._ctor = njit(ctor)
|
|
|
|
def __instancecheck__(cls, instance):
|
|
if isinstance(instance, _box.Box):
|
|
return instance._numba_type_.class_type is cls.class_type
|
|
return False
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
# The first argument of _ctor_sig is `cls`, which here
|
|
# is bound to None and then skipped when invoking the constructor.
|
|
bind = cls._ctor_sig.bind(None, *args, **kwargs)
|
|
bind.apply_defaults()
|
|
return cls._ctor(*bind.args[1:], **bind.kwargs)
|
|
|
|
|
|
##############################################################################
|
|
# Registration utils
|
|
|
|
def _validate_spec(spec):
|
|
for k, v in spec.items():
|
|
if not isinstance(k, str):
|
|
raise TypeError("spec keys should be strings, got %r" % (k,))
|
|
if not isinstance(v, types.Type):
|
|
raise TypeError("spec values should be Numba type instances, got %r"
|
|
% (v,))
|
|
|
|
|
|
def _fix_up_private_attr(clsname, spec):
|
|
"""
|
|
Apply the same changes to dunder names as CPython would.
|
|
"""
|
|
out = OrderedDict()
|
|
for k, v in spec.items():
|
|
if k.startswith('__') and not k.endswith('__'):
|
|
k = '_' + clsname + k
|
|
out[k] = v
|
|
return out
|
|
|
|
|
|
def _add_linking_libs(context, call):
|
|
"""
|
|
Add the required libs for the callable to allow inlining.
|
|
"""
|
|
libs = getattr(call, "libs", ())
|
|
if libs:
|
|
context.add_linking_libs(libs)
|
|
|
|
|
|
def register_class_type(cls, spec, class_ctor, builder):
|
|
"""
|
|
Internal function to create a jitclass.
|
|
|
|
Args
|
|
----
|
|
cls: the original class object (used as the prototype)
|
|
spec: the structural specification contains the field types.
|
|
class_ctor: the numba type to represent the jitclass
|
|
builder: the internal jitclass builder
|
|
"""
|
|
# Normalize spec
|
|
if spec is None:
|
|
spec = OrderedDict()
|
|
elif isinstance(spec, Sequence):
|
|
spec = OrderedDict(spec)
|
|
|
|
# Extend spec with class annotations.
|
|
for attr, py_type in pt.get_type_hints(cls).items():
|
|
if attr not in spec:
|
|
spec[attr] = as_numba_type(py_type)
|
|
|
|
_validate_spec(spec)
|
|
|
|
# Fix up private attribute names
|
|
spec = _fix_up_private_attr(cls.__name__, spec)
|
|
|
|
# Copy methods from base classes
|
|
clsdct = {}
|
|
for basecls in reversed(inspect.getmro(cls)):
|
|
clsdct.update(basecls.__dict__)
|
|
|
|
methods, props, static_methods, others = {}, {}, {}, {}
|
|
for k, v in clsdct.items():
|
|
if isinstance(v, pytypes.FunctionType):
|
|
methods[k] = v
|
|
elif isinstance(v, property):
|
|
props[k] = v
|
|
elif isinstance(v, staticmethod):
|
|
static_methods[k] = v
|
|
else:
|
|
others[k] = v
|
|
|
|
# Check for name shadowing
|
|
shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
|
|
if shadowed:
|
|
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
|
|
|
|
docstring = others.pop('__doc__', "")
|
|
_drop_ignored_attrs(others)
|
|
if others:
|
|
msg = "class members are not yet supported: {0}"
|
|
members = ', '.join(others.keys())
|
|
raise TypeError(msg.format(members))
|
|
|
|
for k, v in props.items():
|
|
if v.fdel is not None:
|
|
raise TypeError("deleter is not supported: {0}".format(k))
|
|
|
|
jit_methods = {k: njit(v) for k, v in methods.items()}
|
|
|
|
jit_props = {}
|
|
for k, v in props.items():
|
|
dct = {}
|
|
if v.fget:
|
|
dct['get'] = njit(v.fget)
|
|
if v.fset:
|
|
dct['set'] = njit(v.fset)
|
|
jit_props[k] = dct
|
|
|
|
jit_static_methods = {
|
|
k: njit(v.__func__) for k, v in static_methods.items()}
|
|
|
|
# Instantiate class type
|
|
class_type = class_ctor(
|
|
cls,
|
|
ConstructorTemplate,
|
|
spec,
|
|
jit_methods,
|
|
jit_props,
|
|
jit_static_methods)
|
|
|
|
jit_class_dct = dict(class_type=class_type, __doc__=docstring)
|
|
jit_class_dct.update(jit_static_methods)
|
|
cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
|
|
|
|
# Register resolution of the class object
|
|
typingctx = cpu_target.typing_context
|
|
typingctx.insert_global(cls, class_type)
|
|
|
|
# Register class
|
|
targetctx = cpu_target.target_context
|
|
builder(class_type, typingctx, targetctx).register()
|
|
as_numba_type.register(cls, class_type.instance_type)
|
|
|
|
return cls
|
|
|
|
|
|
class ConstructorTemplate(templates.AbstractTemplate):
|
|
"""
|
|
Base class for jitclass constructor templates.
|
|
"""
|
|
|
|
def generic(self, args, kws):
|
|
# Redirect resolution to __init__
|
|
instance_type = self.key.instance_type
|
|
ctor = instance_type.jit_methods['__init__']
|
|
boundargs = (instance_type.get_reference_type(),) + args
|
|
disp_type = types.Dispatcher(ctor)
|
|
sig = disp_type.get_call_type(self.context, boundargs, kws)
|
|
|
|
if not isinstance(sig.return_type, types.NoneType):
|
|
raise errors.NumbaTypeError(
|
|
f"__init__() should return None, not '{sig.return_type}'")
|
|
|
|
# Actual constructor returns an instance value (not None)
|
|
out = templates.signature(instance_type, *sig.args[1:])
|
|
return out
|
|
|
|
|
|
def _drop_ignored_attrs(dct):
|
|
# ignore anything defined by object
|
|
drop = set(['__weakref__',
|
|
'__module__',
|
|
'__dict__'])
|
|
|
|
if '__annotations__' in dct:
|
|
drop.add('__annotations__')
|
|
|
|
for k, v in dct.items():
|
|
if isinstance(v, (pytypes.BuiltinFunctionType,
|
|
pytypes.BuiltinMethodType)):
|
|
drop.add(k)
|
|
elif getattr(v, '__objclass__', None) is object:
|
|
drop.add(k)
|
|
|
|
# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
|
|
# None. This is a class member, and class members are not presently
|
|
# supported.
|
|
if '__hash__' in dct and dct['__hash__'] is None:
|
|
drop.add('__hash__')
|
|
|
|
for k in drop:
|
|
del dct[k]
|
|
|
|
|
|
class ClassBuilder(object):
|
|
"""
|
|
A jitclass builder for a mutable jitclass. This will register
|
|
typing and implementation hooks to the given typing and target contexts.
|
|
"""
|
|
class_impl_registry = imputils.Registry('jitclass builder')
|
|
implemented_methods = set()
|
|
|
|
def __init__(self, class_type, typingctx, targetctx):
|
|
self.class_type = class_type
|
|
self.typingctx = typingctx
|
|
self.targetctx = targetctx
|
|
|
|
def register(self):
|
|
"""
|
|
Register to the frontend and backend.
|
|
"""
|
|
# Register generic implementations for all jitclasses
|
|
self._register_methods(self.class_impl_registry,
|
|
self.class_type.instance_type)
|
|
# NOTE other registrations are done at the top-level
|
|
# (see ctor_impl and attr_impl below)
|
|
self.targetctx.install_registry(self.class_impl_registry)
|
|
|
|
def _register_methods(self, registry, instance_type):
|
|
"""
|
|
Register method implementations.
|
|
This simply registers that the method names are valid methods. Inside
|
|
of imp() below we retrieve the actual method to run from the type of
|
|
the receiver argument (i.e. self).
|
|
"""
|
|
to_register = list(instance_type.jit_methods) + \
|
|
list(instance_type.jit_static_methods)
|
|
for meth in to_register:
|
|
|
|
# There's no way to retrieve the particular method name
|
|
# inside the implementation function, so we have to register a
|
|
# specific closure for each different name
|
|
if meth not in self.implemented_methods:
|
|
self._implement_method(registry, meth)
|
|
self.implemented_methods.add(meth)
|
|
|
|
def _implement_method(self, registry, attr):
|
|
# create a separate instance of imp method to avoid closure clashing
|
|
def get_imp():
|
|
def imp(context, builder, sig, args):
|
|
instance_type = sig.args[0]
|
|
|
|
if attr in instance_type.jit_methods:
|
|
method = instance_type.jit_methods[attr]
|
|
elif attr in instance_type.jit_static_methods:
|
|
method = instance_type.jit_static_methods[attr]
|
|
# imp gets called as a method, where the first argument is
|
|
# self. We drop this for a static method.
|
|
sig = sig.replace(args=sig.args[1:])
|
|
args = args[1:]
|
|
|
|
disp_type = types.Dispatcher(method)
|
|
call = context.get_function(disp_type, sig)
|
|
out = call(builder, args)
|
|
_add_linking_libs(context, call)
|
|
return imputils.impl_ret_new_ref(context, builder,
|
|
sig.return_type, out)
|
|
return imp
|
|
|
|
def _getsetitem_gen(getset):
|
|
_dunder_meth = "__%s__" % getset
|
|
op = getattr(operator, getset)
|
|
|
|
@templates.infer_global(op)
|
|
class GetSetItem(templates.AbstractTemplate):
|
|
def generic(self, args, kws):
|
|
instance = args[0]
|
|
if isinstance(instance, types.ClassInstanceType) and \
|
|
_dunder_meth in instance.jit_methods:
|
|
meth = instance.jit_methods[_dunder_meth]
|
|
disp_type = types.Dispatcher(meth)
|
|
sig = disp_type.get_call_type(self.context, args, kws)
|
|
return sig
|
|
|
|
# lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
|
|
# from python and numba
|
|
imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
|
|
types.ClassInstanceType,
|
|
types.VarArg(types.Any))(get_imp())
|
|
imputils.lower_builtin(op,
|
|
types.ClassInstanceType,
|
|
types.VarArg(types.Any))(get_imp())
|
|
|
|
dunder_stripped = attr.strip('_')
|
|
if dunder_stripped in ("getitem", "setitem"):
|
|
_getsetitem_gen(dunder_stripped)
|
|
else:
|
|
registry.lower((types.ClassInstanceType, attr),
|
|
types.ClassInstanceType,
|
|
types.VarArg(types.Any))(get_imp())
|
|
|
|
|
|
@templates.infer_getattr
|
|
class ClassAttribute(templates.AttributeTemplate):
|
|
key = types.ClassInstanceType
|
|
|
|
def generic_resolve(self, instance, attr):
|
|
if attr in instance.struct:
|
|
# It's a struct field => the type is well-known
|
|
return instance.struct[attr]
|
|
|
|
elif attr in instance.jit_methods:
|
|
# It's a jitted method => typeinfer it
|
|
meth = instance.jit_methods[attr]
|
|
disp_type = types.Dispatcher(meth)
|
|
|
|
class MethodTemplate(templates.AbstractTemplate):
|
|
key = (self.key, attr)
|
|
|
|
def generic(self, args, kws):
|
|
args = (instance,) + tuple(args)
|
|
sig = disp_type.get_call_type(self.context, args, kws)
|
|
return sig.as_method()
|
|
|
|
return types.BoundFunction(MethodTemplate, instance)
|
|
|
|
elif attr in instance.jit_static_methods:
|
|
# It's a jitted method => typeinfer it
|
|
meth = instance.jit_static_methods[attr]
|
|
disp_type = types.Dispatcher(meth)
|
|
|
|
class StaticMethodTemplate(templates.AbstractTemplate):
|
|
key = (self.key, attr)
|
|
|
|
def generic(self, args, kws):
|
|
# Don't add instance as the first argument for a static
|
|
# method.
|
|
sig = disp_type.get_call_type(self.context, args, kws)
|
|
# sig itself does not include ClassInstanceType as it's
|
|
# first argument, so instead of calling sig.as_method()
|
|
# we insert the recvr. This is equivalent to
|
|
# sig.replace(args=(instance,) + sig.args).as_method().
|
|
return sig.replace(recvr=instance)
|
|
|
|
return types.BoundFunction(StaticMethodTemplate, instance)
|
|
|
|
elif attr in instance.jit_props:
|
|
# It's a jitted property => typeinfer its getter
|
|
impdct = instance.jit_props[attr]
|
|
getter = impdct['get']
|
|
disp_type = types.Dispatcher(getter)
|
|
sig = disp_type.get_call_type(self.context, (instance,), {})
|
|
return sig.return_type
|
|
|
|
|
|
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
|
|
def get_attr_impl(context, builder, typ, value, attr):
|
|
"""
|
|
Generic getattr() for @jitclass instances.
|
|
"""
|
|
if attr in typ.struct:
|
|
# It's a struct field
|
|
inst = context.make_helper(builder, typ, value=value)
|
|
data_pointer = inst.data
|
|
data = context.make_data_helper(builder, typ.get_data_type(),
|
|
ref=data_pointer)
|
|
return imputils.impl_ret_borrowed(context, builder,
|
|
typ.struct[attr],
|
|
getattr(data, _mangle_attr(attr)))
|
|
elif attr in typ.jit_props:
|
|
# It's a jitted property
|
|
getter = typ.jit_props[attr]['get']
|
|
sig = templates.signature(None, typ)
|
|
dispatcher = types.Dispatcher(getter)
|
|
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
|
|
call = context.get_function(dispatcher, sig)
|
|
out = call(builder, [value])
|
|
_add_linking_libs(context, call)
|
|
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
|
|
|
|
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
|
|
|
|
|
|
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
|
|
def set_attr_impl(context, builder, sig, args, attr):
|
|
"""
|
|
Generic setattr() for @jitclass instances.
|
|
"""
|
|
typ, valty = sig.args
|
|
target, val = args
|
|
|
|
if attr in typ.struct:
|
|
# It's a struct member
|
|
inst = context.make_helper(builder, typ, value=target)
|
|
data_ptr = inst.data
|
|
data = context.make_data_helper(builder, typ.get_data_type(),
|
|
ref=data_ptr)
|
|
|
|
# Get old value
|
|
attr_type = typ.struct[attr]
|
|
oldvalue = getattr(data, _mangle_attr(attr))
|
|
|
|
# Store n
|
|
setattr(data, _mangle_attr(attr), val)
|
|
context.nrt.incref(builder, attr_type, val)
|
|
|
|
# Delete old value
|
|
context.nrt.decref(builder, attr_type, oldvalue)
|
|
|
|
elif attr in typ.jit_props:
|
|
# It's a jitted property
|
|
setter = typ.jit_props[attr]['set']
|
|
disp_type = types.Dispatcher(setter)
|
|
sig = disp_type.get_call_type(context.typing_context,
|
|
(typ, valty), {})
|
|
call = context.get_function(disp_type, sig)
|
|
call(builder, (target, val))
|
|
_add_linking_libs(context, call)
|
|
else:
|
|
raise NotImplementedError(
|
|
'attribute {0!r} not implemented'.format(attr))
|
|
|
|
|
|
def imp_dtor(context, module, instance_type):
|
|
llvoidptr = context.get_value_type(types.voidptr)
|
|
llsize = context.get_value_type(types.uintp)
|
|
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
|
|
[llvoidptr, llsize, llvoidptr])
|
|
|
|
fname = "_Dtor.{0}".format(instance_type.name)
|
|
dtor_fn = cgutils.get_or_insert_function(module, dtor_ftype, fname)
|
|
if dtor_fn.is_declaration:
|
|
# Define
|
|
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
|
|
|
|
alloc_fe_type = instance_type.get_data_type()
|
|
alloc_type = context.get_value_type(alloc_fe_type)
|
|
|
|
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
|
|
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
|
|
|
|
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
|
|
|
|
builder.ret_void()
|
|
|
|
return dtor_fn
|
|
|
|
|
|
@ClassBuilder.class_impl_registry.lower(types.ClassType,
|
|
types.VarArg(types.Any))
|
|
def ctor_impl(context, builder, sig, args):
|
|
"""
|
|
Generic constructor (__new__) for jitclasses.
|
|
"""
|
|
# Allocate the instance
|
|
inst_typ = sig.return_type
|
|
alloc_type = context.get_data_type(inst_typ.get_data_type())
|
|
alloc_size = context.get_abi_sizeof(alloc_type)
|
|
|
|
meminfo = context.nrt.meminfo_alloc_dtor(
|
|
builder,
|
|
context.get_constant(types.uintp, alloc_size),
|
|
imp_dtor(context, builder.module, inst_typ),
|
|
)
|
|
data_pointer = context.nrt.meminfo_data(builder, meminfo)
|
|
data_pointer = builder.bitcast(data_pointer,
|
|
alloc_type.as_pointer())
|
|
|
|
# Nullify all data
|
|
builder.store(cgutils.get_null_value(alloc_type),
|
|
data_pointer)
|
|
|
|
inst_struct = context.make_helper(builder, inst_typ)
|
|
inst_struct.meminfo = meminfo
|
|
inst_struct.data = data_pointer
|
|
|
|
# Call the jitted __init__
|
|
# TODO: extract the following into a common util
|
|
init_sig = (sig.return_type,) + sig.args
|
|
|
|
init = inst_typ.jit_methods['__init__']
|
|
disp_type = types.Dispatcher(init)
|
|
call = context.get_function(disp_type, types.void(*init_sig))
|
|
_add_linking_libs(context, call)
|
|
realargs = [inst_struct._getvalue()] + list(args)
|
|
call(builder, realargs)
|
|
|
|
# Prepare return value
|
|
ret = inst_struct._getvalue()
|
|
|
|
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)
|