385 lines
11 KiB
Python
385 lines
11 KiB
Python
|
"""Utilities for defining a mutable struct.
|
||
|
|
||
|
A mutable struct is passed by reference;
|
||
|
hence, structref (a reference to a struct).
|
||
|
|
||
|
"""
|
||
|
from numba import njit
|
||
|
from numba.core import types, imputils, cgutils
|
||
|
from numba.core.datamodel import default_manager, models
|
||
|
from numba.core.extending import (
|
||
|
infer_getattr,
|
||
|
lower_getattr_generic,
|
||
|
lower_setattr_generic,
|
||
|
box,
|
||
|
unbox,
|
||
|
NativeValue,
|
||
|
intrinsic,
|
||
|
overload,
|
||
|
)
|
||
|
from numba.core.typing.templates import AttributeTemplate
|
||
|
|
||
|
|
||
|
class _Utils:
|
||
|
"""Internal builder-code utils for structref definitions.
|
||
|
"""
|
||
|
def __init__(self, context, builder, struct_type):
|
||
|
"""
|
||
|
Parameters
|
||
|
----------
|
||
|
context :
|
||
|
a numba target context
|
||
|
builder :
|
||
|
a llvmlite IRBuilder
|
||
|
struct_type : numba.core.types.StructRef
|
||
|
"""
|
||
|
self.context = context
|
||
|
self.builder = builder
|
||
|
self.struct_type = struct_type
|
||
|
|
||
|
def new_struct_ref(self, mi):
|
||
|
"""Encapsulate the MemInfo from a `StructRefPayload` in a `StructRef`
|
||
|
"""
|
||
|
context = self.context
|
||
|
builder = self.builder
|
||
|
struct_type = self.struct_type
|
||
|
|
||
|
st = cgutils.create_struct_proxy(struct_type)(context, builder)
|
||
|
st.meminfo = mi
|
||
|
return st
|
||
|
|
||
|
def get_struct_ref(self, val):
|
||
|
"""Return a helper for accessing a StructRefType
|
||
|
"""
|
||
|
context = self.context
|
||
|
builder = self.builder
|
||
|
struct_type = self.struct_type
|
||
|
|
||
|
return cgutils.create_struct_proxy(struct_type)(
|
||
|
context, builder, value=val
|
||
|
)
|
||
|
|
||
|
def get_data_pointer(self, val):
|
||
|
"""Get the data pointer to the payload from a `StructRefType`.
|
||
|
"""
|
||
|
context = self.context
|
||
|
builder = self.builder
|
||
|
struct_type = self.struct_type
|
||
|
|
||
|
structval = self.get_struct_ref(val)
|
||
|
meminfo = structval.meminfo
|
||
|
data_ptr = context.nrt.meminfo_data(builder, meminfo)
|
||
|
|
||
|
valtype = struct_type.get_data_type()
|
||
|
model = context.data_model_manager[valtype]
|
||
|
alloc_type = model.get_value_type()
|
||
|
data_ptr = builder.bitcast(data_ptr, alloc_type.as_pointer())
|
||
|
return data_ptr
|
||
|
|
||
|
def get_data_struct(self, val):
|
||
|
"""Get a getter/setter helper for accessing a `StructRefPayload`
|
||
|
"""
|
||
|
context = self.context
|
||
|
builder = self.builder
|
||
|
struct_type = self.struct_type
|
||
|
|
||
|
data_ptr = self.get_data_pointer(val)
|
||
|
valtype = struct_type.get_data_type()
|
||
|
dataval = cgutils.create_struct_proxy(valtype)(
|
||
|
context, builder, ref=data_ptr
|
||
|
)
|
||
|
return dataval
|
||
|
|
||
|
|
||
|
def define_attributes(struct_typeclass):
|
||
|
"""Define attributes on `struct_typeclass`.
|
||
|
|
||
|
Defines both setters and getters in jit-code.
|
||
|
|
||
|
This is called directly in `register()`.
|
||
|
"""
|
||
|
@infer_getattr
|
||
|
class StructAttribute(AttributeTemplate):
|
||
|
key = struct_typeclass
|
||
|
|
||
|
def generic_resolve(self, typ, attr):
|
||
|
if attr in typ.field_dict:
|
||
|
attrty = typ.field_dict[attr]
|
||
|
return attrty
|
||
|
|
||
|
@lower_getattr_generic(struct_typeclass)
|
||
|
def struct_getattr_impl(context, builder, typ, val, attr):
|
||
|
utils = _Utils(context, builder, typ)
|
||
|
dataval = utils.get_data_struct(val)
|
||
|
ret = getattr(dataval, attr)
|
||
|
fieldtype = typ.field_dict[attr]
|
||
|
return imputils.impl_ret_borrowed(context, builder, fieldtype, ret)
|
||
|
|
||
|
@lower_setattr_generic(struct_typeclass)
|
||
|
def struct_setattr_impl(context, builder, sig, args, attr):
|
||
|
[inst_type, val_type] = sig.args
|
||
|
[instance, val] = args
|
||
|
utils = _Utils(context, builder, inst_type)
|
||
|
dataval = utils.get_data_struct(instance)
|
||
|
# cast val to the correct type
|
||
|
field_type = inst_type.field_dict[attr]
|
||
|
casted = context.cast(builder, val, val_type, field_type)
|
||
|
# read old
|
||
|
old_value = getattr(dataval, attr)
|
||
|
# incref new value
|
||
|
context.nrt.incref(builder, val_type, casted)
|
||
|
# decref old value (must be last in case new value is old value)
|
||
|
context.nrt.decref(builder, val_type, old_value)
|
||
|
# write new
|
||
|
setattr(dataval, attr, casted)
|
||
|
|
||
|
|
||
|
def define_boxing(struct_type, obj_class):
|
||
|
"""Define the boxing & unboxing logic for `struct_type` to `obj_class`.
|
||
|
|
||
|
Defines both boxing and unboxing.
|
||
|
|
||
|
- boxing turns an instance of `struct_type` into a PyObject of `obj_class`
|
||
|
- unboxing turns an instance of `obj_class` into an instance of
|
||
|
`struct_type` in jit-code.
|
||
|
|
||
|
|
||
|
Use this directly instead of `define_proxy()` when the user does not
|
||
|
want any constructor to be defined.
|
||
|
"""
|
||
|
if struct_type is types.StructRef:
|
||
|
raise ValueError(f"cannot register {types.StructRef}")
|
||
|
|
||
|
obj_ctor = obj_class._numba_box_
|
||
|
|
||
|
@box(struct_type)
|
||
|
def box_struct_ref(typ, val, c):
|
||
|
"""
|
||
|
Convert a raw pointer to a Python int.
|
||
|
"""
|
||
|
utils = _Utils(c.context, c.builder, typ)
|
||
|
struct_ref = utils.get_struct_ref(val)
|
||
|
meminfo = struct_ref.meminfo
|
||
|
|
||
|
mip_type = types.MemInfoPointer(types.voidptr)
|
||
|
boxed_meminfo = c.box(mip_type, meminfo)
|
||
|
|
||
|
ctor_pyfunc = c.pyapi.unserialize(c.pyapi.serialize_object(obj_ctor))
|
||
|
ty_pyobj = c.pyapi.unserialize(c.pyapi.serialize_object(typ))
|
||
|
|
||
|
res = c.pyapi.call_function_objargs(
|
||
|
ctor_pyfunc, [ty_pyobj, boxed_meminfo],
|
||
|
)
|
||
|
c.pyapi.decref(ctor_pyfunc)
|
||
|
c.pyapi.decref(ty_pyobj)
|
||
|
c.pyapi.decref(boxed_meminfo)
|
||
|
return res
|
||
|
|
||
|
@unbox(struct_type)
|
||
|
def unbox_struct_ref(typ, obj, c):
|
||
|
mi_obj = c.pyapi.object_getattr_string(obj, "_meminfo")
|
||
|
|
||
|
mip_type = types.MemInfoPointer(types.voidptr)
|
||
|
|
||
|
mi = c.unbox(mip_type, mi_obj).value
|
||
|
|
||
|
utils = _Utils(c.context, c.builder, typ)
|
||
|
struct_ref = utils.new_struct_ref(mi)
|
||
|
out = struct_ref._getvalue()
|
||
|
|
||
|
c.pyapi.decref(mi_obj)
|
||
|
return NativeValue(out)
|
||
|
|
||
|
|
||
|
def define_constructor(py_class, struct_typeclass, fields):
|
||
|
"""Define the jit-code constructor for `struct_typeclass` using the
|
||
|
Python type `py_class` and the required `fields`.
|
||
|
|
||
|
Use this instead of `define_proxy()` if the user does not want boxing
|
||
|
logic defined.
|
||
|
"""
|
||
|
# Build source code for the constructor
|
||
|
params = ', '.join(fields)
|
||
|
indent = ' ' * 8
|
||
|
init_fields_buf = []
|
||
|
for k in fields:
|
||
|
init_fields_buf.append(f"st.{k} = {k}")
|
||
|
init_fields = f'\n{indent}'.join(init_fields_buf)
|
||
|
|
||
|
source = f"""
|
||
|
def ctor({params}):
|
||
|
struct_type = struct_typeclass(list(zip({list(fields)}, [{params}])))
|
||
|
def impl({params}):
|
||
|
st = new(struct_type)
|
||
|
{init_fields}
|
||
|
return st
|
||
|
return impl
|
||
|
"""
|
||
|
|
||
|
glbs = dict(struct_typeclass=struct_typeclass, new=new)
|
||
|
exec(source, glbs)
|
||
|
ctor = glbs['ctor']
|
||
|
# Make it an overload
|
||
|
overload(py_class)(ctor)
|
||
|
|
||
|
|
||
|
def define_proxy(py_class, struct_typeclass, fields):
|
||
|
"""Defines a PyObject proxy for a structref.
|
||
|
|
||
|
This makes `py_class` a valid constructor for creating a instance of
|
||
|
`struct_typeclass` that contains the members as defined by `fields`.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
py_class : type
|
||
|
The Python class for constructing an instance of `struct_typeclass`.
|
||
|
struct_typeclass : numba.core.types.Type
|
||
|
The structref type class to bind to.
|
||
|
fields : Sequence[str]
|
||
|
A sequence of field names.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
None
|
||
|
"""
|
||
|
define_constructor(py_class, struct_typeclass, fields)
|
||
|
define_boxing(struct_typeclass, py_class)
|
||
|
|
||
|
|
||
|
def register(struct_type):
|
||
|
"""Register a `numba.core.types.StructRef` for use in jit-code.
|
||
|
|
||
|
This defines the data-model for lowering an instance of `struct_type`.
|
||
|
This defines attributes accessor and mutator for an instance of
|
||
|
`struct_type`.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
struct_type : type
|
||
|
A subclass of `numba.core.types.StructRef`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
struct_type : type
|
||
|
Returns the input argument so this can act like a decorator.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
|
||
|
.. code-block::
|
||
|
|
||
|
class MyStruct(numba.core.types.StructRef):
|
||
|
... # the simplest subclass can be empty
|
||
|
|
||
|
numba.experimental.structref.register(MyStruct)
|
||
|
|
||
|
"""
|
||
|
if struct_type is types.StructRef:
|
||
|
raise ValueError(f"cannot register {types.StructRef}")
|
||
|
default_manager.register(struct_type, models.StructRefModel)
|
||
|
define_attributes(struct_type)
|
||
|
return struct_type
|
||
|
|
||
|
|
||
|
@intrinsic
|
||
|
def new(typingctx, struct_type):
|
||
|
"""new(struct_type)
|
||
|
|
||
|
A jit-code only intrinsic. Used to allocate an **empty** mutable struct.
|
||
|
The fields are zero-initialized and must be set manually after calling
|
||
|
the function.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
instance = new(MyStruct)
|
||
|
instance.field = field_value
|
||
|
"""
|
||
|
from numba.experimental.jitclass.base import imp_dtor
|
||
|
|
||
|
inst_type = struct_type.instance_type
|
||
|
|
||
|
def codegen(context, builder, signature, args):
|
||
|
# FIXME: mostly the same as jitclass ctor_impl()
|
||
|
model = context.data_model_manager[inst_type.get_data_type()]
|
||
|
alloc_type = model.get_value_type()
|
||
|
alloc_size = context.get_abi_sizeof(alloc_type)
|
||
|
|
||
|
meminfo = context.nrt.meminfo_alloc_dtor(
|
||
|
builder,
|
||
|
context.get_constant(types.uintp, alloc_size),
|
||
|
imp_dtor(context, builder.module, inst_type),
|
||
|
)
|
||
|
data_pointer = context.nrt.meminfo_data(builder, meminfo)
|
||
|
data_pointer = builder.bitcast(data_pointer, alloc_type.as_pointer())
|
||
|
|
||
|
# Nullify all data
|
||
|
builder.store(cgutils.get_null_value(alloc_type), data_pointer)
|
||
|
|
||
|
inst_struct = context.make_helper(builder, inst_type)
|
||
|
inst_struct.meminfo = meminfo
|
||
|
|
||
|
return inst_struct._getvalue()
|
||
|
|
||
|
sig = inst_type(struct_type)
|
||
|
return sig, codegen
|
||
|
|
||
|
|
||
|
class StructRefProxy:
|
||
|
"""A PyObject proxy to the Numba allocated structref data structure.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
|
||
|
* Subclasses should not define ``__init__``.
|
||
|
* Subclasses can override ``__new__``.
|
||
|
"""
|
||
|
__slots__ = ('_type', '_meminfo')
|
||
|
|
||
|
@classmethod
|
||
|
def _numba_box_(cls, ty, mi):
|
||
|
"""Called by boxing logic, the conversion of Numba internal
|
||
|
representation into a PyObject.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
ty :
|
||
|
a Numba type instance.
|
||
|
mi :
|
||
|
a wrapped MemInfoPointer.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
instance :
|
||
|
a StructRefProxy instance.
|
||
|
"""
|
||
|
instance = super().__new__(cls)
|
||
|
instance._type = ty
|
||
|
instance._meminfo = mi
|
||
|
return instance
|
||
|
|
||
|
def __new__(cls, *args):
|
||
|
"""Construct a new instance of the structref.
|
||
|
|
||
|
This takes positional-arguments only due to limitation of the compiler.
|
||
|
The arguments are mapped to ``cls(*args)`` in jit-code.
|
||
|
"""
|
||
|
try:
|
||
|
# use cached ctor if available
|
||
|
ctor = cls.__numba_ctor
|
||
|
except AttributeError:
|
||
|
# lazily define the ctor
|
||
|
@njit
|
||
|
def ctor(*args):
|
||
|
return cls(*args)
|
||
|
# cache it to attribute to avoid recompilation
|
||
|
cls.__numba_ctor = ctor
|
||
|
return ctor(*args)
|
||
|
|
||
|
@property
|
||
|
def _numba_type_(self):
|
||
|
"""Returns the Numba type instance for this structref instance.
|
||
|
|
||
|
Subclasses should NOT override.
|
||
|
"""
|
||
|
return self._type
|