Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / core / dispatch / OperatorEntry.h

#pragma once

#include <ATen/core/function_schema.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/either.h>
#include <c10/util/Optional.h>
#include <c10/core/DispatchKey.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/DispatchKeyExtractor.h>

#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>

#include <list>
#include <array>

namespace c10 {

class Dispatcher;

namespace impl {

// This data structure represents a kernel that was registered to us from a
// user.  Unlike KernelFunction, AnnotatedKernel contains some extra metadata
// about the kernel that isn't necessary for actual dispatching (this is why
// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
// giving good error messages.
struct AnnotatedKernel final {
  AnnotatedKernel(KernelFunction k, std::unique_ptr<FunctionSchema> s, std::string d)
    : kernel(std::move(k))
    , inferred_function_schema(std::move(s))
    , debug(std::move(d))
    {}
  AnnotatedKernel() {}
  KernelFunction kernel;
  std::unique_ptr<FunctionSchema> inferred_function_schema;
  // A little debug string to help us identify the kernel in question.
  // Most importantly it records the TORCH_LIBRARY block that did the
  // registration.
  std::string debug;
};

// This data structure represents operator schema, with metadata specifying
// where the registration of this schema occurred
struct AnnotatedSchema final {
  AnnotatedSchema(FunctionSchema s, std::string d)
    : schema(std::move(s))
    , debug(std::move(d))
    {}
  FunctionSchema schema;
  std::string debug;
};

// Internal data structure that records information about a specific operator.
// It's not part of the public API; typically, users will interact with
// OperatorHandle instead.
//
// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher
// lock (this is important because some methods in OperatorEntry access
// dispatcher state)
class TORCH_API OperatorEntry final {
public:
  explicit OperatorEntry(OperatorName&& operator_name);

  OperatorEntry(const OperatorEntry&) = delete;
  OperatorEntry(OperatorEntry&&) noexcept = delete;
  OperatorEntry& operator=(const OperatorEntry&) = delete;
  OperatorEntry& operator=(OperatorEntry&&) noexcept = delete;

  const FunctionSchema& schema() const {
    TORCH_INTERNAL_ASSERT(schema_.has_value(), "Tried to access the schema for ", name_, " which doesn't have a schema registered yet");
    return schema_->schema;
  }
  const std::string& debug() const {
    TORCH_INTERNAL_ASSERT(schema_.has_value());
    return schema_->debug;
  }
  bool hasSchema() const {
    return schema_.has_value();
  }

  bool isObserved() const {
    return is_observed_;
  }

  // We may allocate an OperatorEntry for an operator even when we don't
  // have a schema.  When we receive the schema registration, we post
  // facto register a schema.
  //
  // NB: registerSchema/deregisterSchema are not idempotent; if you
  // attempt to register a schema when one is already present or vice
  // versa that is an error.  (Refcounting for the registrations is
  // handled in the OperatorHandle in Dispatcher)
  void registerSchema(FunctionSchema&&, std::string&& debug);
  void deregisterSchema();

  const OperatorName& operator_name() const {
    return name_;
  }

  // Why are kernels and fallback asymmetric?  It has to do with ownership.
  // Kernels and the computed dispatch tables for them are canonically
  // owned by OperatorEntry, but backend fallbacks are specified once
  // and apply for all operators, so they should be owned by Dispatcher.
  // However, the registration of a backend fallback affects the
  // state of the computed dispatch table, so when a backend fallback
  // is updated, we need to update the operator tables too.  Thus,
  // registerKernel is the mechanism by which we give kernels to
  // operator entry to own (and update dispatch table), but we only
  // need a non-owning mechanism to update fallback.

  // Precondition: Dispatcher::mutex_ is held
  // Postcondition: caller is responsible for disposing of the kernel
  std::list<AnnotatedKernel>::iterator registerKernel(
    const Dispatcher& dispatcher,
    c10::optional<DispatchKey> dispatch_key,
    KernelFunction kernel,
    c10::optional<CppSignature> cpp_signature,
    std::unique_ptr<FunctionSchema> inferred_function_schema,
    std::string debug
  );

  // Precondition: Dispatcher::mutex_ is held
  void deregisterKernel_(
    const Dispatcher& dispatcher,
    c10::optional<DispatchKey> dispatch_key,
    std::list<AnnotatedKernel>::iterator kernel
  );

  // Precondition: Dispatcher::mutex_ is held
  void updateFallback(
    const Dispatcher& dispatcher,
    DispatchKey dispatch_key
  );

