631 lines
28 KiB
C++
631 lines
28 KiB
C++
#pragma once
|
|
|
|
namespace at::mps {
|
|
|
|
static const char * indexing_metal_shaders = R"INDEX_METAL(
|
|
#include <metal_stdlib>
|
|
#include <metal_atomic>
|
|
|
|
using namespace metal;
|
|
|
|
#if __METAL_VERSION__ < 300
|
|
struct IndexAB {
|
|
// Allow up to 16 indices
|
|
metal::array<constant void *, 16> indexArray [[ id(0) ]];
|
|
};
|
|
#else
|
|
struct IndexAB {
|
|
constant int64_t* indexArray;
|
|
};
|
|
|
|
#endif
|
|
|
|
template<typename T, typename OffsetsT>
|
|
kernel void index_select(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant OffsetsT * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]) {
|
|
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
|
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
|
int64_t offset = 0;
|
|
for (uint32_t i = 0; i < num_indices; i++) {
|
|
#if __METAL_VERSION__ >= 300
|
|
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
#else
|
|
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
|
#endif
|
|
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
if (index < 0) {
|
|
index += index_sizes[i];
|
|
}
|
|
offset += index * index_strides[i];
|
|
}
|
|
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
|
|
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
|
|
*out = *in;
|
|
}
|
|
|
|
template<typename T, typename OffsetsT>
|
|
void index_put_impl(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB,
|
|
#else
|
|
constant IndexAB & indexAB,
|
|
#endif
|
|
constant int64_t * index_sizes,
|
|
constant int64_t * index_strides,
|
|
constant OffsetsT * offsets,
|
|
constant void * inputData,
|
|
device void * outputData,
|
|
constant uint32_t & num_indices,
|
|
uint thread_index) {
|
|
int64_t offset = 0;
|
|
for (uint32_t i = 0; i < num_indices; i++) {
|
|
#if __METAL_VERSION__ >= 300
|
|
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
#else
|
|
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
|
#endif
|
|
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
|
|
if (index < 0) {
|
|
index += index_sizes[i];
|
|
}
|
|
offset += index * index_strides[i];
|
|
}
|
|
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
|
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
|
|
*out = *in;
|
|
}
|
|
|
|
template<typename T, typename OffsetsT>
|
|
kernel void index_put_serial(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant OffsetsT * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
constant uint * numIters [[buffer(7)]],
|
|
uint thread_index [[thread_position_in_grid]]) {
|
|
|
|
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
|
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
|
|
|
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
|
|
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
|
|
}
|
|
}
|
|
|
|
template<typename T, typename OffsetsT>
|
|
kernel void index_put(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant OffsetsT * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]) {
|
|
|
|
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
|
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
|
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
|
|
}
|
|
|
|
#if __METAL_VERSION__ < 300
|
|
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
|
template \
|
|
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
|
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
|
constant IndexAB & indexAB [[buffer(0)]], \
|
|
constant void * indexSizes [[buffer(1)]], \
|
|
constant void * indexStrides [[buffer(2)]], \
|
|
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
|
constant void * inputData [[buffer(4)]], \
|
|
device void * outputData [[buffer(5)]], \
|
|
constant uint32_t & num_indices [[buffer(6)]], \
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
#else
|
|
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
|
template \
|
|
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
|
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
|
constant IndexAB * indexAB [[buffer(0)]], \
|
|
constant void * indexSizes [[buffer(1)]], \
|
|
constant void * indexStrides [[buffer(2)]], \
|
|
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
|
constant void * inputData [[buffer(4)]], \
|
|
device void * outputData [[buffer(5)]], \
|
|
constant uint32_t & num_indices [[buffer(6)]], \
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
#endif
|
|
|
|
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
|
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
|
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
|
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
|
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
|
|
|
REGISTER_INDEX_OP_ALL_DTYPES(select);
|
|
REGISTER_INDEX_OP_ALL_DTYPES(put);
|
|
|
|
#if __METAL_VERSION__ < 300
|
|
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
|
template \
|
|
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
|
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
|
constant IndexAB & indexAB [[buffer(0)]], \
|
|
constant void * indexSizes [[buffer(1)]], \
|
|
constant void * indexStrides [[buffer(2)]], \
|
|
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
|
constant void * inputData [[buffer(4)]], \
|
|
device void * outputData [[buffer(5)]], \
|
|
constant uint32_t & num_indices [[buffer(6)]], \
|
|
constant uint * numIters [[buffer(7)]], \
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
#else
|
|
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
|
template \
|
|
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
|
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
|
constant IndexAB * indexAB [[buffer(0)]], \
|
|
constant void * indexSizes [[buffer(1)]], \
|
|
constant void * indexStrides [[buffer(2)]], \
|
|
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
|
constant void * inputData [[buffer(4)]], \
|
|
device void * outputData [[buffer(5)]], \
|
|
constant uint32_t & num_indices [[buffer(6)]], \
|
|
constant uint * numIters [[buffer(7)]], \
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
#endif
|
|
|
|
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
|
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
|
|
|
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
|
|
|
|
template<typename StridesT, typename DataT>
|
|
kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
|
|
device DataT * data_offsets [[buffer(1)]],
|
|
constant uint * iter_shape [[buffer(2)]],
|
|
constant uint & num_dimensions [[buffer(3)]],
|
|
uint thread_index [[thread_position_in_grid]]) {
|
|
data_offsets[thread_index] = 0;
|
|
uint32_t idx = thread_index;
|
|
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
|
|
uint32_t remainder = idx % iter_shape[dim];
|
|
idx /= iter_shape[dim];
|
|
|
|
data_offsets[thread_index] += remainder * DataT(strides[dim]);
|
|
}
|
|
}
|
|
|
|
template
|
|
[[host_name("kernel_index_offsets_32")]]
|
|
kernel void kernel_index_offsets<packed_uint3, uint3>(
|
|
constant packed_uint3 * strides [[buffer(0)]],
|
|
device uint3 * data_offsets [[buffer(1)]],
|
|
constant uint * iter_shape [[buffer(2)]],
|
|
constant uint & num_dimensions [[buffer(3)]],
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
|
|
template
|
|
[[host_name("kernel_index_offsets_64")]]
|
|
kernel void kernel_index_offsets<packed_uint3, ulong3>(
|
|
constant packed_uint3 * strides [[buffer(0)]],
|
|
device ulong3 * data_offsets [[buffer(1)]],
|
|
constant uint * iter_shape [[buffer(2)]],
|
|
constant uint & num_dimensions [[buffer(3)]],
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
|
|
template<typename T, typename E, typename OffsetsT>
|
|
kernel void index_put_accumulate_native_dtypes(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant OffsetsT * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]) {
|
|
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
|
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
|
int64_t offset = 0;
|
|
for (uint32_t i = 0; i < num_indices; i++) {
|
|
#if __METAL_VERSION__ >= 300
|
|
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
#else
|
|
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
|
#endif
|
|
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
if (index < 0) {
|
|
index += index_sizes[i];
|
|
}
|
|
offset += index * index_strides[i];
|
|
}
|
|
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
|
constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
|
|
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
|
|
}
|
|
|
|
template<typename T>
|
|
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
|
|
device atomic_uint* uintAddr = (device atomic_uint*)addr;
|
|
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
|
|
T updated = as_type<T>(expected) + value;
|
|
while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
|
|
updated = as_type<T>(expected) + value;
|
|
}
|
|
}
|
|
|
|
template<typename T, typename OffsetsT>
|
|
kernel void atomic_index_put_accumulate(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant OffsetsT * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]) {
|
|
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
|
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
|
int64_t offset = 0;
|
|
for (uint32_t i = 0; i < num_indices; i++) {
|
|
#if __METAL_VERSION__ >= 300
|
|
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
#else
|
|
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
|
|
#endif
|
|
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
if (index < 0) {
|
|
index += index_sizes[i];
|
|
}
|
|
offset += index * index_strides[i];
|
|
}
|
|
device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
|
|
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
|
|
atomic_fetch_add_relaxed<T>(out, *in);
|
|
}
|
|
|
|
template
|
|
[[host_name("index_put_accumulate_32bit_float_idx32")]]
|
|
kernel void atomic_index_put_accumulate<float, uint3>(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant uint3 * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
|
|
template
|
|
[[host_name("index_put_accumulate_32bit_float_idx64")]]
|
|
kernel void atomic_index_put_accumulate<float, ulong3>(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant ulong3 * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
|
|
template
|
|
[[host_name("index_put_accumulate_32bit_int_idx32")]]
|
|
kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant uint3 * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
|
|
template
|
|
[[host_name("index_put_accumulate_32bit_int_idx64")]]
|
|
kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
|
|
#if __METAL_VERSION__ >= 300
|
|
constant IndexAB * indexAB [[buffer(0)]],
|
|
#else
|
|
constant IndexAB & indexAB [[buffer(0)]],
|
|
#endif
|
|
constant void * indexSizes [[buffer(1)]],
|
|
constant void * indexStrides [[buffer(2)]],
|
|
constant ulong3 * offsets [[buffer(3)]],
|
|
constant void * inputData [[buffer(4)]],
|
|
device void * outputData [[buffer(5)]],
|
|
constant uint32_t & num_indices [[buffer(6)]],
|
|
uint thread_index [[thread_position_in_grid]]);
|
|
)INDEX_METAL";
|
|
|
|
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
|
|
struct __attribute__ ((packed)) packed_uint5{{
|
|
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
|
}};
|
|
|
|
template<typename Y, typename X>
|
|
Y cast(const X x);
|
|
|
|
template<>
|
|
{1} cast<{1}, {0}>(const {0} x) {{
|
|
return {2};
|
|
}}
|
|
|
|
kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint5 & size [[buffer(2)]],
|
|
constant packed_uint5 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint5 local_index;
|
|
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
|
|
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
|
|
local_index.z = linear_index / (size.u * size.w) % size.z;
|
|
local_index.w = linear_index / size.u % size.w;
|
|
local_index.u = linear_index % size.u;
|
|
|
|
packed_uint5 strided_index;
|
|
strided_index.x = local_index.x * stride.x;
|
|
strided_index.y = local_index.y * stride.y;
|
|
strided_index.z = local_index.z * stride.z;
|
|
strided_index.w = local_index.w * stride.w;
|
|
strided_index.u = local_index.u * stride.u;
|
|
|
|
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
|
|
}}
|
|
|
|
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint4 & size [[buffer(2)]],
|
|
constant packed_uint4 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint4 local_index;
|
|
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
|
|
local_index.y = linear_index / (size[3] * size[2]) % size[1];
|
|
local_index.z = linear_index / size[3] % size[2];
|
|
local_index.w = linear_index % size[3];
|
|
|
|
const packed_uint4 strided_index = local_index * stride;
|
|
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
|
|
}}
|
|
|
|
kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint3 & size [[buffer(2)]],
|
|
constant packed_uint3 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint3 local_index;
|
|
local_index.x = linear_index / (size[2] * size[1]) % size[0];
|
|
local_index.y = linear_index / size[2] % size[1];
|
|
local_index.z = linear_index % size[2];
|
|
|
|
const packed_uint3 strided_index = local_index * stride;
|
|
dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
|
|
}}
|
|
|
|
kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint2 & size [[buffer(2)]],
|
|
constant packed_uint2 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint2 local_index;
|
|
local_index.x = linear_index / size[1] % size[0];
|
|
local_index.y = linear_index % size[1];
|
|
|
|
const packed_uint2 strided_index = local_index * stride;
|
|
dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
|
|
}}
|
|
|
|
kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant int & size [[buffer(2)]],
|
|
constant int & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
const int local_index = linear_index % size;
|
|
const int strided_index = local_index * stride;
|
|
dst[strided_index] = cast<{1}>(src[linear_index]);
|
|
}}
|
|
)METAL_SCATTER";
|
|
|
|
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
|
|
struct __attribute__ ((packed)) packed_uint5{{
|
|
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
|
}};
|
|
|
|
template<typename Y, typename X>
|
|
Y cast(const X x);
|
|
|
|
template<>
|
|
{1} cast<{1}, {0}>(const {0} x) {{
|
|
return {2};
|
|
}}
|
|
|
|
kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint5 & size [[buffer(2)]],
|
|
constant packed_uint5 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
|
|
packed_uint5 local_index;
|
|
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
|
|
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
|
|
local_index.z = linear_index / (size.u * size.w) % size.z;
|
|
local_index.w = linear_index / size.u % size.w;
|
|
local_index.u = linear_index % size.u;
|
|
|
|
packed_uint5 strided_index;
|
|
strided_index.x = local_index.x * stride.x;
|
|
strided_index.y = local_index.y * stride.y;
|
|
strided_index.z = local_index.z * stride.z;
|
|
strided_index.w = local_index.w * stride.w;
|
|
strided_index.u = local_index.u * stride.u;
|
|
|
|
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
|
|
}}
|
|
|
|
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint4 & size [[buffer(2)]],
|
|
constant packed_uint4 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint4 local_index;
|
|
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
|
|
local_index.y = linear_index / (size[3] * size[2]) % size[1];
|
|
local_index.z = linear_index / size[3] % size[2];
|
|
local_index.w = linear_index % size[3];
|
|
|
|
const packed_uint4 strided_index = local_index * stride;
|
|
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
|
|
}}
|
|
|
|
kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint3 & size [[buffer(2)]],
|
|
constant packed_uint3 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint3 local_index;
|
|
local_index.x = linear_index / (size[2] * size[1]) % size[0];
|
|
local_index.y = linear_index / size[2] % size[1];
|
|
local_index.z = linear_index % size[2];
|
|
|
|
const packed_uint3 strided_index = local_index * stride;
|
|
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
|
|
}}
|
|
|
|
kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant packed_uint2 & size [[buffer(2)]],
|
|
constant packed_uint2 & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
packed_uint2 local_index;
|
|
local_index.x = linear_index / size[1] % size[0];
|
|
local_index.y = linear_index % size[1];
|
|
|
|
const packed_uint2 strided_index = local_index * stride;
|
|
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
|
|
}}
|
|
|
|
kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
|
|
constant void * src_ [[buffer(0)]],
|
|
device void * dst_ [[buffer(1)]],
|
|
constant int & size [[buffer(2)]],
|
|
constant int & stride [[buffer(3)]],
|
|
constant uint32_t & numel [[buffer(4)]]) {{
|
|
if (linear_index >= numel) return;
|
|
|
|
constant {0} * src = (constant {0} *)src_;
|
|
device {1} * dst = (device {1} *)dst_;
|
|
|
|
const int local_index = linear_index % size;
|
|
const int strided_index = local_index * stride;
|
|
dst[linear_index] = cast<{1}>(src[strided_index]);
|
|
}}
|
|
)METAL_GATHER";
|
|
} // namespace at::mps
|