Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / native / TriangularOpsUtils.h

#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>

namespace at {
namespace native {

/*
 * Given batches of matrices with arbitrary batch dim,
 * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
 */
static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
  int64_t result = 1;
  for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
    if (batched_matrices.stride(i) != 0) {
      result *= batched_matrices.size(i);
    }
  }
  return result;
}

/* Checks a necessary property for the triu and tril implementations, hence the name.
 * Here batch contiguity is checked for tensors with greater than 4 dimensions.
 * Contiguous tensors and tensors with less than 3 dimensions pass this check
 */
static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
  // Complete contiguity is the most desired property, which is why
  // we return true if the tensor is contiguous
  if (tensor.is_contiguous()) {
    auto default_strides_for_size = at::detail::defaultStrides(tensor.sizes());
    if (tensor.strides() == default_strides_for_size) {
      return std::make_tuple(true, tensor);
    } else {
      return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
    }
  }

  int64_t dims = tensor.dim();

  // Tensors with dimension less than 4 are handled by default
  if (allow_zero_stride && dims <= 3) {
    return std::make_tuple(true, tensor);
  }

  int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
  for (int64_t i = dims - 3; i >= 0; i--) {
    // Skip trivial dimension;
    if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
      continue;
    }
    if (expected_stride != tensor.stride(i)) {
      return std::make_tuple(false, tensor.contiguous());
    }
    expected_stride *= tensor.size(i);
  }
  return std::make_tuple(true, tensor);
}

}  // namespace native
}  // namespace at