  // Precondition: Dispatcher::mutex_ is held
  void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
    TORCH_INTERNAL_ASSERT(schema_.has_value());
    schema_->schema.setAliasAnalysis(a);
  }

  std::string dumpComputedTable() const;
  std::string dumpState() const;
  void checkInvariants() const;

  const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; }

  // Asserts that the given FuncType is correct for calling this operator in an unboxed way.
  template<class FuncType>
  void assertSignatureIsCorrect() {
    if (C10_UNLIKELY(cpp_signature_.has_value() && (CppSignature::make<FuncType>() != cpp_signature_->signature))) {
      reportSignatureError(CppSignature::make<FuncType>().name());
    }
  }

  [[noreturn]] void reportError(DispatchKey dispatchKey) const;

  const KernelFunction& lookup(DispatchKey k) const {
    const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
    // A valid kernel *always* has a boxed kernel and *may* have an
    // unboxed kernel. However, we typically do unboxed calls in at::
    // APIs, where the kernel 1) will very likely be valid and 2)
    // should have an unboxed kernel. Checking the unboxed kernel
    // first will allow us to avoid touching the boxed kernel at all
    // in the common case.
    if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
      if (!kernel.isValid()) {
        reportError(k);
      }
    }
    return kernel;
  }

  std::string listAllDispatchKeys() const;

private:

  OperatorName name_;
  c10::optional<AnnotatedSchema> schema_;

  std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
  DispatchKeyExtractor dispatchKeyExtractor_;

  // kernels_ stores all registered kernels for the corresponding dispatch key
  // and catchAllKernels_ stores the catch-all kernels.
  // If an operator library gets loaded that overwrites an already existing kernel,
  // both kernels will be in that list but only the newer one will be in
  // dispatchTable. If any of the kernels go away (say the library gets
  // unloaded), we remove the kernel from this list and update the
  // dispatchTable if necessary.
  // Kernels in the list are ordered by registration time descendingly,
  // newer registrations are before older registrations.
  // We do not combine dispatchTable and kernels into one hash map because
  // kernels is a larger data structure and accessed quite infrequently
  // while dispatchTable is accessed often and should be kept small to fit
  // into CPU caches.
  // Invariants:
  //  - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
  //  - dispatchTable[dispatch_key] does not exist if and only if
  //    kernels_[dispatch_key] does not exist
  //  - If kernels_[dispatch_key] exists, then it has elements.
  //    It is never an empty list.
  //
  // Why do we do that?
  // -----
  // We mostly do this to enable Jupyter notebooks where a cell registering
  // a kernel could be executed multiple times and the later execution
  // should overwrite the earlier one. Note that this still fails when the
  // function schema changed between the executions, but it works as long
  // as the function schema didn't change. A better solution would be to
  // unload the old extension library from the Jupyter cell when the cell is
  // re-executed and then only allow one kernel here, i.e. error if a kernel
  // is already registered, but that's a lot of effort to implement and
  // currently not high-pri.
  ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;

  AnnotatedKernel missingKernel_;
  static const AnnotatedKernel ambiguousAutogradOtherKernel_;

  // cpp_signature_ stores function signature if any of
  // the kernels was created in a way that allowed us to know the function
  // signature (i.e. by supplying an unboxed C++ kernel function).
  // If this is set, it will be used to check that future kernel
  // registrations match and it will be used in unboxed function calls
  // to verify their arguments against the known function signature.
  struct CppSignatureWithDebug {
    CppSignature signature;
    std::string debug;
    c10::optional<DispatchKey> dispatch_key;
  };
  c10::optional<CppSignatureWithDebug> cpp_signature_;

  // Whether this operator needs to be observed with RecordFunction
  const bool is_observed_;

  [[noreturn]] void reportSignatureError(std::string name) const;
  const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
  std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug(
    const c10::Dispatcher& dispatcher, DispatchKey dispatch_key
  ) const;
  // This function re-establishes the invariant that dispatchTable
  // contains the front element from the kernels list for a given runtime dispatch key.
  void updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
  // Like above, but also handles alias dispatch keys.
  void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
  // Like above, but for ALL entries in the dispatch table.
  void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);

  // Returns true if kernel_ has entry for any key in ks.
  bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
  // Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front().
  c10::optional<const AnnotatedKernel*> getKernelForDispatchKey(DispatchKey dispatch_key) const;
};

} // namespace impl
} // namespace c10