Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / core / interned_strings.h

#pragma once
#include <vector>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <algorithm>

#include <c10/macros/Macros.h>

#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
#include <ATen/core/aten_interned_strings.h>
#endif

namespace c10 {

#if !defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)
#define FORALL_NS_SYMBOLS(_)         \
  _(namespaces, prim)                \
  _(namespaces, aten)                \
  _(namespaces, cuda)                \
  _(namespaces, onnx)                \
  _(namespaces, attr)                \
  _(namespaces, scope)               \
  _(namespaces, user)                \
  _(namespaces, _caffe2)             \
  _(namespaces, dimname)             \
  _(namespaces, namespaces)          \
  _(prim, Assign)                    \
  _(prim, BroadcastingChunk)         \
  _(prim, BroadcastSizes)            \
  _(prim, ReductionSizes)            \
  _(prim, Constant)                  \
  _(prim, ChunkSizes)                \
  _(prim, Drop)                      \
  _(prim, Eval)                      \
  _(prim, Expand) /* onnx */         \
  _(prim, FusionGroup)               \
  _(prim, CudaFusionGroup)           \
  _(prim, CudaFusionGuard)           \
  _(prim, FunctionalGraph)           \
  _(prim, DifferentiableGraph)       \
  _(prim, TensorExprGroup)           \
  _(prim, StaticSubgraph)            \
  _(prim, If)                        \
  _(prim, Jump) /* debug */          \
  _(prim, JumpNZ) /* debug */        \
  _(prim, JumpZ) /* debug */         \
  _(prim, Load)                      \
  _(prim, Loop)                      \
  _(prim, Param)                     \
  _(prim, PackPadded) /* onnx */     \
  _(prim, PadPacked) /* onnx */      \
  _(prim, Placeholder) /* debug */   \
  _(prim, Print)                     \
  _(prim, PythonOp)                  \
  _(prim, IgnoredPythonOp)           \
  _(prim, Reverse)                   \
  _(prim, Return)                    \
  _(prim, ReturnStmt)                \
  _(prim, BreakStmt)                 \
  _(prim, ContinueStmt)              \
  _(prim, ComprehensionScope)        \
  _(prim, Store)                     \
  _(prim, AutogradZero)              \
  _(prim, AutogradAnyNonZero)        \
  _(prim, AutogradAllNonZero)        \
  _(prim, AutogradAllZero)           \
  _(prim, Starred)                   \
  _(prim, TupleConstruct)            \
  _(prim, TupleUnpack)               \
  _(prim, TupleIndex)                \
  _(prim, TupleSlice)                \
  _(prim, ListConstruct)             \
  _(prim, ListUnpack)                \
  _(prim, DictConstruct)             \
  _(prim, ModuleDictIndex)           \
  _(prim, EnumName)                  \
  _(prim, EnumValue)                 \
  _(prim, StringIndex)               \
  _(prim, NumToTensor)               \
  _(prim, Uninitialized)             \
  _(prim, With)                      \
  _(prim, Enter)                     \
  _(prim, Exit)                      \
  _(aten, Bool)                      \
  _(aten, Int)                       \
  _(aten, FloatImplicit)             \
  _(aten, IntImplicit)               \
  _(aten, ScalarImplicit)            \
  _(aten, Float)                     \
  _(aten, str)                       \
  _(aten, Delete)                    \
  _(prim, device)                    \
  _(prim, dtype)                     \
  _(prim, layout)                    \
  _(prim, id)                        \
  _(prim, requires_grad)             \
  _(prim, MakeTestTensor) /* test */ \
  _(prim, AutogradAdd)               \
  _(prim, GradOf)                    \
  _(aten, grad)                      \
  _(aten, backward)                  \
  _(prim, Guard)                     \
  _(prim, BailOut)                   \
  _(prim, TypeCheck)                 \
  _(prim, RequiresGradCheck)         \
  _(prim, FallbackGraph)             \
  _(prim, FusedConcat)               \
  _(prim, ConstantChunk)             \
  _(prim, MMTreeReduce)              \
  _(prim, MMBatchSide)               \
  _(prim, list)                      \
  _(prim, min)                       \
  _(prim, max)                       \
  _(prim, abs)                       \
  _(aten, divmod)                    \
  _(prim, zip)                       \
  _(prim, enumerate)                 \
  _(prim, range)                     \
  _(prim, rangelist)                 \
  _(prim, isinstance)                \
  _(prim, tolist)                    \
  _(prim, unchecked_cast)            \
  _(aten, _grad_sum_to_size)         \
  _(aten, _size_if_not_equal)        \
  _(aten, _ncf_unsqueeze)            \
  _(aten, warn)                      \
  _(aten, sorted)                    \
  _(aten, floordiv)                  \
  _(aten, __range_length)            \
  _(aten, __derive_index)            \
  _(aten, __round_to_zero_floordiv)  \
  _(aten, is_scripting)              \
  _(aten, _unwrap_optional)          \
  _(prim, fork)                      \
  _(prim, forkClosure)               \
  _(prim, RaiseException)            \
  _(prim, Closure)                   \
  _(prim, CreateObject)              \
  _(prim, SetAttr)                   \
  _(prim, GetAttr)                   \
  _(prim, HasAttr)                   \
  _(prim, profile)                   \
  _(prim, profile_ivalue)            \
  _(prim, AddStatValue)              \
  _(prim, TimePoint)                 \
  _(prim, CallFunction)              \
  _(prim, CallMethod)                \
  _(prim, LoopContinuation)          \
  _(prim, annotate)                  \
  _(prim, TracedModuleForward)       \
  _(prim, TracedFork)                \
  _(prim, TracedAttr)                \
  _(prim, rpc_async)                 \
  _(prim, rpc_sync)                  \
  _(prim, rpc_remote)                \
  _(prim, is_cuda)                   \
  _(aten, abs_)                      \
  _(aten, absolute)                  \
  _(aten, absolute_)                 \
  _(aten, acos)                      \
  _(aten, acos_)                     \
  _(aten, arccos)                    \
  _(aten, arccos_)                   \
  _(aten, acosh)                     \
  _(aten, acosh_)                    \
  _(aten, arccosh)                   \
  _(aten, arccosh_)                  \
  _(aten, asin)                      \
  _(aten, asin_)                     \
  _(aten, arcsin)                    \
  _(aten, arcsin_)                   \
  _(aten, asinh)                     \
  _(aten, asinh_)                    \
  _(aten, arcsinh)                   \
  _(aten, arcsinh_)                  \
  _(aten, atan)                      \
  _(aten, atan_)                     \
  _(aten, arctan)                    \
  _(aten, arctan_)                   \
  _(aten, atanh)                     \
  _(aten, atanh_)                    \
  _(aten, arctanh)                   \
  _(aten, arctanh_)                  \
  _(aten, clamp)                     \
  _(aten, clamp_)                    \
  _(aten, clip)                      \
  _(aten, clip_)                     \
  _(aten, det)                       \
  _(aten, linalg_det)                \
  _(aten, linalg_norm)               \
  _(aten, append)                    \
  _(aten, item)                      \
  _(aten, format)                    \
  _(aten, percentFormat)             \
  _(aten, __not__)                   \
  _(aten, __is__)                    \
  _(aten, __isnot__)                 \
  _(aten, copy)                      \
  _(aten, copy_)                     \
  _(aten, div)                       \
  _(aten, div_)                      \
  _(aten, divide)                    \
  _(aten, divide_)                   \
  _(aten, true_divide)               \
  _(aten, true_divide_)              \
  _(aten, t_)                        \
  _(aten, addbmm_)                   \
  _(aten, addcdiv_)                  \
  _(aten, addcmul_)                  \
  _(aten, addmv_)                    \
  _(aten, addr_)                     \
  _(aten, baddbmm_)                  \
  _(aten, ge)                        \
  _(aten, ge_)                       \
  _(aten, greater_equal)             \
  _(aten, greater_equal_)            \
  _(aten, gt)                        \
  _(aten, gt_)                       \
  _(aten, greater)                   \
  _(aten, greater_)                  \
  _(aten, le)                        \
  _(aten, le_)                       \
  _(aten, less_equal)                \
  _(aten, less_equal_)               \
  _(aten, lerp_)                     \
  _(aten, lt)                        \
  _(aten, lt_)                       \
  _(aten, less)                      \
  _(aten, less_)                     \
  _(aten, isnan)                     \
  _(aten, mul)                       \
  _(aten, mul_)                      \
  _(aten, multiply)                  \
  _(aten, multiply_)                 \
  _(aten, ne)                        \
  _(aten, ne_)                       \
  _(aten, not_equal)                 \
  _(aten, not_equal_)                \
  _(aten, _ger)                      \
  _(aten, ger)                       \
  _(aten, outer)                     \
  _(aten, transpose)                 \
  _(aten, transpose_)                \
  _(aten, unsqueeze_)                \
  _(aten, __getitem__)               \
  _(aten, _set_item)                 \
  _(aten, manual_seed)               \
  _(aten, set_)                      \
  _(aten, index_put_)                \
  _(aten, device)                    \
  _(aten, hash)                      \
  _(aten, len)                       \
  _(aten, list)                      \
  _(aten, wait)                      \
  _(aten, save)                      \
  _(aten, sub)                       \
  _(aten, sub_)                      \
  _(aten, subtract)                  \
  _(aten, subtract_)                 \
  _(aten, keys)                      \
  _(aten, ord)                       \
  _(aten, chr)                       \
  _(aten, hex)                       \
  _(aten, oct)                       \
  _(aten, clear)                     \
  _(aten, trunc)                     \
  _(aten, trunc_)                    \
  _(aten, fix)                       \
  _(aten, fix_)                      \
  _(aten, neg)                       \
  _(aten, neg_)                      \
  _(aten, negative)                  \
  _(aten, negative_)                 \
  _(aten, setdefault)                \
  _(aten, bin)                       \
  _(aten, pop)                       \
  _(aten, insert)                    \
  _(aten, vstack)                    \
  _(aten, row_stack)                 \
  _(prim, unchecked_unwrap_optional) \
  _(aten, __contains__)              \
  _(prim, BailoutTemplate)           \
  _(prim, grad)                      \
  _(aten, zero_)                     \
  _(aten, fill_)                     \
  _(aten, masked_fill_)              \
  _(cuda, _set_device)               \
  _(cuda, set_stream)                \
  _(cuda, _current_device)           \
  _(aten, swapaxes)                  \
  _(aten, swapaxes_)                 \
  _(aten, swapdims)                  \
  _(aten, swapdims_)                 \
  _(aten, movedim)                   \
  _(aten, moveaxis)                  \
  _(aten, has_torch_function)        \
  FORALL_ATEN_BASE_SYMBOLS(_)        \
  _(onnx, Add)                       \
  _(onnx, Concat)                    \
  _(onnx, Constant)                  \
  _(onnx, ConstantFill)              \
  _(onnx, Div)                       \
  _(onnx, GRU)                       \
  _(onnx, Gather)                    \
  _(onnx, Gemm)                      \
  _(onnx, LSTM)                      \
  _(onnx, Mul)                       \
  _(onnx, Pow)                       \
  _(onnx, RNN)                       \
  _(onnx, Shape)                     \
  _(onnx, Size)                      \
  _(onnx, Slice)                     \
  _(onnx, Squeeze)                   \
  _(onnx, Sub)                       \
  _(onnx, Transpose)                 \
  _(onnx, Unsqueeze)                 \
  _(onnx, Loop)                      \
  _(onnx, If)                        \
  _(onnx, Reshape)                   \
  _(onnx, Expand)                    \
  _(onnx, Equal)                     \
  _(onnx, Greater)                   \
  _(onnx, GreaterOrEqual)            \
  _(onnx, Less)                      \
  _(onnx, LessOrEqual)               \
  _(onnx, Not)                       \
  _(onnx, ATen)                      \
  _(onnx, Split)                     \
  _(onnx, ConstantOfShape)           \
  _(onnx, Cast)                      \
  _(onnx, Mod)                       \
  _(onnx, Sqrt)                      \
  _(onnx, SplitToSequence)           \
  _(onnx, SequenceAt)                \
  _(onnx, SequenceConstruct)         \
  _(onnx, SequenceEmpty)             \
  _(onnx, SequenceInsert)            \
  _(onnx, SequenceErase)             \
  _(onnx, ConcatFromSequence)        \
  _(onnx, Identity)                  \
  _(onnx, SoftmaxCrossEntropyLoss)   \
  _(onnx, NegativeLogLikelihoodLoss) \
  _(onnx, LogSoftmax)                \
  _(onnx, ReduceL1)                  \
  _(onnx, ReduceL2)                  \
  _(onnx, Conv)                      \
  _(onnx, BatchNormalization)        \
  _(onnx, ReduceProd)                \
  FORALL_ATTR_BASE_SYMBOLS(_)        \
  _(attr, Subgraph)                  \
  _(attr, ReverseSubgraph)           \
  _(attr, f_real_outputs)            \
  _(attr, df_input_vjps)             \
  _(attr, df_input_captured_inputs)  \
  _(attr, df_input_captured_outputs) \
  _(attr, df_output_vjps)            \
  _(attr, axes)                      \
  _(attr, axis)                      \
  _(attr, broadcast)                 \
  _(attr, direction)                 \
  _(attr, ends)                      \
  _(attr, inplace)                   \
  _(attr, input_as_shape)            \
  _(attr, is_zero)                   \
  _(attr, num_none)                  \
  _(attr, num_present)               \
  _(attr, perm)                      \
  _(attr, sizes)                     \
  _(attr, starts)                    \
  _(attr, profiled_type)             \
  _(attr, transA)                    \
  _(attr, transB)                    \
  _(attr, name)                      \
  _(attr, a)                         \
  _(attr, b)                         \
  _(attr, beg)                       \
  _(attr, idx)                       \
  _(attr, split)                     \
  _(attr, slot)                      \
  _(attr, kinds)                     \
  _(attr, types)                     \
  _(attr, scope)                     \
  _(attr, keepdims)                  \
  _(attr, cache_id)                  \
  _(attr, new_axis)                  \
  _(attr, warn_id)
#else
#define FORALL_NS_SYMBOLS(_) \
  _(namespaces, prim)              \
  _(namespaces, aten)              \
  _(namespaces, cuda)              \
  _(namespaces, onnx)              \
  _(namespaces, attr)              \
  _(namespaces, scope)             \
  _(namespaces, user)              \
  _(namespaces, _caffe2)           \
  _(namespaces, dimname)           \
  _(namespaces, namespaces)
#endif

// 'prim' symbols are synthetic operators that occur only in the IR
// and don't have corresponding implementations in ATen.

// 'onnx' symbols correspond to ONNX operators.  Their semantics
// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
// The particular version we are targeting is specified by '_onnx_opset_version'
// in torch.onnx.symbolic_helper
//
// In general, most ONNX operators won't get an entry here, because they
// are handled from the Python end.  However, you may occasionally need
// to intern an ONNX symbol here so that you can conveniently write an
// optimization on ONNX operations.

// 'attr' symbols are attribute keys.  They are shared between both ONNX and ATen
// operators (you disambiguate their meaning by looking at the operator itself).
// In general, you only need to define attribute keys that are used by
// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.

// Note [Symbol allocation]
// ~~~~~~~~~~~~~~~~~~~~~~~~
//
//  1. Symbol namespace is split up into namespaces.
//
//  2. The intended access pattern for built-in symbols is onnx::MatMul
//  in the c10 namespace (this is a Symbol).
//

// Built-in constant definition strategy:
// - Enum is the most convenient way to generate a contiguous sequence
//   of numbers for an identifier.
// - However, an enum gives you a fresh type.  We want onnx::MatMul to
//   be type Symbol, not some random enum type!
// - Therefore, after using enums to generate the sequence of integers,
//   we then declare constexpr Symbols to get everything the actual Symbol
//   type we want.  Symbols must be constexpr to be valid to be "case"ed on.

using unique_t = uint32_t;

const std::string& domain_prefix();

// A Symbol is like an interned string, but with a little extra
// structure; it is namespaced via SymbolNamespace and the resulting
// intern pointers support efficient namespace testing.
struct TORCH_API Symbol {
  explicit constexpr Symbol() : value(0) {};
  explicit constexpr Symbol(unique_t uniq)
  : value(uniq) {}

