#pragma once namespace at::mps { static const char * indexing_metal_shaders = R"INDEX_METAL( #include #include using namespace metal; #if __METAL_VERSION__ < 300 struct IndexAB { // Allow up to 16 indices metal::array indexArray [[ id(0) ]]; }; #else struct IndexAB { constant int64_t* indexArray; }; #endif template kernel void index_select( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant OffsetsT * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { #if __METAL_VERSION__ >= 300 constant int64_t* indexArray = indexAB[i].indexArray; #else constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; #endif int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } offset += index * index_strides[i]; } device T * out = (device T*)((device char*)outputData + offsets[thread_index].x); constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset); *out = *in; } template void index_put_impl( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB, #else constant IndexAB & indexAB, #endif constant int64_t * index_sizes, constant int64_t * index_strides, constant OffsetsT * offsets, constant void * inputData, device void * outputData, constant uint32_t & num_indices, uint thread_index) { int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { #if __METAL_VERSION__ >= 300 constant int64_t* indexArray = indexAB[i].indexArray; #else constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; #endif int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } offset += index * index_strides[i]; } device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset); constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y); *out = *in; } template kernel void index_put_serial( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant OffsetsT * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], constant uint * numIters [[buffer(7)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; for (uint iter_i = 0; iter_i < *numIters; iter_i++) { index_put_impl(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i); } } template kernel void index_put( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant OffsetsT * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; index_put_impl(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index); } #if __METAL_VERSION__ < 300 #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ template \ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ kernel void index_ ## INDEX_OP_TYPE( \ constant IndexAB & indexAB [[buffer(0)]], \ constant void * indexSizes [[buffer(1)]], \ constant void * indexStrides [[buffer(2)]], \ constant IDX_DTYPE * offsets [[buffer(3)]], \ constant void * inputData [[buffer(4)]], \ device void * outputData [[buffer(5)]], \ constant uint32_t & num_indices [[buffer(6)]], \ uint thread_index [[thread_position_in_grid]]); #else #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ template \ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ kernel void index_ ## INDEX_OP_TYPE( \ constant IndexAB * indexAB [[buffer(0)]], \ constant void * indexSizes [[buffer(1)]], \ constant void * indexStrides [[buffer(2)]], \ constant IDX_DTYPE * offsets [[buffer(3)]], \ constant void * inputData [[buffer(4)]], \ device void * outputData [[buffer(5)]], \ constant uint32_t & num_indices [[buffer(6)]], \ uint thread_index [[thread_position_in_grid]]); #endif #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \ REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \ REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \ REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3); REGISTER_INDEX_OP_ALL_DTYPES(select); REGISTER_INDEX_OP_ALL_DTYPES(put); #if __METAL_VERSION__ < 300 #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ template \ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ kernel void index_ ## INDEX_OP_TYPE( \ constant IndexAB & indexAB [[buffer(0)]], \ constant void * indexSizes [[buffer(1)]], \ constant void * indexStrides [[buffer(2)]], \ constant IDX_DTYPE * offsets [[buffer(3)]], \ constant void * inputData [[buffer(4)]], \ device void * outputData [[buffer(5)]], \ constant uint32_t & num_indices [[buffer(6)]], \ constant uint * numIters [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]); #else #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ template \ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ kernel void index_ ## INDEX_OP_TYPE( \ constant IndexAB * indexAB [[buffer(0)]], \ constant void * indexSizes [[buffer(1)]], \ constant void * indexStrides [[buffer(2)]], \ constant IDX_DTYPE * offsets [[buffer(3)]], \ constant void * inputData [[buffer(4)]], \ device void * outputData [[buffer(5)]], \ constant uint32_t & num_indices [[buffer(6)]], \ constant uint * numIters [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]); #endif #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \ REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \ REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \ REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3); REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial); template kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]], device DataT * data_offsets [[buffer(1)]], constant uint * iter_shape [[buffer(2)]], constant uint & num_dimensions [[buffer(3)]], uint thread_index [[thread_position_in_grid]]) { data_offsets[thread_index] = 0; uint32_t idx = thread_index; for (uint32_t dim = 0; dim < num_dimensions; dim++) { uint32_t remainder = idx % iter_shape[dim]; idx /= iter_shape[dim]; data_offsets[thread_index] += remainder * DataT(strides[dim]); } } template [[host_name("kernel_index_offsets_32")]] kernel void kernel_index_offsets( constant packed_uint3 * strides [[buffer(0)]], device uint3 * data_offsets [[buffer(1)]], constant uint * iter_shape [[buffer(2)]], constant uint & num_dimensions [[buffer(3)]], uint thread_index [[thread_position_in_grid]]); template [[host_name("kernel_index_offsets_64")]] kernel void kernel_index_offsets( constant packed_uint3 * strides [[buffer(0)]], device ulong3 * data_offsets [[buffer(1)]], constant uint * iter_shape [[buffer(2)]], constant uint & num_dimensions [[buffer(3)]], uint thread_index [[thread_position_in_grid]]); template kernel void index_put_accumulate_native_dtypes( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant OffsetsT * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { #if __METAL_VERSION__ >= 300 constant int64_t* indexArray = indexAB[i].indexArray; #else constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; #endif int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } offset += index * index_strides[i]; } device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset); constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y); atomic_fetch_add_explicit(out, *in, memory_order_relaxed); } template __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) { device atomic_uint* uintAddr = (device atomic_uint*)addr; uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed); T updated = as_type(expected) + value; while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type(updated), memory_order_relaxed, memory_order_relaxed)) { updated = as_type(expected) + value; } } template kernel void atomic_index_put_accumulate( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant OffsetsT * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]) { constant int64_t * index_sizes = (constant int64_t *)indexSizes; constant int64_t * index_strides = (constant int64_t *)indexStrides; int64_t offset = 0; for (uint32_t i = 0; i < num_indices; i++) { #if __METAL_VERSION__ >= 300 constant int64_t* indexArray = indexAB[i].indexArray; #else constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i]; #endif int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; if (index < 0) { index += index_sizes[i]; } offset += index * index_strides[i]; } device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset); constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y); atomic_fetch_add_relaxed(out, *in); } template [[host_name("index_put_accumulate_32bit_float_idx32")]] kernel void atomic_index_put_accumulate( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant uint3 * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]); template [[host_name("index_put_accumulate_32bit_float_idx64")]] kernel void atomic_index_put_accumulate( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant ulong3 * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]); template [[host_name("index_put_accumulate_32bit_int_idx32")]] kernel void index_put_accumulate_native_dtypes( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant uint3 * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]); template [[host_name("index_put_accumulate_32bit_int_idx64")]] kernel void index_put_accumulate_native_dtypes( #if __METAL_VERSION__ >= 300 constant IndexAB * indexAB [[buffer(0)]], #else constant IndexAB & indexAB [[buffer(0)]], #endif constant void * indexSizes [[buffer(1)]], constant void * indexStrides [[buffer(2)]], constant ulong3 * offsets [[buffer(3)]], constant void * inputData [[buffer(4)]], device void * outputData [[buffer(5)]], constant uint32_t & num_indices [[buffer(6)]], uint thread_index [[thread_position_in_grid]]); )INDEX_METAL"; static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER( struct __attribute__ ((packed)) packed_uint5{{ uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u; }}; template Y cast(const X x); template<> {1} cast<{1}, {0}>(const {0} x) {{ return {2}; }} kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint5 & size [[buffer(2)]], constant packed_uint5 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint5 local_index; local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x; local_index.y = linear_index / (size.u * size.w * size.z) % size.y; local_index.z = linear_index / (size.u * size.w) % size.z; local_index.w = linear_index / size.u % size.w; local_index.u = linear_index % size.u; packed_uint5 strided_index; strided_index.x = local_index.x * stride.x; strided_index.y = local_index.y * stride.y; strided_index.z = local_index.z * stride.z; strided_index.w = local_index.w * stride.w; strided_index.u = local_index.u * stride.u; dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]); }} kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint4 & size [[buffer(2)]], constant packed_uint4 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint4 local_index; local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; local_index.y = linear_index / (size[3] * size[2]) % size[1]; local_index.z = linear_index / size[3] % size[2]; local_index.w = linear_index % size[3]; const packed_uint4 strided_index = local_index * stride; dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]); }} kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint3 & size [[buffer(2)]], constant packed_uint3 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint3 local_index; local_index.x = linear_index / (size[2] * size[1]) % size[0]; local_index.y = linear_index / size[2] % size[1]; local_index.z = linear_index % size[2]; const packed_uint3 strided_index = local_index * stride; dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]); }} kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint2 & size [[buffer(2)]], constant packed_uint2 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint2 local_index; local_index.x = linear_index / size[1] % size[0]; local_index.y = linear_index % size[1]; const packed_uint2 strided_index = local_index * stride; dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]); }} kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant int & size [[buffer(2)]], constant int & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; const int local_index = linear_index % size; const int strided_index = local_index * stride; dst[strided_index] = cast<{1}>(src[linear_index]); }} )METAL_SCATTER"; static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER( struct __attribute__ ((packed)) packed_uint5{{ uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u; }}; template Y cast(const X x); template<> {1} cast<{1}, {0}>(const {0} x) {{ return {2}; }} kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint5 & size [[buffer(2)]], constant packed_uint5 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint5 local_index; local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x; local_index.y = linear_index / (size.u * size.w * size.z) % size.y; local_index.z = linear_index / (size.u * size.w) % size.z; local_index.w = linear_index / size.u % size.w; local_index.u = linear_index % size.u; packed_uint5 strided_index; strided_index.x = local_index.x * stride.x; strided_index.y = local_index.y * stride.y; strided_index.z = local_index.z * stride.z; strided_index.w = local_index.w * stride.w; strided_index.u = local_index.u * stride.u; dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]); }} kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint4 & size [[buffer(2)]], constant packed_uint4 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint4 local_index; local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; local_index.y = linear_index / (size[3] * size[2]) % size[1]; local_index.z = linear_index / size[3] % size[2]; local_index.w = linear_index % size[3]; const packed_uint4 strided_index = local_index * stride; dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]); }} kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint3 & size [[buffer(2)]], constant packed_uint3 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint3 local_index; local_index.x = linear_index / (size[2] * size[1]) % size[0]; local_index.y = linear_index / size[2] % size[1]; local_index.z = linear_index % size[2]; const packed_uint3 strided_index = local_index * stride; dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]); }} kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant packed_uint2 & size [[buffer(2)]], constant packed_uint2 & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; packed_uint2 local_index; local_index.x = linear_index / size[1] % size[0]; local_index.y = linear_index % size[1]; const packed_uint2 strided_index = local_index * stride; dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]); }} kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant int & size [[buffer(2)]], constant int & stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; const int local_index = linear_index % size; const int strided_index = local_index * stride; dst[linear_index] = cast<{1}>(src[strided_index]); }} )METAL_GATHER"; } // namespace at::mps