Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / SparseCsrTensorUtils.h

#pragma once

#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/core/Tensor.h>

#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
  [&] {                                                              \
    const auto& the_layout = LAYOUT;                                 \
    switch (the_layout) {                                            \
      case kSparseCsr:                                               \
      case kSparseCsc:                                               \
      case kSparseBsr:                                               \
      case kSparseBsc:                                               \
        return __VA_ARGS__();                                        \
      default:                                                       \
        AT_ERROR(                                                    \
            NAME,                                                    \
            " expected sparse compressed tensor layout but got ",    \
            the_layout);                                             \
    }                                                                \
  }()

#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(                \
    LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION)              \
  [&]() {                                                         \
    const auto& the_layout = LAYOUT;                              \
    switch (the_layout) {                                         \
      case kSparseCsr:                                            \
      case kSparseBsr:                                            \
        return (ROW_DIM_ACTION)();                                \
      case kSparseCsc:                                            \
      case kSparseBsc:                                            \
        return (COLUMN_DIM_ACTION)();                             \
      default:                                                    \
        AT_ERROR(                                                 \
            NAME,                                                 \
            " expected sparse compressed tensor layout but got ", \
            the_layout);                                          \
    }                                                             \
  }()

#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(              \
    LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION)                  \
  [&]() {                                                         \
    const auto& the_layout = LAYOUT;                              \
    switch (the_layout) {                                         \
      case kSparseCsr:                                            \
      case kSparseCsc:                                            \
        return (NO_BLOCK_ACTION)();                               \
      case kSparseBsr:                                            \
      case kSparseBsc:                                            \
        return (BLOCK_ACTION)();                                  \
      default:                                                    \
        AT_ERROR(                                                 \
            NAME,                                                 \
            " expected sparse compressed tensor layout but got ", \
            the_layout);                                          \
    }                                                             \
  }()

#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(                    \
    LAYOUT, NAME, ROW_DIM_ACTION)                                     \
  [&]() {                                                             \
    const auto& the_layout = LAYOUT;                                  \
    switch (the_layout) {                                             \
      case kSparseCsr:                                                \
      case kSparseBsr:                                                \
        return (ROW_DIM_ACTION)();                                    \
      default:                                                        \
        AT_ERROR(                                                     \
            NAME,                                                     \
            " expected sparse row compressed tensor layout but got ", \
            the_layout);                                              \
    }                                                                 \
  }()

#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(                       \
    LAYOUT, NAME, COL_DIM_ACTION)                                        \
  [&]() {                                                                \
    const auto& the_layout = LAYOUT;                                     \
    switch (the_layout) {                                                \
      case kSparseCsc:                                                   \
      case kSparseBsc:                                                   \
        return (COL_DIM_ACTION)();                                       \
      default:                                                           \
        AT_ERROR(                                                        \
            NAME,                                                        \
            " expected sparse column compressed tensor layout but got ", \
            the_layout);                                                 \
    }                                                                    \
  }()

#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION)  \
  [&]() {                                                                     \
    const auto& the_layout = LAYOUT;                                          \
    switch (the_layout) {                                                     \
      case kSparseCsr:                                                        \
      case kSparseCsc:                                                        \
        return (ACTION)();                                                    \
      default:                                                                \
        AT_ERROR(                                                             \
            NAME,                                                             \
            " expected sparse compressed (non-block) tensor layout but got ", \
            the_layout);                                                      \
    }                                                                         \
  }()

#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  [&]() {                                                                 \
    const auto& the_layout = LAYOUT;                                      \
    switch (the_layout) {                                                 \
      case kSparseBsr:                                                    \
      case kSparseBsc:                                                    \
        return (ACTION)();                                                \
      default:                                                            \
        AT_ERROR(                                                         \
            NAME,                                                         \
            " expected sparse compressed block tensor layout but got ",   \
            the_layout);                                                  \
    }                                                                     \
  }()

#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
  AT_DISPATCH_SWITCH(                                   \
      TYPE,                                             \
      NAME,                                             \
      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(      \
          kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))