  // Get a Symbol for a qualified string like "attr::bar"
  static Symbol fromQualString(const std::string & s);

  // Get a Symbol from a domain and an unqualified string like "org.pytorch.attr" and "bar"
  static Symbol fromDomainAndUnqualString(const std::string & d, const std::string & s);

  // Constructors for our various namespaced strings.  This will construct
  // the appropriate namespaced string, e.g., "attr::foo" for the
  // argument "foo", and then attempt to intern it.  DO NOT USE THIS
  // with a string literal; attr::foo should be available in that case
  // (and if it's not, you should add it to the built-ins list above.)
  static Symbol attr(const std::string & s);
  static Symbol aten(const std::string & s);
  static Symbol cuda(const std::string & s);
  static Symbol onnx(const std::string & s);
  static Symbol prim(const std::string & s);
  static Symbol user(const std::string & s);
  static Symbol caffe2(const std::string & s);
  static Symbol dimname(const std::string & s);
  // TODO: eliminate me
  static Symbol scope(const std::string & s);

  bool is_attr() const;
  bool is_aten() const;
  bool is_cuda() const;
  bool is_prim() const;
  bool is_onnx() const;
  bool is_user() const;
  bool is_caffe2() const;
  bool is_dimname() const;

  // So we can switch on this
  constexpr operator unique_t() const {
    return value;
  }

