Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / c10 / core / MemoryFormat.h

#pragma once

#include <c10/core/Backend.h>
#include <c10/util/Exception.h>
#include <c10/util/ArrayRef.h>

#include <iostream>

// Memory format is not the property of a Tensor. It is the way to tell an
// operator how the result should be organized in memory and nothing more. That
// means memory format should never be used as return value for any tensor state
// interrogation functions (internally and externally).
//
// Possible options are:
//  Preserve:
//    If any of the input tensors is in channels_last format, operator output
//    should be in channels_last format
//
//  Contiguous:
//    Regardless of input tensors format, the output should be contiguous Tensor.
//
//  ChannelsLast:
//    Regardless of input tensors format, the output should be in channels_last format.


namespace c10 {
enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };

// If you are seeing this, it means that this call site was not checked if
// the memory format could be preserved, and it was switched to old default
// behaviour of contiguous
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()

inline MemoryFormat get_contiguous_memory_format() {
  return MemoryFormat::Contiguous;
}

inline std::ostream& operator<<(
    std::ostream& stream,
    at::MemoryFormat memory_format) {
  switch (memory_format) {
    case MemoryFormat::Preserve:
      return stream << "Preserve";
    case MemoryFormat::Contiguous:
      return stream << "Contiguous";
    case MemoryFormat::ChannelsLast:
      return stream << "ChannelsLast";
    case MemoryFormat::ChannelsLast3d:
      return stream << "ChannelsLast3d";
    default:
      AT_ERROR("Unknown memory format");
  }
}

// Note: Hardcoded the channel last stride indices here to get better performance
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
  std::vector<int64_t> strides(sizes.size());
  switch (sizes.size()) {
    case 4:
      strides[1] = 1;
      strides[3] = sizes[1];
      strides[2] = strides[3] * sizes[3];
      strides[0] = strides[2] * sizes[2];
      return strides;
    case 3:
      strides[0] = 1;
      strides[2] = sizes[0];
      strides[1] = strides[2] * sizes[2];
      return strides;
    default:
      TORCH_INTERNAL_ASSERT(false, "ChannelsLast2d doesn't support size ", sizes.size());
  }
}

inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
  std::vector<int64_t> strides(sizes.size());
  switch (sizes.size()) {
    case 5:
      strides[1] = 1;
      strides[4] = sizes[1];
      strides[3] = strides[4] * sizes[4];
      strides[2] = strides[3] * sizes[3];
      strides[0] = strides[2] * sizes[2];
      return strides;
    case 4:
      strides[0] = 1;
      strides[3] = sizes[0];
      strides[2] = strides[3] * sizes[3];
      strides[1] = strides[2] * sizes[2];
      return strides;
    default:
      TORCH_INTERNAL_ASSERT(false, "ChannelsLast3d doesn't support size ", sizes.size());
  }
}

// NOTE:
// Below are Helper functions for is_channels_last_strides_xd.
// 1. Please do not combine these helper functions, each helper function handles
// exactly one case of sizes + memory_format, by doing this, the strides indices
// will be a constant array and we can access it using constant index number,
// the compiler will fully unroll the loop on strides indices to gain a better
// performance.
// 2. No error check in helper function, caller ensures the correctness of the input
// 3. All helper functions have similar comments, only 1st helper function is commented here.
inline bool is_channels_last_strides_2d_s4(const IntArrayRef sizes, const IntArrayRef strides) {
  int64_t min = 0;
  // special case for trivial C dimension. default to NCHW
  if (strides[1]==0) {
    return false;
  }
  // loop strides indices
  for (auto& d : {1, 3, 2, 0}) {
    if (sizes[d] == 0) {
      return false;
    }
    if (strides[d] < min) {
      return false;
    }
    // Fallback to NCHW as default layout for ambiguous cases
    // This is the flaw of implicit memory_format from strides.
    // N111 tensor with identical strides for size 1 dimension;
    // Two cases could lead us here:
    // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
    // b. N11W contiguous Tensor sliced on the W-dimension. ([N,1,1,1]@[W,W,W,W])
    if (d==0 && min==strides[1]) {
      return false;
    }
    // This is necessary to:
    // 1. distinguish the memory_format of N1H1;
    //     [H, 1, 1, 1] channels_last stride
    //     [H, H, 1, 1] contiguous stride
    // 2. permutation of 1C1W:
    //     [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
    //     [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
    min = strides[d];
    if (sizes[d] > 1) {
      min *= sizes[d];
    }
  }
  return true;
}

inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArrayRef strides) {
  int64_t min = 0;
  if (strides[1] == 0) {
    return false;
  }
  for (auto& d : {1, 4, 3, 2, 0}) {
    if (sizes[d] == 0) {
      return false;
    }
    if (strides[d] < min) {
      return false;
    }
    if (d == 0 && min == strides[1]) {
      return false;
    }
    min = strides[d];
    if (sizes[d] > 1) {
      min *= sizes[d];
    }
  }
  return true;
}

// Note [Ambiguous is_channels_last_strides_xd]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// The flaw of carrying memory_format implicitly through strides is very hard
// to WAR properly. issue #24090
// Without the history of permutation, we can't infer the memory_format of a
// tensor from the snapshot of its size & stride
// e.g.
//
// 1. We can NOT specify the memory_format of N111 tensor through strides in a
//  meaningful way;
//
// 2. Two path that ended up with identical size/stride
//  N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
//  NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
//    So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
//    the memory_format of the original tensor.
//
// Due to the limitations, our temporary WAR `is_channels_last_strides` does the
// best effort to infer whether the original memory_format of a tensor is
// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
// by their importance):
//   1. Ensure that normal shape manipulation does not accidentally change the
//      MemoryFormat of an existing tensor.
//   2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
//
// The function does so via checking strides of the tensor, including strides of
// size-1 dimensions. Although conventionally PyTorch implies no restriction on
// trivial stride (stride for size-1 dimension).
//
// Note that this approach is a compromise. We did not solve the problem
// completely. Many cases we will not be able to infer the correct memory
// format.
// The implementation of `is_channels_last_strides` is to serve the objectives:
// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
// conversion); Best effort to maintain the ChannelsLast flag.
//
// Due to the fact that this is not a bulletproof solution, through testing
// (aten/src/ATen/test/memory_format_test.cpp)
//   a. we ensure that the common tasks are supported;
//   a. we identify corner cases where the implementation compromises on.
//
// By the time accumulated permutation is enabled to replace implicit
// memory_format through strides, we should be updating our tests and fix the
// issues in our tests.
//
// We use Channels Last 2d as an example above.
// This is a general problem for all the is_channels_last_strides_xd implementation.
// Please check the helper functions (is_channels_last_strides_*d_s*) for more details.

inline bool is_channels_last_strides_2d(const IntArrayRef sizes, const IntArrayRef strides) {
  switch (sizes.size()) {
    case 4:
      return is_channels_last_strides_2d_s4(sizes, strides);
    case 3:
      // TODO dim == 3 case will be enabled once it is fully tested
      return false;
    default:
      return false;
  }
}

inline bool is_channels_last_strides_3d(const IntArrayRef sizes, const IntArrayRef strides) {
  switch (sizes.size()) {
    case 5:
      return is_channels_last_strides_3d_s5(sizes, strides);
    case 4:
      // TODO dim == 4 case will be enabled once it is fully tested
      return false;
    default:
      return false;
  }
}

} // namespace c10