Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / native / DilatedConvolutionUtils.h

#pragma once

#include <algorithm>
#include <vector>

#include <ATen/div_rtn.h>
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>

#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
  TORCH_CHECK(                                       \
      T.dim() == DIM && T.size(DIM_SIZE) == SIZE,    \
      "Need " #T " of dimension ",                   \
      DIM,                                           \
      " and " #T ".size[",                           \
      DIM_SIZE,                                      \
      "] == ",                                       \
      SIZE,                                          \
      " but got input to be of shape ",              \
      T.sizes())

namespace at {
namespace native {
namespace internal {
namespace {
inline bool all_positive(IntArrayRef& arr) {
  return std::all_of(
      arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
}

inline bool all_nonnegative(std::vector<int64_t>& arr) {
  return std::all_of(
      arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
}

} // namespace

// calculate the rear part of output tensor sizes
template <int64_t dim>
std::vector<int64_t> get_output_size(
    const Tensor& input,
    IntArrayRef kernel_size,
    IntArrayRef stride_size,
    IntArrayRef pad_size,
    IntArrayRef dilation_size) {
  std::vector<int64_t> sizes;
  for (const auto index : c10::irange(dim)) {
    sizes.push_back(
        div_rtn<int64_t>(
            input.size(index + input.dim() - dim) + 2 * pad_size[index] -
                (dilation_size[index] * (kernel_size[index] - 1) + 1),
            stride_size[index]) +
        1);
  }
  return sizes;
}

// calculate the sizes of output tensor
template <int64_t dim>
std::vector<int64_t> get_output_size(
    const Tensor& input,
    const Tensor& weight,
    IntArrayRef kernel_size,
    IntArrayRef stride_size,
    IntArrayRef pad_size,
    IntArrayRef dilation_size) {
  auto output_size = get_output_size<dim>(
      input, kernel_size, stride_size, pad_size, dilation_size);
  output_size.insert(output_size.begin(), weight.size(0));
  if (input.dim() == dim + 2) {
    output_size.insert(output_size.begin(), input.size(0));
  }
  return output_size;
}
/*
  slow_conv_dilated_shape_check - check user-input to dilated convolution
  forward and backward functions.
*/
template <int64_t dim>
void slow_conv_dilated_shape_check(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    const Tensor& grad_output,
    IntArrayRef kernel_size,
    IntArrayRef stride_size,
    IntArrayRef pad_size,
    IntArrayRef dilation_size) {
  /*
    When the following tensors are defined:

    bias, grad_weight, grad_output

    then these are assumed to be contiguous without checking
    because of these tensors are made contiguous by calling
    .contiguous() method or by resizing of zero-sized tensors in
    forward/backward functions.

    When grad_weight is defined then it is assumed without
    checking to have the same shape as weight, see backward
    functions.
   */
  // Check size arguments
  TORCH_CHECK(
      kernel_size.size() == dim,
      "kernel sizes length should be ",
      dim,
      ", but got ",
      kernel_size.size());
  TORCH_CHECK(
      stride_size.size() == dim,
      "strides length should be ",
      dim,
      ", but got ",
      stride_size.size());
  TORCH_CHECK(
      dilation_size.size() == dim,
      "dilations length should be ",
      dim,
      ", but got ",
      dilation_size.size());
  TORCH_CHECK(
      pad_size.size() == dim,
      "pads length should be ",
      dim,
      ", but got ",
      pad_size.size());

  TORCH_CHECK(
      all_positive(kernel_size),
      "kernel size should be greater than zero, but got ",
      kernel_size);
  TORCH_CHECK(
      all_positive(stride_size),
      "stride should be greater than zero, but got ",
      stride_size);
  TORCH_CHECK(
      all_positive(dilation_size),
      "dilation should be greater than zero, but got ",
      dilation_size);

  // check input
  TORCH_CHECK(input.defined(), "input must be defined");
  bool is_batch = input.dim() == dim + 2;
  int64_t n = (is_batch ? 2 : 1);
  int64_t ndim = n + dim;
  if (!is_batch) {
    // input dim has to be dim + 1 if not batched
    TORCH_CHECK(
        input.dim() == dim + 1,
        "input must be 4D or 5D tensor but got ",
        input.dim(),
        "D tensor");
  }

  // check output sizes
  auto output_size = get_output_size<dim>(
      input, kernel_size, stride_size, pad_size, dilation_size);

  TORCH_CHECK(
      all_nonnegative(output_size),
      "calculated output size ",
      output_size,
      " is too small (all sizes must be non-negative)");

  // check weight
  TORCH_CHECK(weight.defined(), "weight must be defined");
  TORCH_CHECK(
      weight.dim() == dim + 2,
      "weight must be ",
      dim + 2,
      "D tensor but got ",
      weight.dim(),
      "D tensor dim=",
      dim);
  TORCH_CHECK(
      weight.sizes().slice(2) == kernel_size,
      "weight[2:] shape ",
      weight.sizes().slice(2),
      " must be equal to kernel_size ",
      kernel_size);

  TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));

  // check bias when present
  if (bias.defined()) {
    TORCH_CHECK(
        bias.dim() == 1,
        "bias must be 1D tensor but got ",
        bias.dim(),
        "D tensor");
    TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
  }

  // check grad_output when present
  if (grad_output.defined()) {
    TORCH_CHECK(
        grad_output.dim() == ndim,
        "grad_output must be ",
        ndim,
        "D tensor but got ",
        grad_output.dim(),
        "D tensor");
    if (is_batch) {
      TORCH_CHECK(
          grad_output.size(0) == input.size(0),
          "grad_output.size(0)=",
          grad_output.size(0),
          " must be input.size(0)=",
          input.size(0));
    }
    TORCH_CHECK(
        grad_output.size(n - 1) == weight.size(0),
        "grad_output.size(",
        n - 1,
        ")=",
        grad_output.size(n - 1),
        " must be weight.size(0)=",
        weight.size(0));
    TORCH_CHECK(
        grad_output.sizes().slice(n) == output_size,
        "grad_output[",
        n,
        ":] shape",
        grad_output.sizes().slice(n),
        " must be equal to output size ",
        output_size);
  }
}

} // namespace internal
} // namespace native
} // namespace at