  Symbol ns() const;

  // Give a string corresponding to the unqualified version of this name, e.g.,
  // "mm". Use this in a context where the intended namespace of the string is
  // obvious; this is a *lossy* conversion.
  const char * toUnqualString() const;

  // Give a string corresponding to the qualified version of this name,
  // e.g., "aten::mm".  This string format is made available to Python bindings
  // (so we know how to parse it.)
  const char * toQualString() const;

  // This describes a symbol in a case where humans read it.  At the moment it's
  // the same as toQualString.  This has to be a const char* returned because
  // a lot of printf style macros use it.
  const char * toDisplayString() const;

  // Give a string corresponding to the domain name for the symbol,
  // e.g., "org.pytorch.aten".
  std::string domainString() const;

private:
  explicit Symbol(Symbol ns, const std::string & s);
  unique_t value;
};

static inline bool operator==(Symbol lhs, Symbol rhs) {
  return static_cast<unique_t>(lhs) == static_cast<unique_t>(rhs);
}

enum class _keys : unique_t {
    #define DEFINE_KEY(ns, s) ns##_##s,
    FORALL_NS_SYMBOLS(DEFINE_KEY)
    #undef DEFINE_KEY
    num_symbols
};

#define DEFINE_SYMBOL(s) \
  constexpr Symbol s(static_cast<unique_t>(_keys::s));

#undef DEFINE_SYMBOL

#define DEFINE_SYMBOL(ns, s) \
  namespace ns { constexpr Symbol s(static_cast<unique_t>(_keys::ns##_##s)); }
FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
#undef DEFINE_SYMBOL

inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
inline Symbol Symbol::aten(const std::string & s)  { return Symbol::fromQualString("aten::" + s); }
inline Symbol Symbol::cuda(const std::string & s)  { return Symbol::fromQualString("cuda::" + s); }
inline Symbol Symbol::onnx(const std::string & s)  { return Symbol::fromQualString("onnx::" + s); }
inline Symbol Symbol::prim(const std::string & s)  { return Symbol::fromQualString("prim::" + s); }
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); }
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
inline bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; }
inline bool Symbol::is_dimname() const { return ns() == namespaces::dimname; }

} // namespace c10

// make symbol behave like an integer in hash tables
namespace std {
template <>
struct hash<c10::Symbol> {
  size_t operator()(c10::Symbol s) const {
    return std::hash<uint32_t>()(static_cast<uint32_t>(s));
  }
};
}