Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / ATen / core / NestedIntSymNodeImpl.h
Size: Mime:
#pragma once

#include <c10/core/ConstantSymNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <string>

namespace c10 {

// The motivating usecase for this is to represent the ragged size structure
// of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
// allows us to simply return [B, j0, D] if someone queries for the size of our
// tensor.
//
// Morally we define comparison between two nested ints to return true if
// that comparison holds for all corresponding elements of the arrays they
// represent. Comparison between a nested int and a plain int is defined
// similarly.
//
// To simulate this desired behavior but also avoid the O(N) cost of checking,
// we associate each raggedness pattern with an integer "id" that can be used as
// a proxy to evaluate equality. We also constrain the range of values for this
// as to enable inequality checks.
//
// We also support a positive integer scalar "coeff" that is used for computing
// strides. For example given, a [B, j0, D] tensor, it can be strided in two
// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
// differentiate the two cases.
//
// During tracing the strides of the outputs need to be a function of the size
// and strides of the inputs so it is important that NestedIntSymNode itself is
// able to express this.
class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
 public:
  // CAUTION: you should probably not be constructing these directly; please
  // the higher-level API in python instead (TODO: actually introduce that).
  explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
      : val_(val), coeff_(coeff) {}

  bool bool_() override {
    return false;
  }

  bool is_int() override {
    return true;
  }

  bool is_float() override {
    return false;
  }

  bool is_bool() override {
    return false;
  }

  bool is_nested_int() const override {
    return true;
  }

  bool has_hint() override {
    return true;
  }

  c10::SymNode wrap_int(int64_t num) override {
    return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
  };

  int64_t guard_int(const char* file, int64_t line) override {
    TORCH_CHECK(false);
  }

  double guard_float(const char* file, int64_t line) override {
    TORCH_CHECK(false, "not a float");
  }

  bool guard_bool(const char* file, int64_t line) override {
    TORCH_CHECK(false, "not a bool");
  }

  int64_t int_() override {
    TORCH_CHECK(false);
  }

  std::string str() override {
    if (coeff_ == 1) {
      return "j" + std::to_string(val_);
    }
    return std::to_string(coeff_) + "*j" + std::to_string(val_);
  }

  // NOTE [ Inequalities with nested int ]
  //
  // The semantics of nested int when it comes to relations is that it is
  // treated as integer known to be within a certain range,
  //
  //     j0 \in [2, int64_t::max]
  //
  // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
  // This is a useful default range for the raggedness pattern of a jagged
  // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
  // specialization checks.
  //
  // [ Indeterminate inequalities error out ]
  //
  // Given the semantic defined above, certain relations like j0 < 3 are thus
  // indeterminable. In our impl today, evaluating such relations error
  //
  // It may seem convenient to just define indeterminate relations to return
  // False, but the implementation we maintain in parallel using sympy does not
  // allow this.
  //
  // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
  // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
  // would mean that means that if we define the indeterminate j0 >= 3 to be
  // False, the also indeterminate j0 < 3 will be evaluated to be True!
  //
  // [ Coefficient are assumed positive ]
  //
  // For the purpose of computing inequalities, we consider the coefficient of
  // the nested int to be a positive integer.
  //
  // Thus, no modifications are needed to the logic since
  // j0 >= k implies coeff * j0 >= k
  //
  c10::SymNode eq(const c10::SymNode& other) override;
  c10::SymNode ne(const c10::SymNode& other) override;
  c10::SymNode ge(const c10::SymNode& other) override;
  c10::SymNode gt(const c10::SymNode& other) override;
  c10::SymNode lt(const c10::SymNode& other) override;
  c10::SymNode le(const c10::SymNode& other) override;
  c10::SymNode mul(const c10::SymNode& other) override;

  std::optional<int64_t> nested_int() override {
    return val_;
  }

  std::optional<int64_t> nested_int_coeff() override {
    return coeff_;
  }

  bool is_symbolic() override {
    return false;
  }

  c10::SymNode clone() override;

#define DEFINE_BINARY_NOT_SUPPORTED(name)                           \
  c10::SymNode name(const c10::SymNode& other) override {           \
    TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
  }

  DEFINE_BINARY_NOT_SUPPORTED(add)
  DEFINE_BINARY_NOT_SUPPORTED(sub)
  DEFINE_BINARY_NOT_SUPPORTED(truediv)
  DEFINE_BINARY_NOT_SUPPORTED(pow)
  DEFINE_BINARY_NOT_SUPPORTED(floordiv)
  DEFINE_BINARY_NOT_SUPPORTED(mod)
  DEFINE_BINARY_NOT_SUPPORTED(sym_min)
  DEFINE_BINARY_NOT_SUPPORTED(sym_max)
  DEFINE_BINARY_NOT_SUPPORTED(sym_and)
  DEFINE_BINARY_NOT_SUPPORTED(sym_or)

#undef DEFINE_BINARY_NOT_SUPPORTED

#define DEFINE_NOT_SUPPORTED(name)                                     \
  c10::SymNode name() override {                                       \
    TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
  }

  DEFINE_NOT_SUPPORTED(sym_not)
  DEFINE_NOT_SUPPORTED(ceil)
  DEFINE_NOT_SUPPORTED(floor)
  DEFINE_NOT_SUPPORTED(neg)
  DEFINE_NOT_SUPPORTED(sym_float)

#undef DEFINE_NOT_SUPPORTED

 private:
  int64_t val_;
  int64_t coeff_;
};

} // namespace c10