236 lines
8.1 KiB
Python
236 lines
8.1 KiB
Python
|
import itertools
|
||
|
from llvmlite import ir
|
||
|
from numba.core import cgutils, targetconfig
|
||
|
from .cudadrv import nvvm
|
||
|
|
||
|
|
||
|
def declare_atomic_cas_int(lmod, isize):
|
||
|
fname = '___numba_atomic_i' + str(isize) + '_cas_hack'
|
||
|
fnty = ir.FunctionType(ir.IntType(isize),
|
||
|
(ir.PointerType(ir.IntType(isize)),
|
||
|
ir.IntType(isize),
|
||
|
ir.IntType(isize)))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def atomic_cmpxchg(builder, lmod, isize, ptr, cmp, val):
|
||
|
out = builder.cmpxchg(ptr, cmp, val, 'monotonic', 'monotonic')
|
||
|
return builder.extract_value(out, 0)
|
||
|
|
||
|
|
||
|
def declare_atomic_add_float32(lmod):
|
||
|
fname = 'llvm.nvvm.atomic.load.add.f32.p0f32'
|
||
|
fnty = ir.FunctionType(ir.FloatType(),
|
||
|
(ir.PointerType(ir.FloatType(), 0), ir.FloatType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_add_float64(lmod):
|
||
|
flags = targetconfig.ConfigStack().top()
|
||
|
if flags.compute_capability >= (6, 0):
|
||
|
fname = 'llvm.nvvm.atomic.load.add.f64.p0f64'
|
||
|
else:
|
||
|
fname = '___numba_atomic_double_add'
|
||
|
fnty = ir.FunctionType(ir.DoubleType(),
|
||
|
(ir.PointerType(ir.DoubleType()), ir.DoubleType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_sub_float32(lmod):
|
||
|
fname = '___numba_atomic_float_sub'
|
||
|
fnty = ir.FunctionType(ir.FloatType(),
|
||
|
(ir.PointerType(ir.FloatType()), ir.FloatType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_sub_float64(lmod):
|
||
|
fname = '___numba_atomic_double_sub'
|
||
|
fnty = ir.FunctionType(ir.DoubleType(),
|
||
|
(ir.PointerType(ir.DoubleType()), ir.DoubleType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_inc_int32(lmod):
|
||
|
fname = 'llvm.nvvm.atomic.load.inc.32.p0i32'
|
||
|
fnty = ir.FunctionType(ir.IntType(32),
|
||
|
(ir.PointerType(ir.IntType(32)), ir.IntType(32)))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_inc_int64(lmod):
|
||
|
fname = '___numba_atomic_u64_inc'
|
||
|
fnty = ir.FunctionType(ir.IntType(64),
|
||
|
(ir.PointerType(ir.IntType(64)), ir.IntType(64)))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_dec_int32(lmod):
|
||
|
fname = 'llvm.nvvm.atomic.load.dec.32.p0i32'
|
||
|
fnty = ir.FunctionType(ir.IntType(32),
|
||
|
(ir.PointerType(ir.IntType(32)), ir.IntType(32)))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_dec_int64(lmod):
|
||
|
fname = '___numba_atomic_u64_dec'
|
||
|
fnty = ir.FunctionType(ir.IntType(64),
|
||
|
(ir.PointerType(ir.IntType(64)), ir.IntType(64)))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_max_float32(lmod):
|
||
|
fname = '___numba_atomic_float_max'
|
||
|
fnty = ir.FunctionType(ir.FloatType(),
|
||
|
(ir.PointerType(ir.FloatType()), ir.FloatType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_max_float64(lmod):
|
||
|
fname = '___numba_atomic_double_max'
|
||
|
fnty = ir.FunctionType(ir.DoubleType(),
|
||
|
(ir.PointerType(ir.DoubleType()), ir.DoubleType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_min_float32(lmod):
|
||
|
fname = '___numba_atomic_float_min'
|
||
|
fnty = ir.FunctionType(ir.FloatType(),
|
||
|
(ir.PointerType(ir.FloatType()), ir.FloatType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_min_float64(lmod):
|
||
|
fname = '___numba_atomic_double_min'
|
||
|
fnty = ir.FunctionType(ir.DoubleType(),
|
||
|
(ir.PointerType(ir.DoubleType()), ir.DoubleType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_nanmax_float32(lmod):
|
||
|
fname = '___numba_atomic_float_nanmax'
|
||
|
fnty = ir.FunctionType(ir.FloatType(),
|
||
|
(ir.PointerType(ir.FloatType()), ir.FloatType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_nanmax_float64(lmod):
|
||
|
fname = '___numba_atomic_double_nanmax'
|
||
|
fnty = ir.FunctionType(ir.DoubleType(),
|
||
|
(ir.PointerType(ir.DoubleType()), ir.DoubleType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_nanmin_float32(lmod):
|
||
|
fname = '___numba_atomic_float_nanmin'
|
||
|
fnty = ir.FunctionType(ir.FloatType(),
|
||
|
(ir.PointerType(ir.FloatType()), ir.FloatType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_atomic_nanmin_float64(lmod):
|
||
|
fname = '___numba_atomic_double_nanmin'
|
||
|
fnty = ir.FunctionType(ir.DoubleType(),
|
||
|
(ir.PointerType(ir.DoubleType()), ir.DoubleType()))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_cudaCGGetIntrinsicHandle(lmod):
|
||
|
fname = 'cudaCGGetIntrinsicHandle'
|
||
|
fnty = ir.FunctionType(ir.IntType(64),
|
||
|
(ir.IntType(32),))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_cudaCGSynchronize(lmod):
|
||
|
fname = 'cudaCGSynchronize'
|
||
|
fnty = ir.FunctionType(ir.IntType(32),
|
||
|
(ir.IntType(64), ir.IntType(32)))
|
||
|
return cgutils.get_or_insert_function(lmod, fnty, fname)
|
||
|
|
||
|
|
||
|
def declare_string(builder, value):
|
||
|
lmod = builder.basic_block.function.module
|
||
|
cval = cgutils.make_bytearray(value.encode("utf-8") + b"\x00")
|
||
|
gl = cgutils.add_global_variable(lmod, cval.type, name="_str",
|
||
|
addrspace=nvvm.ADDRSPACE_CONSTANT)
|
||
|
gl.linkage = 'internal'
|
||
|
gl.global_constant = True
|
||
|
gl.initializer = cval
|
||
|
|
||
|
return builder.addrspacecast(gl, ir.PointerType(ir.IntType(8)), 'generic')
|
||
|
|
||
|
|
||
|
def declare_vprint(lmod):
|
||
|
voidptrty = ir.PointerType(ir.IntType(8))
|
||
|
# NOTE: the second argument to vprintf() points to the variable-length
|
||
|
# array of arguments (after the format)
|
||
|
vprintfty = ir.FunctionType(ir.IntType(32), [voidptrty, voidptrty])
|
||
|
vprintf = cgutils.get_or_insert_function(lmod, vprintfty, "vprintf")
|
||
|
return vprintf
|
||
|
|
||
|
|
||
|
# -----------------------------------------------------------------------------
|
||
|
|
||
|
SREG_MAPPING = {
|
||
|
'tid.x': 'llvm.nvvm.read.ptx.sreg.tid.x',
|
||
|
'tid.y': 'llvm.nvvm.read.ptx.sreg.tid.y',
|
||
|
'tid.z': 'llvm.nvvm.read.ptx.sreg.tid.z',
|
||
|
|
||
|
'ntid.x': 'llvm.nvvm.read.ptx.sreg.ntid.x',
|
||
|
'ntid.y': 'llvm.nvvm.read.ptx.sreg.ntid.y',
|
||
|
'ntid.z': 'llvm.nvvm.read.ptx.sreg.ntid.z',
|
||
|
|
||
|
'ctaid.x': 'llvm.nvvm.read.ptx.sreg.ctaid.x',
|
||
|
'ctaid.y': 'llvm.nvvm.read.ptx.sreg.ctaid.y',
|
||
|
'ctaid.z': 'llvm.nvvm.read.ptx.sreg.ctaid.z',
|
||
|
|
||
|
'nctaid.x': 'llvm.nvvm.read.ptx.sreg.nctaid.x',
|
||
|
'nctaid.y': 'llvm.nvvm.read.ptx.sreg.nctaid.y',
|
||
|
'nctaid.z': 'llvm.nvvm.read.ptx.sreg.nctaid.z',
|
||
|
|
||
|
'warpsize': 'llvm.nvvm.read.ptx.sreg.warpsize',
|
||
|
'laneid': 'llvm.nvvm.read.ptx.sreg.laneid',
|
||
|
}
|
||
|
|
||
|
|
||
|
def call_sreg(builder, name):
|
||
|
module = builder.module
|
||
|
fnty = ir.FunctionType(ir.IntType(32), ())
|
||
|
fn = cgutils.get_or_insert_function(module, fnty, SREG_MAPPING[name])
|
||
|
return builder.call(fn, ())
|
||
|
|
||
|
|
||
|
class SRegBuilder(object):
|
||
|
def __init__(self, builder):
|
||
|
self.builder = builder
|
||
|
|
||
|
def tid(self, xyz):
|
||
|
return call_sreg(self.builder, 'tid.%s' % xyz)
|
||
|
|
||
|
def ctaid(self, xyz):
|
||
|
return call_sreg(self.builder, 'ctaid.%s' % xyz)
|
||
|
|
||
|
def ntid(self, xyz):
|
||
|
return call_sreg(self.builder, 'ntid.%s' % xyz)
|
||
|
|
||
|
def nctaid(self, xyz):
|
||
|
return call_sreg(self.builder, 'nctaid.%s' % xyz)
|
||
|
|
||
|
def getdim(self, xyz):
|
||
|
i64 = ir.IntType(64)
|
||
|
tid = self.builder.sext(self.tid(xyz), i64)
|
||
|
ntid = self.builder.sext(self.ntid(xyz), i64)
|
||
|
nctaid = self.builder.sext(self.ctaid(xyz), i64)
|
||
|
res = self.builder.add(self.builder.mul(ntid, nctaid), tid)
|
||
|
return res
|
||
|
|
||
|
|
||
|
def get_global_id(builder, dim):
|
||
|
sreg = SRegBuilder(builder)
|
||
|
it = (sreg.getdim(xyz) for xyz in 'xyz')
|
||
|
seq = list(itertools.islice(it, None, dim))
|
||
|
if dim == 1:
|
||
|
return seq[0]
|
||
|
else:
|
||
|
return seq
|