ai-content-maker/.venv/Lib/site-packages/numba/cuda/intrinsics.py

199 lines
5.8 KiB
Python

from llvmlite import ir
from numba import cuda, types
from numba.core import cgutils
from numba.core.errors import RequireLiteralValue
from numba.core.typing import signature
from numba.core.extending import overload_attribute
from numba.cuda import nvvmutils
from numba.cuda.extending import intrinsic
#-------------------------------------------------------------------------------
# Grid functions
def _type_grid_function(ndim):
val = ndim.literal_value
if val == 1:
restype = types.int64
elif val in (2, 3):
restype = types.UniTuple(types.int64, val)
else:
raise ValueError('argument can only be 1, 2, 3')
return signature(restype, types.int32)
@intrinsic
def grid(typingctx, ndim):
'''grid(ndim)
Return the absolute position of the current thread in the entire grid of
blocks. *ndim* should correspond to the number of dimensions declared when
instantiating the kernel. If *ndim* is 1, a single integer is returned.
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
Computation of the first integer is as follows::
cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
and is similar for the other two indices, but using the ``y`` and ``z``
attributes.
'''
if not isinstance(ndim, types.IntegerLiteral):
raise RequireLiteralValue(ndim)
sig = _type_grid_function(ndim)
def codegen(context, builder, sig, args):
restype = sig.return_type
if restype == types.int64:
return nvvmutils.get_global_id(builder, dim=1)
elif isinstance(restype, types.UniTuple):
ids = nvvmutils.get_global_id(builder, dim=restype.count)
return cgutils.pack_array(builder, ids)
return sig, codegen
@intrinsic
def gridsize(typingctx, ndim):
'''gridsize(ndim)
Return the absolute size (or shape) in threads of the entire grid of
blocks. *ndim* should correspond to the number of dimensions declared when
instantiating the kernel. If *ndim* is 1, a single integer is returned.
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
Computation of the first integer is as follows::
cuda.blockDim.x * cuda.gridDim.x
and is similar for the other two indices, but using the ``y`` and ``z``
attributes.
'''
if not isinstance(ndim, types.IntegerLiteral):
raise RequireLiteralValue(ndim)
sig = _type_grid_function(ndim)
def _nthreads_for_dim(builder, dim):
i64 = ir.IntType(64)
ntid = nvvmutils.call_sreg(builder, f"ntid.{dim}")
nctaid = nvvmutils.call_sreg(builder, f"nctaid.{dim}")
return builder.mul(builder.sext(ntid, i64), builder.sext(nctaid, i64))
def codegen(context, builder, sig, args):
restype = sig.return_type
nx = _nthreads_for_dim(builder, 'x')
if restype == types.int64:
return nx
elif isinstance(restype, types.UniTuple):
ny = _nthreads_for_dim(builder, 'y')
if restype.count == 2:
return cgutils.pack_array(builder, (nx, ny))
elif restype.count == 3:
nz = _nthreads_for_dim(builder, 'z')
return cgutils.pack_array(builder, (nx, ny, nz))
return sig, codegen
@intrinsic
def _warpsize(typingctx):
sig = signature(types.int32)
def codegen(context, builder, sig, args):
return nvvmutils.call_sreg(builder, 'warpsize')
return sig, codegen
@overload_attribute(types.Module(cuda), 'warpsize', target='cuda')
def cuda_warpsize(mod):
'''
The size of a warp. All architectures implemented to date have a warp size
of 32.
'''
def get(mod):
return _warpsize()
return get
#-------------------------------------------------------------------------------
# syncthreads
@intrinsic
def syncthreads(typingctx):
'''
Synchronize all threads in the same thread block. This function implements
the same pattern as barriers in traditional multi-threaded programming: this
function waits until all threads in the block call it, at which point it
returns control to all its callers.
'''
sig = signature(types.none)
def codegen(context, builder, sig, args):
fname = 'llvm.nvvm.barrier0'
lmod = builder.module
fnty = ir.FunctionType(ir.VoidType(), ())
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
builder.call(sync, ())
return context.get_dummy_value()
return sig, codegen
def _syncthreads_predicate(typingctx, predicate, fname):
if not isinstance(predicate, types.Integer):
return None
sig = signature(types.i4, types.i4)
def codegen(context, builder, sig, args):
fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32),))
sync = cgutils.get_or_insert_function(builder.module, fnty, fname)
return builder.call(sync, args)
return sig, codegen
@intrinsic
def syncthreads_count(typingctx, predicate):
'''
syncthreads_count(predicate)
An extension to numba.cuda.syncthreads where the return value is a count
of the threads where predicate is true.
'''
fname = 'llvm.nvvm.barrier0.popc'
return _syncthreads_predicate(typingctx, predicate, fname)
@intrinsic
def syncthreads_and(typingctx, predicate):
'''
syncthreads_and(predicate)
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
true for all threads or 0 otherwise.
'''
fname = 'llvm.nvvm.barrier0.and'
return _syncthreads_predicate(typingctx, predicate, fname)
@intrinsic
def syncthreads_or(typingctx, predicate):
'''
syncthreads_or(predicate)
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
true for any thread or 0 otherwise.
'''
fname = 'llvm.nvvm.barrier0.or'
return _syncthreads_predicate(typingctx, predicate, fname)