Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / torch / csrc / autograd / utils / grad_layout_contract.h

#pragma once

#include <ATen/Tensor.h>

namespace torch {
namespace autograd {
namespace utils {

// Helper functions to enforce the "Gradient Layout Contract" described in
// torch/csrc/autograd/functions/accumulate_grad.h.

// Checks if grad obeys the contract with variable.
inline bool obeys_layout_contract(
    const at::Tensor& grad,
    const at::Tensor& variable) {
  TORCH_INTERNAL_ASSERT(!grad.is_sparse());
  TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
  TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());

  if (variable.is_nested()) {
    // TODO: Nested Tensor does not have an implementation of detach. The
    // current implementation of nested tensor likely does obey the gradient
    // contract and should return true, but this would likely change in the
    // future
    return false;
  } else if (variable.is_sparse()) {
    // Gradient Layout Contract is not applicable for sparse layouts
    return false;
  } else if (variable.is_non_overlapping_and_dense()) {
    // Only look at stride for dimensions that are not of size 1.
    const auto& grad_sizes = grad.sym_sizes();
    const auto& grad_strides = grad.sym_strides();
    const auto& variable_strides = variable.sym_strides();
    for (const auto idx : c10::irange(grad_sizes.size())) {
      if (grad_sizes[idx] != 1) {
        if (grad_strides[idx] != variable_strides[idx]) {
          return false;
        }
      } else {
        // This should not be needed but we don't check if a Tensor has views
        // before stashing it. And 0-strided Tensors of size 1 are actually
        // views for ops like cat.
        // TODO: Actually detect views in the accumulateGrad function so that
        // this Tensor is not considered at all.
        if (grad_strides[idx] == 0) {
          return false;
        }
      }
    }
    return true;
  } else {
    return grad.is_contiguous(at::MemoryFormat::Contiguous);
  }
}

// Creates a clone of new_grad that obeys the contract with variable.
// The clone should attach to new_grad's history if GradMode::is_enabled().
inline at::Tensor clone_obey_contract(
    const at::Tensor& new_grad,
    const at::Tensor& variable) {
  if (variable.is_non_overlapping_and_dense()) {
    // (1)
    // Does this dicey-looking sequence attach the result to new_grad's
    // history if GradMode::is_enabled()?  Yes, and @alband says it should.
    return std::move(new_grad
                         .new_empty_strided_symint(
                             variable.sym_sizes(),
                             variable.sym_strides(),
                             variable.options().memory_format(c10::nullopt))
                         .copy_(new_grad));
  } else {
    // (2)
    return new_grad.clone(at::MemoryFormat::Contiguous);
  }
}

} // namespace utils
} // namespace autograd
} // namespace torch