import functools from llvmlite import ir from numba.core.datamodel.registry import DataModelManager, register from numba.core.extending import models from numba.core import types from numba.cuda.types import Dim3, GridGroup, CUDADispatcher cuda_data_manager = DataModelManager() register_model = functools.partial(register, cuda_data_manager) @register_model(Dim3) class Dim3Model(models.StructModel): def __init__(self, dmm, fe_type): members = [ ('x', types.int32), ('y', types.int32), ('z', types.int32) ] super().__init__(dmm, fe_type, members) @register_model(GridGroup) class GridGroupModel(models.PrimitiveModel): def __init__(self, dmm, fe_type): be_type = ir.IntType(64) super().__init__(dmm, fe_type, be_type) @register_model(types.Float) class FloatModel(models.PrimitiveModel): def __init__(self, dmm, fe_type): if fe_type == types.float16: be_type = ir.IntType(16) elif fe_type == types.float32: be_type = ir.FloatType() elif fe_type == types.float64: be_type = ir.DoubleType() else: raise NotImplementedError(fe_type) super(FloatModel, self).__init__(dmm, fe_type, be_type) register_model(CUDADispatcher)(models.OpaqueModel)