63 lines
1.5 KiB
Python
63 lines
1.5 KiB
Python
|
from numba.core import types
|
||
|
from numba.core.extending import overload, overload_method
|
||
|
from numba.core.typing import signature
|
||
|
from numba.cuda import nvvmutils
|
||
|
from numba.cuda.extending import intrinsic
|
||
|
from numba.cuda.types import grid_group, GridGroup as GridGroupClass
|
||
|
|
||
|
|
||
|
class GridGroup:
|
||
|
"""A cooperative group representing the entire grid"""
|
||
|
|
||
|
def sync() -> None:
|
||
|
"""Synchronize this grid group"""
|
||
|
|
||
|
|
||
|
def this_grid() -> GridGroup:
|
||
|
"""Get the current grid group."""
|
||
|
return GridGroup()
|
||
|
|
||
|
|
||
|
@intrinsic
|
||
|
def _this_grid(typingctx):
|
||
|
sig = signature(grid_group)
|
||
|
|
||
|
def codegen(context, builder, sig, args):
|
||
|
one = context.get_constant(types.int32, 1)
|
||
|
mod = builder.module
|
||
|
return builder.call(
|
||
|
nvvmutils.declare_cudaCGGetIntrinsicHandle(mod),
|
||
|
(one,))
|
||
|
|
||
|
return sig, codegen
|
||
|
|
||
|
|
||
|
@overload(this_grid, target='cuda')
|
||
|
def _ol_this_grid():
|
||
|
def impl():
|
||
|
return _this_grid()
|
||
|
|
||
|
return impl
|
||
|
|
||
|
|
||
|
@intrinsic
|
||
|
def _grid_group_sync(typingctx, group):
|
||
|
sig = signature(types.int32, group)
|
||
|
|
||
|
def codegen(context, builder, sig, args):
|
||
|
flags = context.get_constant(types.int32, 0)
|
||
|
mod = builder.module
|
||
|
return builder.call(
|
||
|
nvvmutils.declare_cudaCGSynchronize(mod),
|
||
|
(*args, flags))
|
||
|
|
||
|
return sig, codegen
|
||
|
|
||
|
|
||
|
@overload_method(GridGroupClass, 'sync', target='cuda')
|
||
|
def _ol_grid_group_sync(group):
|
||
|
def impl(group):
|
||
|
return _grid_group_sync(group)
|
||
|
|
||
|
return impl
|