230 lines
6.3 KiB
C++
230 lines
6.3 KiB
C++
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <vector>
|
|
|
|
#include <ATen/div_rtn.h>
|
|
#include <ATen/core/Tensor.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
|
|
TORCH_CHECK( \
|
|
T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
|
|
"Need " #T " of dimension ", \
|
|
DIM, \
|
|
" and " #T ".size[", \
|
|
DIM_SIZE, \
|
|
"] == ", \
|
|
SIZE, \
|
|
" but got input to be of shape ", \
|
|
T.sizes())
|
|
|
|
namespace at::native::internal {
|
|
namespace {
|
|
inline bool all_positive(IntArrayRef& arr) {
|
|
return std::all_of(
|
|
arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
|
|
}
|
|
|
|
inline bool all_nonnegative(std::vector<int64_t>& arr) {
|
|
return std::all_of(
|
|
arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// calculate the rear part of output tensor sizes
|
|
template <int64_t dim>
|
|
std::vector<int64_t> get_output_size(
|
|
const Tensor& input,
|
|
IntArrayRef kernel_size,
|
|
IntArrayRef stride_size,
|
|
IntArrayRef pad_size,
|
|
IntArrayRef dilation_size) {
|
|
std::vector<int64_t> sizes;
|
|
for (const auto index : c10::irange(dim)) {
|
|
sizes.push_back(
|
|
div_rtn<int64_t>(
|
|
input.size(index + input.dim() - dim) + 2 * pad_size[index] -
|
|
(dilation_size[index] * (kernel_size[index] - 1) + 1),
|
|
stride_size[index]) +
|
|
1);
|
|
}
|
|
return sizes;
|
|
}
|
|
|
|
// calculate the sizes of output tensor
|
|
template <int64_t dim>
|
|
std::vector<int64_t> get_output_size(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
IntArrayRef kernel_size,
|
|
IntArrayRef stride_size,
|
|
IntArrayRef pad_size,
|
|
IntArrayRef dilation_size) {
|
|
auto output_size = get_output_size<dim>(
|
|
input, kernel_size, stride_size, pad_size, dilation_size);
|
|
output_size.insert(output_size.begin(), weight.size(0));
|
|
if (input.dim() == dim + 2) {
|
|
output_size.insert(output_size.begin(), input.size(0));
|
|
}
|
|
return output_size;
|
|
}
|
|
/*
|
|
slow_conv_dilated_shape_check - check user-input to dilated convolution
|
|
forward and backward functions.
|
|
*/
|
|
template <int64_t dim>
|
|
void slow_conv_dilated_shape_check(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& bias,
|
|
const Tensor& grad_output,
|
|
IntArrayRef kernel_size,
|
|
IntArrayRef stride_size,
|
|
IntArrayRef pad_size,
|
|
IntArrayRef dilation_size) {
|
|
/*
|
|
When the following tensors are defined:
|
|
|
|
bias, grad_weight, grad_output
|
|
|
|
then these are assumed to be contiguous without checking
|
|
because of these tensors are made contiguous by calling
|
|
.contiguous() method or by resizing of zero-sized tensors in
|
|
forward/backward functions.
|
|
|
|
When grad_weight is defined then it is assumed without
|
|
checking to have the same shape as weight, see backward
|
|
functions.
|
|
*/
|
|
// Check size arguments
|
|
TORCH_CHECK(
|
|
kernel_size.size() == dim,
|
|
"kernel sizes length should be ",
|
|
dim,
|
|
", but got ",
|
|
kernel_size.size());
|
|
TORCH_CHECK(
|
|
stride_size.size() == dim,
|
|
"strides length should be ",
|
|
dim,
|
|
", but got ",
|
|
stride_size.size());
|
|
TORCH_CHECK(
|
|
dilation_size.size() == dim,
|
|
"dilations length should be ",
|
|
dim,
|
|
", but got ",
|
|
dilation_size.size());
|
|
TORCH_CHECK(
|
|
pad_size.size() == dim,
|
|
"pads length should be ",
|
|
dim,
|
|
", but got ",
|
|
pad_size.size());
|
|
|
|
TORCH_CHECK(
|
|
all_positive(kernel_size),
|
|
"kernel size should be greater than zero, but got ",
|
|
kernel_size);
|
|
TORCH_CHECK(
|
|
all_positive(stride_size),
|
|
"stride should be greater than zero, but got ",
|
|
stride_size);
|
|
TORCH_CHECK(
|
|
all_positive(dilation_size),
|
|
"dilation should be greater than zero, but got ",
|
|
dilation_size);
|
|
|
|
// check input
|
|
TORCH_CHECK(input.defined(), "input must be defined");
|
|
bool is_batch = input.dim() == dim + 2;
|
|
int64_t n = (is_batch ? 2 : 1);
|
|
int64_t ndim = n + dim;
|
|
if (!is_batch) {
|
|
// input dim has to be dim + 1 if not batched
|
|
TORCH_CHECK(
|
|
input.dim() == dim + 1,
|
|
"input must be 4D or 5D tensor but got ",
|
|
input.dim(),
|
|
"D tensor");
|
|
}
|
|
|
|
// check output sizes
|
|
auto output_size = get_output_size<dim>(
|
|
input, kernel_size, stride_size, pad_size, dilation_size);
|
|
|
|
TORCH_CHECK(
|
|
all_nonnegative(output_size),
|
|
"calculated output size ",
|
|
output_size,
|
|
" is too small (all sizes must be non-negative)");
|
|
|
|
// check weight
|
|
TORCH_CHECK(weight.defined(), "weight must be defined");
|
|
TORCH_CHECK(
|
|
weight.dim() == dim + 2,
|
|
"weight must be ",
|
|
dim + 2,
|
|
"D tensor but got ",
|
|
weight.dim(),
|
|
"D tensor dim=",
|
|
dim);
|
|
TORCH_CHECK(
|
|
weight.sizes().slice(2) == kernel_size,
|
|
"weight[2:] shape ",
|
|
weight.sizes().slice(2),
|
|
" must be equal to kernel_size ",
|
|
kernel_size);
|
|
|
|
TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
|
|
|
|
// check bias when present
|
|
if (bias.defined()) {
|
|
TORCH_CHECK(
|
|
bias.dim() == 1,
|
|
"bias must be 1D tensor but got ",
|
|
bias.dim(),
|
|
"D tensor");
|
|
TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
|
|
}
|
|
|
|
// check grad_output when present
|
|
if (grad_output.defined()) {
|
|
TORCH_CHECK(
|
|
grad_output.dim() == ndim,
|
|
"grad_output must be ",
|
|
ndim,
|
|
"D tensor but got ",
|
|
grad_output.dim(),
|
|
"D tensor");
|
|
if (is_batch) {
|
|
TORCH_CHECK(
|
|
grad_output.size(0) == input.size(0),
|
|
"grad_output.size(0)=",
|
|
grad_output.size(0),
|
|
" must be input.size(0)=",
|
|
input.size(0));
|
|
}
|
|
TORCH_CHECK(
|
|
grad_output.size(n - 1) == weight.size(0),
|
|
"grad_output.size(",
|
|
n - 1,
|
|
")=",
|
|
grad_output.size(n - 1),
|
|
" must be weight.size(0)=",
|
|
weight.size(0));
|
|
TORCH_CHECK(
|
|
grad_output.sizes().slice(n) == output_size,
|
|
"grad_output[",
|
|
n,
|
|
":] shape",
|
|
grad_output.sizes().slice(n),
|
|
" must be equal to output size ",
|
|
output_size);
|
|
}
|
|
}
|
|
|
|
} // namespace at::native::internal
|