ai-content-maker/.venv/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h

631 lines
28 KiB
C++

#pragma once
namespace at::mps {
static const char * indexing_metal_shaders = R"INDEX_METAL(
#include <metal_stdlib>
#include <metal_atomic>
using namespace metal;
#if __METAL_VERSION__ < 300
struct IndexAB {
// Allow up to 16 indices
metal::array<constant void *, 16> indexArray [[ id(0) ]];
};
#else
struct IndexAB {
constant int64_t* indexArray;
};
#endif
template<typename T, typename OffsetsT>
kernel void index_select(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
*out = *in;
}
template<typename T, typename OffsetsT>
void index_put_impl(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB,
#else
constant IndexAB & indexAB,
#endif
constant int64_t * index_sizes,
constant int64_t * index_strides,
constant OffsetsT * offsets,
constant void * inputData,
device void * outputData,
constant uint32_t & num_indices,
uint thread_index) {
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
*out = *in;
}
template<typename T, typename OffsetsT>
kernel void index_put_serial(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
constant uint * numIters [[buffer(7)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
}
}
template<typename T, typename OffsetsT>
kernel void index_put(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
}
#if __METAL_VERSION__ < 300
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB & indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#else
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#endif
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
REGISTER_INDEX_OP_ALL_DTYPES(select);
REGISTER_INDEX_OP_ALL_DTYPES(put);
#if __METAL_VERSION__ < 300
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB & indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
constant uint * numIters [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]]);
#else
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
constant uint * numIters [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]]);
#endif
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
template<typename StridesT, typename DataT>
kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
device DataT * data_offsets [[buffer(1)]],
constant uint * iter_shape [[buffer(2)]],
constant uint & num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]) {
data_offsets[thread_index] = 0;
uint32_t idx = thread_index;
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
uint32_t remainder = idx % iter_shape[dim];
idx /= iter_shape[dim];
data_offsets[thread_index] += remainder * DataT(strides[dim]);
}
}
template
[[host_name("kernel_index_offsets_32")]]
kernel void kernel_index_offsets<packed_uint3, uint3>(
constant packed_uint3 * strides [[buffer(0)]],
device uint3 * data_offsets [[buffer(1)]],
constant uint * iter_shape [[buffer(2)]],
constant uint & num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("kernel_index_offsets_64")]]
kernel void kernel_index_offsets<packed_uint3, ulong3>(
constant packed_uint3 * strides [[buffer(0)]],
device ulong3 * data_offsets [[buffer(1)]],
constant uint * iter_shape [[buffer(2)]],
constant uint & num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]);
template<typename T, typename E, typename OffsetsT>
kernel void index_put_accumulate_native_dtypes(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
}
template<typename T>
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
device atomic_uint* uintAddr = (device atomic_uint*)addr;
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
T updated = as_type<T>(expected) + value;
while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
updated = as_type<T>(expected) + value;
}
}
template<typename T, typename OffsetsT>
kernel void atomic_index_put_accumulate(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
atomic_fetch_add_relaxed<T>(out, *in);
}
template
[[host_name("index_put_accumulate_32bit_float_idx32")]]
kernel void atomic_index_put_accumulate<float, uint3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_float_idx64")]]
kernel void atomic_index_put_accumulate<float, ulong3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant ulong3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_int_idx32")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_int_idx64")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant ulong3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
)INDEX_METAL";
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
struct __attribute__ ((packed)) packed_uint5{{
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
}};
template<typename Y, typename X>
Y cast(const X x);
template<>
{1} cast<{1}, {0}>(const {0} x) {{
return {2};
}}
kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint5 & size [[buffer(2)]],
constant packed_uint5 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint5 local_index;
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
local_index.z = linear_index / (size.u * size.w) % size.z;
local_index.w = linear_index / size.u % size.w;
local_index.u = linear_index % size.u;
packed_uint5 strided_index;
strided_index.x = local_index.x * stride.x;
strided_index.y = local_index.y * stride.y;
strided_index.z = local_index.z * stride.z;
strided_index.w = local_index.w * stride.w;
strided_index.u = local_index.u * stride.u;
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint4 & size [[buffer(2)]],
constant packed_uint4 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint4 local_index;
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
local_index.y = linear_index / (size[3] * size[2]) % size[1];
local_index.z = linear_index / size[3] % size[2];
local_index.w = linear_index % size[3];
const packed_uint4 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint3 & size [[buffer(2)]],
constant packed_uint3 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint3 local_index;
local_index.x = linear_index / (size[2] * size[1]) % size[0];
local_index.y = linear_index / size[2] % size[1];
local_index.z = linear_index % size[2];
const packed_uint3 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint2 & size [[buffer(2)]],
constant packed_uint2 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint2 local_index;
local_index.x = linear_index / size[1] % size[0];
local_index.y = linear_index % size[1];
const packed_uint2 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant int & size [[buffer(2)]],
constant int & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
const int local_index = linear_index % size;
const int strided_index = local_index * stride;
dst[strided_index] = cast<{1}>(src[linear_index]);
}}
)METAL_SCATTER";
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
struct __attribute__ ((packed)) packed_uint5{{
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
}};
template<typename Y, typename X>
Y cast(const X x);
template<>
{1} cast<{1}, {0}>(const {0} x) {{
return {2};
}}
kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint5 & size [[buffer(2)]],
constant packed_uint5 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint5 local_index;
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
local_index.z = linear_index / (size.u * size.w) % size.z;
local_index.w = linear_index / size.u % size.w;
local_index.u = linear_index % size.u;
packed_uint5 strided_index;
strided_index.x = local_index.x * stride.x;
strided_index.y = local_index.y * stride.y;
strided_index.z = local_index.z * stride.z;
strided_index.w = local_index.w * stride.w;
strided_index.u = local_index.u * stride.u;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
}}
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint4 & size [[buffer(2)]],
constant packed_uint4 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint4 local_index;
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
local_index.y = linear_index / (size[3] * size[2]) % size[1];
local_index.z = linear_index / size[3] % size[2];
local_index.w = linear_index % size[3];
const packed_uint4 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
}}
kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint3 & size [[buffer(2)]],
constant packed_uint3 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint3 local_index;
local_index.x = linear_index / (size[2] * size[1]) % size[0];
local_index.y = linear_index / size[2] % size[1];
local_index.z = linear_index % size[2];
const packed_uint3 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
}}
kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint2 & size [[buffer(2)]],
constant packed_uint2 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint2 local_index;
local_index.x = linear_index / size[1] % size[0];
local_index.y = linear_index % size[1];
const packed_uint2 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
}}
kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant int & size [[buffer(2)]],
constant int & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
const int local_index = linear_index % size;
const int strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index]);
}}
)METAL_GATHER";
} // namespace at::mps