namespace at {
namespace sparse_csr {

using SparseCsrTensor = Tensor;

inline bool is_sparse_compressed(const Layout& layout) {
  switch (layout) {
    case kSparseCsr:
    case kSparseCsc:
    case kSparseBsr:
    case kSparseBsc:
      return true;
    default:;
  }
  return false;
}

inline bool is_sparse_compressed(const Tensor& self) {
  return is_sparse_compressed(self.layout());
}

inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
  AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
      self.layout(), "get_sparse_csr_impl", [&] {});
  return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
}

inline std::string layoutToString(
    Layout layout,
    bool upper = false,
    bool lower = false) {
  switch (layout) {
    case kSparseCsr:
      return (upper ? "CSR" : (lower ? "csr" : "Csr"));
    case kSparseCsc:
      return (upper ? "CSC" : (lower ? "csc" : "Csc"));
    case kSparseBsr:
      return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
    case kSparseBsc:
      return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
    default:
      TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
      return "";
  }
}

inline bool isCompressedRow(Layout layout) {
  return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
      layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
}

inline bool isCompressedColumn(Layout layout) {
  return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
      layout,
      "isCompressedColumn",
      [&] { return false; },
      [&] { return true; });
}

inline std::string compressedIndicesName(Layout layout) {
  return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
      layout,
      "compressedIndicesName",
      [&] { return "crow_indices"; },
      [&] { return "ccol_indices"; });
}

inline std::string plainIndicesName(Layout layout) {
  return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
      layout,
      "plainIndicesName",
      [&] { return "col_indices"; },
      [&] { return "row_indices"; });
}

inline std::string compressedDimName(Layout layout) {
  switch (layout) {
    case kSparseCsr:
      return "row";
    case kSparseCsc:
      return "column";
    case kSparseBsr:
      return "row block";
    case kSparseBsc:
      return "column block";
    default:
      TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
      return "";
  }
}

inline std::string plainDimName(Layout layout) {
  switch (layout) {
    case kSparseCsr:
      return "column";
    case kSparseCsc:
      return "row";
    case kSparseBsr:
      return "column block";
    case kSparseBsc:
      return "row block";
    default:
      TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
      return "";
  }
}

inline int rowDimension(Layout layout, IntArrayRef size) {
  return size.size() - (isCompressedRow(layout) ? 2 : 1);
}

inline int columnDimension(Layout layout, IntArrayRef size) {
  return size.size() - (isCompressedColumn(layout) ? 2 : 1);
}

inline int compressedDimension(
    Layout layout,
    IntArrayRef size,
    size_t dense_ndim = 0) {
  return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
}

inline int plainDimension(
    Layout layout,
    IntArrayRef size,
    size_t dense_ndim = 0) {
  return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
}

inline int64_t numBatchDimensions(Tensor const& self) {
  return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
      self.layout(),
      "numBatchDimensions",
      [&self] { return self.crow_indices().dim() - 1; },
      [&self] { return self.ccol_indices().dim() - 1; });
}

inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
  return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
      self.layout(),
      "getCompressedPlainIndices",
      [&self] {
        return std::make_pair(self.crow_indices(), self.col_indices());
      },
      [&self] {
        return std::make_pair(self.ccol_indices(), self.row_indices());
      });
}

inline Layout flip_compressed_layout(Layout layout) {
  switch (layout) {
    case kSparseCsr:
      return kSparseCsc;
    case kSparseCsc:
      return kSparseCsr;
    case kSparseBsr:
      return kSparseBsc;
    case kSparseBsc:
      return kSparseBsr;
    default:
      TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
      return kSparseCsr;
  }
}

inline DimVector getBlockSize(Tensor const& self) {
  int64_t n_batch = numBatchDimensions(self);
  return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
}

inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
  if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
    int64_t n_batch = numBatchDimensions(self);
    return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
  } else {
    return {};
  }
}

} // namespace sparse_csr
} // namespace at