import itertools from llvmlite import ir from numba.core import cgutils, targetconfig from .cudadrv import nvvm def declare_atomic_cas_int(lmod, isize): fname = '___numba_atomic_i' + str(isize) + '_cas_hack' fnty = ir.FunctionType(ir.IntType(isize), (ir.PointerType(ir.IntType(isize)), ir.IntType(isize), ir.IntType(isize))) return cgutils.get_or_insert_function(lmod, fnty, fname) def atomic_cmpxchg(builder, lmod, isize, ptr, cmp, val): out = builder.cmpxchg(ptr, cmp, val, 'monotonic', 'monotonic') return builder.extract_value(out, 0) def declare_atomic_add_float32(lmod): fname = 'llvm.nvvm.atomic.load.add.f32.p0f32' fnty = ir.FunctionType(ir.FloatType(), (ir.PointerType(ir.FloatType(), 0), ir.FloatType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_add_float64(lmod): flags = targetconfig.ConfigStack().top() if flags.compute_capability >= (6, 0): fname = 'llvm.nvvm.atomic.load.add.f64.p0f64' else: fname = '___numba_atomic_double_add' fnty = ir.FunctionType(ir.DoubleType(), (ir.PointerType(ir.DoubleType()), ir.DoubleType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_sub_float32(lmod): fname = '___numba_atomic_float_sub' fnty = ir.FunctionType(ir.FloatType(), (ir.PointerType(ir.FloatType()), ir.FloatType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_sub_float64(lmod): fname = '___numba_atomic_double_sub' fnty = ir.FunctionType(ir.DoubleType(), (ir.PointerType(ir.DoubleType()), ir.DoubleType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_inc_int32(lmod): fname = 'llvm.nvvm.atomic.load.inc.32.p0i32' fnty = ir.FunctionType(ir.IntType(32), (ir.PointerType(ir.IntType(32)), ir.IntType(32))) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_inc_int64(lmod): fname = '___numba_atomic_u64_inc' fnty = ir.FunctionType(ir.IntType(64), (ir.PointerType(ir.IntType(64)), ir.IntType(64))) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_dec_int32(lmod): fname = 'llvm.nvvm.atomic.load.dec.32.p0i32' fnty = ir.FunctionType(ir.IntType(32), (ir.PointerType(ir.IntType(32)), ir.IntType(32))) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_dec_int64(lmod): fname = '___numba_atomic_u64_dec' fnty = ir.FunctionType(ir.IntType(64), (ir.PointerType(ir.IntType(64)), ir.IntType(64))) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_max_float32(lmod): fname = '___numba_atomic_float_max' fnty = ir.FunctionType(ir.FloatType(), (ir.PointerType(ir.FloatType()), ir.FloatType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_max_float64(lmod): fname = '___numba_atomic_double_max' fnty = ir.FunctionType(ir.DoubleType(), (ir.PointerType(ir.DoubleType()), ir.DoubleType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_min_float32(lmod): fname = '___numba_atomic_float_min' fnty = ir.FunctionType(ir.FloatType(), (ir.PointerType(ir.FloatType()), ir.FloatType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_min_float64(lmod): fname = '___numba_atomic_double_min' fnty = ir.FunctionType(ir.DoubleType(), (ir.PointerType(ir.DoubleType()), ir.DoubleType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_nanmax_float32(lmod): fname = '___numba_atomic_float_nanmax' fnty = ir.FunctionType(ir.FloatType(), (ir.PointerType(ir.FloatType()), ir.FloatType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_nanmax_float64(lmod): fname = '___numba_atomic_double_nanmax' fnty = ir.FunctionType(ir.DoubleType(), (ir.PointerType(ir.DoubleType()), ir.DoubleType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_nanmin_float32(lmod): fname = '___numba_atomic_float_nanmin' fnty = ir.FunctionType(ir.FloatType(), (ir.PointerType(ir.FloatType()), ir.FloatType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_atomic_nanmin_float64(lmod): fname = '___numba_atomic_double_nanmin' fnty = ir.FunctionType(ir.DoubleType(), (ir.PointerType(ir.DoubleType()), ir.DoubleType())) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_cudaCGGetIntrinsicHandle(lmod): fname = 'cudaCGGetIntrinsicHandle' fnty = ir.FunctionType(ir.IntType(64), (ir.IntType(32),)) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_cudaCGSynchronize(lmod): fname = 'cudaCGSynchronize' fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(64), ir.IntType(32))) return cgutils.get_or_insert_function(lmod, fnty, fname) def declare_string(builder, value): lmod = builder.basic_block.function.module cval = cgutils.make_bytearray(value.encode("utf-8") + b"\x00") gl = cgutils.add_global_variable(lmod, cval.type, name="_str", addrspace=nvvm.ADDRSPACE_CONSTANT) gl.linkage = 'internal' gl.global_constant = True gl.initializer = cval return builder.addrspacecast(gl, ir.PointerType(ir.IntType(8)), 'generic') def declare_vprint(lmod): voidptrty = ir.PointerType(ir.IntType(8)) # NOTE: the second argument to vprintf() points to the variable-length # array of arguments (after the format) vprintfty = ir.FunctionType(ir.IntType(32), [voidptrty, voidptrty]) vprintf = cgutils.get_or_insert_function(lmod, vprintfty, "vprintf") return vprintf # ----------------------------------------------------------------------------- SREG_MAPPING = { 'tid.x': 'llvm.nvvm.read.ptx.sreg.tid.x', 'tid.y': 'llvm.nvvm.read.ptx.sreg.tid.y', 'tid.z': 'llvm.nvvm.read.ptx.sreg.tid.z', 'ntid.x': 'llvm.nvvm.read.ptx.sreg.ntid.x', 'ntid.y': 'llvm.nvvm.read.ptx.sreg.ntid.y', 'ntid.z': 'llvm.nvvm.read.ptx.sreg.ntid.z', 'ctaid.x': 'llvm.nvvm.read.ptx.sreg.ctaid.x', 'ctaid.y': 'llvm.nvvm.read.ptx.sreg.ctaid.y', 'ctaid.z': 'llvm.nvvm.read.ptx.sreg.ctaid.z', 'nctaid.x': 'llvm.nvvm.read.ptx.sreg.nctaid.x', 'nctaid.y': 'llvm.nvvm.read.ptx.sreg.nctaid.y', 'nctaid.z': 'llvm.nvvm.read.ptx.sreg.nctaid.z', 'warpsize': 'llvm.nvvm.read.ptx.sreg.warpsize', 'laneid': 'llvm.nvvm.read.ptx.sreg.laneid', } def call_sreg(builder, name): module = builder.module fnty = ir.FunctionType(ir.IntType(32), ()) fn = cgutils.get_or_insert_function(module, fnty, SREG_MAPPING[name]) return builder.call(fn, ()) class SRegBuilder(object): def __init__(self, builder): self.builder = builder def tid(self, xyz): return call_sreg(self.builder, 'tid.%s' % xyz) def ctaid(self, xyz): return call_sreg(self.builder, 'ctaid.%s' % xyz) def ntid(self, xyz): return call_sreg(self.builder, 'ntid.%s' % xyz) def nctaid(self, xyz): return call_sreg(self.builder, 'nctaid.%s' % xyz) def getdim(self, xyz): i64 = ir.IntType(64) tid = self.builder.sext(self.tid(xyz), i64) ntid = self.builder.sext(self.ntid(xyz), i64) nctaid = self.builder.sext(self.ctaid(xyz), i64) res = self.builder.add(self.builder.mul(ntid, nctaid), tid) return res def get_global_id(builder, dim): sreg = SRegBuilder(builder) it = (sreg.getdim(xyz) for xyz in 'xyz') seq = list(itertools.islice(it, None, dim)) if dim == 1: return seq[0] else: return seq