140 lines
4.8 KiB
C++
140 lines
4.8 KiB
C++
#include <ATen/core/Tensor.h>
|
|
#include <ATen/Config.h>
|
|
#include <cstdint>
|
|
|
|
#ifdef USE_FBGEMM
|
|
#include <fbgemm/FbgemmEmbedding.h>
|
|
#endif
|
|
|
|
namespace at::native {
|
|
|
|
void check_arguments(
|
|
const Tensor& weight,
|
|
const Tensor& indices,
|
|
const Tensor& offsets,
|
|
const int64_t mode,
|
|
const c10::optional<Tensor>& per_sample_weights,
|
|
bool include_last_offset);
|
|
|
|
void make_bag_size_out(
|
|
Tensor& bag_size_out,
|
|
const Tensor& offsets,
|
|
const Tensor& indices,
|
|
const int64_t mode,
|
|
const bool include_last_offset,
|
|
const bool requires_grad);
|
|
|
|
void make_max_indices_out(
|
|
Tensor& max_indices_out,
|
|
const Tensor& weight,
|
|
const Tensor& indices,
|
|
const Tensor& offsets,
|
|
const Tensor& bag_size,
|
|
const int64_t mode,
|
|
bool include_last_offset);
|
|
|
|
void make_offset2bag_out(
|
|
Tensor& offset2bag,
|
|
Tensor& output,
|
|
const Tensor& weight,
|
|
const Tensor& indices,
|
|
const Tensor& offsets,
|
|
const int64_t mode,
|
|
const c10::optional<Tensor>& per_sample_weights,
|
|
const int64_t padding_idx = -1);
|
|
|
|
#ifdef USE_FBGEMM
|
|
|
|
template<bool has_weight, typename TIndex, typename TData>
|
|
struct _CallbackAndBlockSize {
|
|
using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
|
|
|
|
int64_t blockSize = -1;
|
|
TCallback callback = nullptr;
|
|
|
|
static TCallback generateCallback(int64_t block_size) {
|
|
return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
|
|
block_size,
|
|
has_weight,
|
|
/* normalize_by_lengths */false,
|
|
/* prefetch */16,
|
|
/* is_weight_positional */false,
|
|
/* use_offsets */true);
|
|
}
|
|
|
|
_CallbackAndBlockSize() = default;
|
|
|
|
explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size)
|
|
: blockSize(maybe_block_size.value_or(-1))
|
|
, callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
|
|
{}
|
|
};
|
|
|
|
template<typename... StorageMixins>
|
|
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
|
|
|
|
_EmbeddingBagKernelCacheImpl() = default;
|
|
// use each of the mixins to store corresponding kernel and block size
|
|
explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size)
|
|
: StorageMixins(maybe_block_size)...
|
|
{}
|
|
|
|
// this method is thread safe (call sites may call from different threads)
|
|
template<bool has_weight, typename TIndex, typename TData>
|
|
typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
|
|
getCallback(int64_t block_size) const {
|
|
// if the cache doesn't store the kernel for the incoming block size
|
|
// (so it is different from the one stored in corresponding mixin)
|
|
// regenerate the kernel (not writing it into the cache so we avoid locks)
|
|
if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
|
|
return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
|
|
}
|
|
// else retrieve the cached kernel from the corresponding mixin
|
|
return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
|
|
}
|
|
};
|
|
|
|
// instantiate the cache with the list of storage mixins
|
|
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
|
|
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
|
|
_CallbackAndBlockSize<true, int32_t, float>,
|
|
_CallbackAndBlockSize<false, int32_t, float>,
|
|
_CallbackAndBlockSize<true, int64_t, float>,
|
|
_CallbackAndBlockSize<false, int64_t, float>,
|
|
_CallbackAndBlockSize<true, int32_t, unsigned short>,
|
|
_CallbackAndBlockSize<false, int32_t, unsigned short>,
|
|
_CallbackAndBlockSize<true, int64_t, unsigned short>,
|
|
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
|
|
#else
|
|
struct _EmbeddingBagKernelCache {
|
|
explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
|
|
};
|
|
#endif
|
|
|
|
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
|
Tensor& bag_size, Tensor* max_indices,
|
|
const Tensor &weight, const Tensor &indices,
|
|
const Tensor &offsets, const int64_t mode = 0,
|
|
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
|
|
bool include_last_offset = false,
|
|
int64_t padding_idx = -1,
|
|
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
|
|
|
|
void _embedding_bag_cpu_out(
|
|
at::Tensor& output,
|
|
at::Tensor& offset2bag,
|
|
at::Tensor& bag_size,
|
|
at::Tensor* p_max_indices,
|
|
const at::Tensor& weight,
|
|
const at::Tensor& indices,
|
|
const at::Tensor& offsets,
|
|
const bool scale_grad_by_freq,
|
|
const int64_t mode,
|
|
const bool sparse,
|
|
const c10::optional<at::Tensor>& per_sample_weights,
|
|
const bool include_last_offset,
|
|
const c10::optional<int64_t>& padding_idx,
|
|
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
|
|
|
|
} // namespace at::native
|