49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
|
import functools
|
||
|
|
||
|
from llvmlite import ir
|
||
|
|
||
|
from numba.core.datamodel.registry import DataModelManager, register
|
||
|
from numba.core.extending import models
|
||
|
from numba.core import types
|
||
|
from numba.cuda.types import Dim3, GridGroup, CUDADispatcher
|
||
|
|
||
|
|
||
|
cuda_data_manager = DataModelManager()
|
||
|
|
||
|
register_model = functools.partial(register, cuda_data_manager)
|
||
|
|
||
|
|
||
|
@register_model(Dim3)
|
||
|
class Dim3Model(models.StructModel):
|
||
|
def __init__(self, dmm, fe_type):
|
||
|
members = [
|
||
|
('x', types.int32),
|
||
|
('y', types.int32),
|
||
|
('z', types.int32)
|
||
|
]
|
||
|
super().__init__(dmm, fe_type, members)
|
||
|
|
||
|
|
||
|
@register_model(GridGroup)
|
||
|
class GridGroupModel(models.PrimitiveModel):
|
||
|
def __init__(self, dmm, fe_type):
|
||
|
be_type = ir.IntType(64)
|
||
|
super().__init__(dmm, fe_type, be_type)
|
||
|
|
||
|
|
||
|
@register_model(types.Float)
|
||
|
class FloatModel(models.PrimitiveModel):
|
||
|
def __init__(self, dmm, fe_type):
|
||
|
if fe_type == types.float16:
|
||
|
be_type = ir.IntType(16)
|
||
|
elif fe_type == types.float32:
|
||
|
be_type = ir.FloatType()
|
||
|
elif fe_type == types.float64:
|
||
|
be_type = ir.DoubleType()
|
||
|
else:
|
||
|
raise NotImplementedError(fe_type)
|
||
|
super(FloatModel, self).__init__(dmm, fe_type, be_type)
|
||
|
|
||
|
|
||
|
register_model(CUDADispatcher)(models.OpaqueModel)
|