199 lines
5.8 KiB
Python
199 lines
5.8 KiB
Python
from llvmlite import ir
|
|
|
|
from numba import cuda, types
|
|
from numba.core import cgutils
|
|
from numba.core.errors import RequireLiteralValue
|
|
from numba.core.typing import signature
|
|
from numba.core.extending import overload_attribute
|
|
from numba.cuda import nvvmutils
|
|
from numba.cuda.extending import intrinsic
|
|
|
|
|
|
#-------------------------------------------------------------------------------
|
|
# Grid functions
|
|
|
|
def _type_grid_function(ndim):
|
|
val = ndim.literal_value
|
|
if val == 1:
|
|
restype = types.int64
|
|
elif val in (2, 3):
|
|
restype = types.UniTuple(types.int64, val)
|
|
else:
|
|
raise ValueError('argument can only be 1, 2, 3')
|
|
|
|
return signature(restype, types.int32)
|
|
|
|
|
|
@intrinsic
|
|
def grid(typingctx, ndim):
|
|
'''grid(ndim)
|
|
|
|
Return the absolute position of the current thread in the entire grid of
|
|
blocks. *ndim* should correspond to the number of dimensions declared when
|
|
instantiating the kernel. If *ndim* is 1, a single integer is returned.
|
|
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
|
|
|
|
Computation of the first integer is as follows::
|
|
|
|
cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
|
|
|
|
and is similar for the other two indices, but using the ``y`` and ``z``
|
|
attributes.
|
|
'''
|
|
|
|
if not isinstance(ndim, types.IntegerLiteral):
|
|
raise RequireLiteralValue(ndim)
|
|
|
|
sig = _type_grid_function(ndim)
|
|
|
|
def codegen(context, builder, sig, args):
|
|
restype = sig.return_type
|
|
if restype == types.int64:
|
|
return nvvmutils.get_global_id(builder, dim=1)
|
|
elif isinstance(restype, types.UniTuple):
|
|
ids = nvvmutils.get_global_id(builder, dim=restype.count)
|
|
return cgutils.pack_array(builder, ids)
|
|
|
|
return sig, codegen
|
|
|
|
|
|
@intrinsic
|
|
def gridsize(typingctx, ndim):
|
|
'''gridsize(ndim)
|
|
|
|
Return the absolute size (or shape) in threads of the entire grid of
|
|
blocks. *ndim* should correspond to the number of dimensions declared when
|
|
instantiating the kernel. If *ndim* is 1, a single integer is returned.
|
|
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
|
|
|
|
Computation of the first integer is as follows::
|
|
|
|
cuda.blockDim.x * cuda.gridDim.x
|
|
|
|
and is similar for the other two indices, but using the ``y`` and ``z``
|
|
attributes.
|
|
'''
|
|
|
|
if not isinstance(ndim, types.IntegerLiteral):
|
|
raise RequireLiteralValue(ndim)
|
|
|
|
sig = _type_grid_function(ndim)
|
|
|
|
def _nthreads_for_dim(builder, dim):
|
|
i64 = ir.IntType(64)
|
|
ntid = nvvmutils.call_sreg(builder, f"ntid.{dim}")
|
|
nctaid = nvvmutils.call_sreg(builder, f"nctaid.{dim}")
|
|
return builder.mul(builder.sext(ntid, i64), builder.sext(nctaid, i64))
|
|
|
|
def codegen(context, builder, sig, args):
|
|
restype = sig.return_type
|
|
nx = _nthreads_for_dim(builder, 'x')
|
|
|
|
if restype == types.int64:
|
|
return nx
|
|
elif isinstance(restype, types.UniTuple):
|
|
ny = _nthreads_for_dim(builder, 'y')
|
|
|
|
if restype.count == 2:
|
|
return cgutils.pack_array(builder, (nx, ny))
|
|
elif restype.count == 3:
|
|
nz = _nthreads_for_dim(builder, 'z')
|
|
return cgutils.pack_array(builder, (nx, ny, nz))
|
|
|
|
return sig, codegen
|
|
|
|
|
|
@intrinsic
|
|
def _warpsize(typingctx):
|
|
sig = signature(types.int32)
|
|
|
|
def codegen(context, builder, sig, args):
|
|
return nvvmutils.call_sreg(builder, 'warpsize')
|
|
|
|
return sig, codegen
|
|
|
|
|
|
@overload_attribute(types.Module(cuda), 'warpsize', target='cuda')
|
|
def cuda_warpsize(mod):
|
|
'''
|
|
The size of a warp. All architectures implemented to date have a warp size
|
|
of 32.
|
|
'''
|
|
def get(mod):
|
|
return _warpsize()
|
|
return get
|
|
|
|
|
|
#-------------------------------------------------------------------------------
|
|
# syncthreads
|
|
|
|
@intrinsic
|
|
def syncthreads(typingctx):
|
|
'''
|
|
Synchronize all threads in the same thread block. This function implements
|
|
the same pattern as barriers in traditional multi-threaded programming: this
|
|
function waits until all threads in the block call it, at which point it
|
|
returns control to all its callers.
|
|
'''
|
|
sig = signature(types.none)
|
|
|
|
def codegen(context, builder, sig, args):
|
|
fname = 'llvm.nvvm.barrier0'
|
|
lmod = builder.module
|
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
|
builder.call(sync, ())
|
|
return context.get_dummy_value()
|
|
|
|
return sig, codegen
|
|
|
|
|
|
def _syncthreads_predicate(typingctx, predicate, fname):
|
|
if not isinstance(predicate, types.Integer):
|
|
return None
|
|
|
|
sig = signature(types.i4, types.i4)
|
|
|
|
def codegen(context, builder, sig, args):
|
|
fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32),))
|
|
sync = cgutils.get_or_insert_function(builder.module, fnty, fname)
|
|
return builder.call(sync, args)
|
|
|
|
return sig, codegen
|
|
|
|
|
|
@intrinsic
|
|
def syncthreads_count(typingctx, predicate):
|
|
'''
|
|
syncthreads_count(predicate)
|
|
|
|
An extension to numba.cuda.syncthreads where the return value is a count
|
|
of the threads where predicate is true.
|
|
'''
|
|
fname = 'llvm.nvvm.barrier0.popc'
|
|
return _syncthreads_predicate(typingctx, predicate, fname)
|
|
|
|
|
|
@intrinsic
|
|
def syncthreads_and(typingctx, predicate):
|
|
'''
|
|
syncthreads_and(predicate)
|
|
|
|
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
|
|
true for all threads or 0 otherwise.
|
|
'''
|
|
fname = 'llvm.nvvm.barrier0.and'
|
|
return _syncthreads_predicate(typingctx, predicate, fname)
|
|
|
|
|
|
@intrinsic
|
|
def syncthreads_or(typingctx, predicate):
|
|
'''
|
|
syncthreads_or(predicate)
|
|
|
|
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
|
|
true for any thread or 0 otherwise.
|
|
'''
|
|
fname = 'llvm.nvvm.barrier0.or'
|
|
return _syncthreads_predicate(typingctx, predicate, fname)
|