ai-content-maker/.venv/Lib/site-packages/numba/experimental/jitclass/base.py

593 lines
21 KiB
Python

import inspect
import operator
import types as pytypes
import typing as pt
from collections import OrderedDict
from collections.abc import Sequence
from llvmlite import ir as llvmir
from numba import njit
from numba.core import cgutils, errors, imputils, types, utils
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target
from numba.core.typing import templates
from numba.core.typing.asnumbatype import as_numba_type
from numba.core.serialize import disable_pickling
from numba.experimental.jitclass import _box
##############################################################################
# Data model
class InstanceModel(models.StructModel):
def __init__(self, dmm, fe_typ):
cls_data_ty = types.ClassDataType(fe_typ)
# MemInfoPointer uses the `dtype` attribute to traverse for nested
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
# we will replace provide MemInfoPointer with an opaque type
# so that it does not raise exception for nested meminfo.
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
members = [
('meminfo', types.MemInfoPointer(dtype)),
('data', types.CPointer(cls_data_ty)),
]
super(InstanceModel, self).__init__(dmm, fe_typ, members)
class InstanceDataModel(models.StructModel):
def __init__(self, dmm, fe_typ):
clsty = fe_typ.class_type
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
default_manager.register(types.ClassInstanceType, InstanceModel)
default_manager.register(types.ClassDataType, InstanceDataModel)
default_manager.register(types.ClassType, models.OpaqueModel)
def _mangle_attr(name):
"""
Mangle attributes.
The resulting name does not startswith an underscore '_'.
"""
return 'm_' + name
##############################################################################
# Class object
_ctor_template = """
def ctor({args}):
return __numba_cls_({args})
"""
def _getargs(fn_sig):
"""
Returns list of positional and keyword argument names in order.
"""
params = fn_sig.parameters
args = []
for k, v in params.items():
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
args.append(k)
else:
msg = "%s argument type unsupported in jitclass" % v.kind
raise errors.UnsupportedError(msg)
return args
@disable_pickling
class JitClassType(type):
"""
The type of any jitclass.
"""
def __new__(cls, name, bases, dct):
if len(bases) != 1:
raise TypeError("must have exactly one base class")
[base] = bases
if isinstance(base, JitClassType):
raise TypeError("cannot subclass from a jitclass")
assert 'class_type' in dct, 'missing "class_type" attr'
outcls = type.__new__(cls, name, bases, dct)
outcls._set_init()
return outcls
def _set_init(cls):
"""
Generate a wrapper for calling the constructor from pure Python.
Note the wrapper will only accept positional arguments.
"""
init = cls.class_type.instance_type.methods['__init__']
init_sig = utils.pysignature(init)
# get postitional and keyword arguments
# offset by one to exclude the `self` arg
args = _getargs(init_sig)[1:]
cls._ctor_sig = init_sig
ctor_source = _ctor_template.format(args=', '.join(args))
glbls = {"__numba_cls_": cls}
exec(ctor_source, glbls)
ctor = glbls['ctor']
cls._ctor = njit(ctor)
def __instancecheck__(cls, instance):
if isinstance(instance, _box.Box):
return instance._numba_type_.class_type is cls.class_type
return False
def __call__(cls, *args, **kwargs):
# The first argument of _ctor_sig is `cls`, which here
# is bound to None and then skipped when invoking the constructor.
bind = cls._ctor_sig.bind(None, *args, **kwargs)
bind.apply_defaults()
return cls._ctor(*bind.args[1:], **bind.kwargs)
##############################################################################
# Registration utils
def _validate_spec(spec):
for k, v in spec.items():
if not isinstance(k, str):
raise TypeError("spec keys should be strings, got %r" % (k,))
if not isinstance(v, types.Type):
raise TypeError("spec values should be Numba type instances, got %r"
% (v,))
def _fix_up_private_attr(clsname, spec):
"""
Apply the same changes to dunder names as CPython would.
"""
out = OrderedDict()
for k, v in spec.items():
if k.startswith('__') and not k.endswith('__'):
k = '_' + clsname + k
out[k] = v
return out
def _add_linking_libs(context, call):
"""
Add the required libs for the callable to allow inlining.
"""
libs = getattr(call, "libs", ())
if libs:
context.add_linking_libs(libs)
def register_class_type(cls, spec, class_ctor, builder):
"""
Internal function to create a jitclass.
Args
----
cls: the original class object (used as the prototype)
spec: the structural specification contains the field types.
class_ctor: the numba type to represent the jitclass
builder: the internal jitclass builder
"""
# Normalize spec
if spec is None:
spec = OrderedDict()
elif isinstance(spec, Sequence):
spec = OrderedDict(spec)
# Extend spec with class annotations.
for attr, py_type in pt.get_type_hints(cls).items():
if attr not in spec:
spec[attr] = as_numba_type(py_type)
_validate_spec(spec)
# Fix up private attribute names
spec = _fix_up_private_attr(cls.__name__, spec)
# Copy methods from base classes
clsdct = {}
for basecls in reversed(inspect.getmro(cls)):
clsdct.update(basecls.__dict__)
methods, props, static_methods, others = {}, {}, {}, {}
for k, v in clsdct.items():
if isinstance(v, pytypes.FunctionType):
methods[k] = v
elif isinstance(v, property):
props[k] = v
elif isinstance(v, staticmethod):
static_methods[k] = v
else:
others[k] = v
# Check for name shadowing
shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
if shadowed:
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
docstring = others.pop('__doc__', "")
_drop_ignored_attrs(others)
if others:
msg = "class members are not yet supported: {0}"
members = ', '.join(others.keys())
raise TypeError(msg.format(members))
for k, v in props.items():
if v.fdel is not None:
raise TypeError("deleter is not supported: {0}".format(k))
jit_methods = {k: njit(v) for k, v in methods.items()}
jit_props = {}
for k, v in props.items():
dct = {}
if v.fget:
dct['get'] = njit(v.fget)
if v.fset:
dct['set'] = njit(v.fset)
jit_props[k] = dct
jit_static_methods = {
k: njit(v.__func__) for k, v in static_methods.items()}
# Instantiate class type
class_type = class_ctor(
cls,
ConstructorTemplate,
spec,
jit_methods,
jit_props,
jit_static_methods)
jit_class_dct = dict(class_type=class_type, __doc__=docstring)
jit_class_dct.update(jit_static_methods)
cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
# Register resolution of the class object
typingctx = cpu_target.typing_context
typingctx.insert_global(cls, class_type)
# Register class
targetctx = cpu_target.target_context
builder(class_type, typingctx, targetctx).register()
as_numba_type.register(cls, class_type.instance_type)
return cls
class ConstructorTemplate(templates.AbstractTemplate):
"""
Base class for jitclass constructor templates.
"""
def generic(self, args, kws):
# Redirect resolution to __init__
instance_type = self.key.instance_type
ctor = instance_type.jit_methods['__init__']
boundargs = (instance_type.get_reference_type(),) + args
disp_type = types.Dispatcher(ctor)
sig = disp_type.get_call_type(self.context, boundargs, kws)
if not isinstance(sig.return_type, types.NoneType):
raise errors.NumbaTypeError(
f"__init__() should return None, not '{sig.return_type}'")
# Actual constructor returns an instance value (not None)
out = templates.signature(instance_type, *sig.args[1:])
return out
def _drop_ignored_attrs(dct):
# ignore anything defined by object
drop = set(['__weakref__',
'__module__',
'__dict__'])
if '__annotations__' in dct:
drop.add('__annotations__')
for k, v in dct.items():
if isinstance(v, (pytypes.BuiltinFunctionType,
pytypes.BuiltinMethodType)):
drop.add(k)
elif getattr(v, '__objclass__', None) is object:
drop.add(k)
# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
# None. This is a class member, and class members are not presently
# supported.
if '__hash__' in dct and dct['__hash__'] is None:
drop.add('__hash__')
for k in drop:
del dct[k]
class ClassBuilder(object):
"""
A jitclass builder for a mutable jitclass. This will register
typing and implementation hooks to the given typing and target contexts.
"""
class_impl_registry = imputils.Registry('jitclass builder')
implemented_methods = set()
def __init__(self, class_type, typingctx, targetctx):
self.class_type = class_type
self.typingctx = typingctx
self.targetctx = targetctx
def register(self):
"""
Register to the frontend and backend.
"""
# Register generic implementations for all jitclasses
self._register_methods(self.class_impl_registry,
self.class_type.instance_type)
# NOTE other registrations are done at the top-level
# (see ctor_impl and attr_impl below)
self.targetctx.install_registry(self.class_impl_registry)
def _register_methods(self, registry, instance_type):
"""
Register method implementations.
This simply registers that the method names are valid methods. Inside
of imp() below we retrieve the actual method to run from the type of
the receiver argument (i.e. self).
"""
to_register = list(instance_type.jit_methods) + \
list(instance_type.jit_static_methods)
for meth in to_register:
# There's no way to retrieve the particular method name
# inside the implementation function, so we have to register a
# specific closure for each different name
if meth not in self.implemented_methods:
self._implement_method(registry, meth)
self.implemented_methods.add(meth)
def _implement_method(self, registry, attr):
# create a separate instance of imp method to avoid closure clashing
def get_imp():
def imp(context, builder, sig, args):
instance_type = sig.args[0]
if attr in instance_type.jit_methods:
method = instance_type.jit_methods[attr]
elif attr in instance_type.jit_static_methods:
method = instance_type.jit_static_methods[attr]
# imp gets called as a method, where the first argument is
# self. We drop this for a static method.
sig = sig.replace(args=sig.args[1:])
args = args[1:]
disp_type = types.Dispatcher(method)
call = context.get_function(disp_type, sig)
out = call(builder, args)
_add_linking_libs(context, call)
return imputils.impl_ret_new_ref(context, builder,
sig.return_type, out)
return imp
def _getsetitem_gen(getset):
_dunder_meth = "__%s__" % getset
op = getattr(operator, getset)
@templates.infer_global(op)
class GetSetItem(templates.AbstractTemplate):
def generic(self, args, kws):
instance = args[0]
if isinstance(instance, types.ClassInstanceType) and \
_dunder_meth in instance.jit_methods:
meth = instance.jit_methods[_dunder_meth]
disp_type = types.Dispatcher(meth)
sig = disp_type.get_call_type(self.context, args, kws)
return sig
# lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
# from python and numba
imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
imputils.lower_builtin(op,
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
dunder_stripped = attr.strip('_')
if dunder_stripped in ("getitem", "setitem"):
_getsetitem_gen(dunder_stripped)
else:
registry.lower((types.ClassInstanceType, attr),
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
@templates.infer_getattr
class ClassAttribute(templates.AttributeTemplate):
key = types.ClassInstanceType
def generic_resolve(self, instance, attr):
if attr in instance.struct:
# It's a struct field => the type is well-known
return instance.struct[attr]
elif attr in instance.jit_methods:
# It's a jitted method => typeinfer it
meth = instance.jit_methods[attr]
disp_type = types.Dispatcher(meth)
class MethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
args = (instance,) + tuple(args)
sig = disp_type.get_call_type(self.context, args, kws)
return sig.as_method()
return types.BoundFunction(MethodTemplate, instance)
elif attr in instance.jit_static_methods:
# It's a jitted method => typeinfer it
meth = instance.jit_static_methods[attr]
disp_type = types.Dispatcher(meth)
class StaticMethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
# Don't add instance as the first argument for a static
# method.
sig = disp_type.get_call_type(self.context, args, kws)
# sig itself does not include ClassInstanceType as it's
# first argument, so instead of calling sig.as_method()
# we insert the recvr. This is equivalent to
# sig.replace(args=(instance,) + sig.args).as_method().
return sig.replace(recvr=instance)
return types.BoundFunction(StaticMethodTemplate, instance)
elif attr in instance.jit_props:
# It's a jitted property => typeinfer its getter
impdct = instance.jit_props[attr]
getter = impdct['get']
disp_type = types.Dispatcher(getter)
sig = disp_type.get_call_type(self.context, (instance,), {})
return sig.return_type
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
def get_attr_impl(context, builder, typ, value, attr):
"""
Generic getattr() for @jitclass instances.
"""
if attr in typ.struct:
# It's a struct field
inst = context.make_helper(builder, typ, value=value)
data_pointer = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_pointer)
return imputils.impl_ret_borrowed(context, builder,
typ.struct[attr],
getattr(data, _mangle_attr(attr)))
elif attr in typ.jit_props:
# It's a jitted property
getter = typ.jit_props[attr]['get']
sig = templates.signature(None, typ)
dispatcher = types.Dispatcher(getter)
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
call = context.get_function(dispatcher, sig)
out = call(builder, [value])
_add_linking_libs(context, call)
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
def set_attr_impl(context, builder, sig, args, attr):
"""
Generic setattr() for @jitclass instances.
"""
typ, valty = sig.args
target, val = args
if attr in typ.struct:
# It's a struct member
inst = context.make_helper(builder, typ, value=target)
data_ptr = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_ptr)
# Get old value
attr_type = typ.struct[attr]
oldvalue = getattr(data, _mangle_attr(attr))
# Store n
setattr(data, _mangle_attr(attr), val)
context.nrt.incref(builder, attr_type, val)
# Delete old value
context.nrt.decref(builder, attr_type, oldvalue)
elif attr in typ.jit_props:
# It's a jitted property
setter = typ.jit_props[attr]['set']
disp_type = types.Dispatcher(setter)
sig = disp_type.get_call_type(context.typing_context,
(typ, valty), {})
call = context.get_function(disp_type, sig)
call(builder, (target, val))
_add_linking_libs(context, call)
else:
raise NotImplementedError(
'attribute {0!r} not implemented'.format(attr))
def imp_dtor(context, module, instance_type):
llvoidptr = context.get_value_type(types.voidptr)
llsize = context.get_value_type(types.uintp)
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
[llvoidptr, llsize, llvoidptr])
fname = "_Dtor.{0}".format(instance_type.name)
dtor_fn = cgutils.get_or_insert_function(module, dtor_ftype, fname)
if dtor_fn.is_declaration:
# Define
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
alloc_fe_type = instance_type.get_data_type()
alloc_type = context.get_value_type(alloc_fe_type)
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
builder.ret_void()
return dtor_fn
@ClassBuilder.class_impl_registry.lower(types.ClassType,
types.VarArg(types.Any))
def ctor_impl(context, builder, sig, args):
"""
Generic constructor (__new__) for jitclasses.
"""
# Allocate the instance
inst_typ = sig.return_type
alloc_type = context.get_data_type(inst_typ.get_data_type())
alloc_size = context.get_abi_sizeof(alloc_type)
meminfo = context.nrt.meminfo_alloc_dtor(
builder,
context.get_constant(types.uintp, alloc_size),
imp_dtor(context, builder.module, inst_typ),
)
data_pointer = context.nrt.meminfo_data(builder, meminfo)
data_pointer = builder.bitcast(data_pointer,
alloc_type.as_pointer())
# Nullify all data
builder.store(cgutils.get_null_value(alloc_type),
data_pointer)
inst_struct = context.make_helper(builder, inst_typ)
inst_struct.meminfo = meminfo
inst_struct.data = data_pointer
# Call the jitted __init__
# TODO: extract the following into a common util
init_sig = (sig.return_type,) + sig.args
init = inst_typ.jit_methods['__init__']
disp_type = types.Dispatcher(init)
call = context.get_function(disp_type, types.void(*init_sig))
_add_linking_libs(context, call)
realargs = [inst_struct._getvalue()] + list(args)
call(builder, realargs)
# Prepare return value
ret = inst_struct._getvalue()
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)