Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / torch / csrc / jit / passes / pass_manager.h

#pragma once

#include <torch/csrc/jit/ir/ir.h>

/* `getCustomPrePasses()` returns a vector of passes that will be executed
 * after differentiation but before any fusion. This is the de-facto location
 * for compiler backends to insert passes.
 *
 * `getCustomPostPasses()` returns a vector of passes that will be
 * executed after differentiation and after fusion (if any). This is the
 * location for fusion cleanup passes if they are needed.
 *
 * Static registration of a pass can be done by creating a global
 * `Register{Pre,Post}Pass r(Pass)` variable in a compilation unit.
 *
 * pass_manager.h uses a Meyer's singleton to store a vector of `Pass`es, which
 * modify the IR graph in place.
 */

namespace torch {
namespace jit {

// A pass modifies a Graph in place.
using GraphPass = std::function<void(std::shared_ptr<Graph>&)>;

// Since Passes are std::functions, we associate a UUID to each pass, this way
// if we want to deregister a pass, we have something to reference it by.
using GraphPassNameType = unsigned int;
// Start UUID at 1
static GraphPassNameType graphPassID = 1;

// Graph pass entries have a name associated with them
using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>;

// Return currently registered passes. Passes are stored in a static vector
TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
getCustomPostPasses();
TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
getCustomPrePasses();

TORCH_API GraphPassNameType registerPostPass(GraphPass p);
TORCH_API GraphPassNameType registerPrePass(GraphPass p);

// Look up pass by name passed in, remove it from registered passes
TORCH_API void clearPostPass(GraphPassNameType p);
TORCH_API void clearPrePass(GraphPassNameType p);

// Remove all passes
TORCH_API void clearAllPostPasses();
TORCH_API void clearAllPrePasses();

// LEGACY CALL
struct TORCH_API RegisterPostPass {
  RegisterPostPass(GraphPass p);
};

using RegisterPass = RegisterPostPass;

/*
 * PassManager is a wrapper on the register/clear PostPass functions above. It
 * will register the pass provided in "registerPass" and will hold on to its
 * associated name that way clearPass can be later called and will delete the
 * pass used to register when called.
 *
 * PassManager is templated because we want static variables based on a
 * particular GraphPass. When deriving from PassManager, you should send as the
 * template parameter your derived class as you would for the curiously
 * recurring template pattern. This template parameter isn't actually used and
 * is simply done to prevent static members from being shared across derived
 * types.
 */
template <typename DerivedType>
struct C10_EXPORT PassManager {
 private:
  // We want this class to be abstract because it's
  virtual void abstract() = 0;

 protected:
  /*
   * isRegistered() will return if a pass has been registered
   * isRegistered(true) will change the value of the internal static bool
   *
   * There's an internal static bool to this function to keep track of the
   * state, this is so when functions are derived from this class, they don't
   * have to worry about initializing the static members.
   */
  static bool isRegistered(bool flip_bit = false) {
    static bool val = false;
    if (flip_bit)
      val = !val;
    return val;
  }

  /*
   * name() will return the name of the registered pass
   * name(pass_name, true) will set the name of the pass
   * Similarly to isRegistered we use an internal static variable to hold the
   * name.
   */
  static GraphPassNameType passID(
      GraphPassNameType PassID = 0,
      bool set = false) {
    static GraphPassNameType pass_id = 0;
    if (set)
      pass_id = PassID;
    return pass_id;
  }

 public:
  // registerPass(pass) will register the pass provided and set the
  // name/isRegistered functions appropriately, it returns a bool value
  // indicating whether the given pass is already registered previously.
  static bool registerPass(GraphPass p) {
    if (!isRegistered()) {
      // If we don't already have a registered pass, register pass
      // hold on to its name, change isRegistered to true
      passID(registerPostPass(std::move(p)), true);
      isRegistered(true);
      return false;
    }
    return true;
  }

  // Calls ClearPostPass(passID())
  static void clearPass() {
    // If the pass is registered, clear it and change isRegistered to false.
    if (isRegistered()) {
      clearPostPass(passID());
      isRegistered(true);
    }
  }

  // clang-tidy requires virtual destructor;
  virtual ~PassManager() = default;
};

} // namespace jit
} // namespace torch