Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / graph.h

#pragma once

#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"

#include <algorithm>
#include <unordered_map>
#include <unordered_set>

namespace caffe2 {

namespace transform {

/**
 *  Graph representation of an operator.
 */
struct TORCH_API Node {
 public:
  // Empty constructor for resize
  Node() {}

  // Alternate constructor
  Node(
      const OperatorDef& op,
      bool active,
      std::map<int, std::vector<string>> parents,
      std::map<int, std::vector<string>> children)
      : op(op), active(active), parents(parents), children(children) {}

  // The OperatorDef which this node represents.
  OperatorDef op;

  // Keeps track of if an operator has been deleted through a transformation.
  bool active = true;

  // Stores a pair (idx, blob_list),
  //  idx = index of the child
  //  blob_list = a list of strings, containing the blobs that connect the nodes
  std::map<int, std::vector<string>> parents;
  std::map<int, std::vector<string>> children;
};

/**
 *  Graph representation of a Netdef.
 */
struct TORCH_API Graph {
 public:
  /**
   * Given a subgraph, gets all of the parents of the subgraph, as well as
   * their associated blob names. Sorted by blob names.
   *
   * <string, int> := (name of blob writing into subgraph,
   *                  index of node that writes into subgraph using that blob)
   */
  const std::vector<std::pair<string, int>> GetSubgraphInput(
      const std::vector<int>& subgraph);

  /**
   * Given a subgraph, gets all of the children of the subgraph, as well as
   * their associated blob names. Sorted by blob names.
   *
   * <string, int> := (name of blob reading from subgraph,
   *                  index of node that reads from subgraph using that blob)
   */
  const std::vector<std::pair<string, int>> GetSubgraphOutput(
      const std::vector<int>& subgraph);

  /**
   * Graph generation.
   * Given a netdef, returns a Graph.
   *
   * Each node represents an operator.
   * An edge exists between two nodes if the parent op writes to a blob, which
   * is the input of the child blob, with no other op writing to the blob in
   * between the execution order.
   *
   * Time Complexity: O(E), where E is the number of blobs
   */
  explicit Graph(const NetDef& net_def);

  /**
   * Generates a NetDef Representation for the current graph.
   * Nodes are visited in topological order, which is proper Opdef ordering.
   * TODO(benz):
   * There exists conflicts with repeated blob names, where topological sorting
   * is not sufficient for correct netdef representation, unless blobs are
   * renamed.
   * For example, if after a transformation, We have operator ancestry:
   * A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the
   * same blob name, then A, B, D, E, C is a correct topological ordering,
   * but D will write to the blob that C reads from, instead of B.
   * Currently believe that there will always be ambiguity unless blobs are
   * renamed.
   * This is solved by performing SSA on all transformed blob names.
   */
  NetDef GetNetDef();

  /**
   * Deactivate a subgraph, and get rid of all edges into this subgraph.
   */
  void DeactivateSubgraph(std::vector<int> subgraph);

  size_t size() const {
    return nodes_.size();
  }

  void push_node(const Node& new_node) {
    return nodes_.push_back(new_node);
  }

  void resize_nodes(size_t new_size) {
    nodes_.resize(new_size);
  }

  // Index safe, less verbose way to access nodes
  inline const Node& node(size_t idx) const {
    return nodes_.at(idx);
  }

  inline Node& node(size_t idx) {
    return nodes_.at(idx);
  }

  inline bool is_node_active(size_t idx) {
    return node(idx).active;
  }

  inline const std::set<string>& external_input() const {
    return external_input_;
  }

  inline const std::set<string>& external_output() const {
    return external_output_;
  }

 private:
  const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
      bool from_children,
      const std::vector<int>& match);

  // Stores the netdef representation. Is updated upon calls to GetNetDef.
  NetDef netdef_;

  // Stores which blobs the graph reads from, and writes to.
  std::set<string> external_input_;
  std::set<string> external_output_;

  // Keeps track of all the Operators currently within graph, even if inactive.
  std::vector<Node> nodes_;
};

} // namespace transform

// Adds an operator def to a netdef.
// Returns the ptr, if you want to add anything extra (such as device_option)
TORCH_API OperatorDef* AddOp(
    NetDef* netdef_ptr,
    string op_type,
    std::vector<string> inputs,
    std::vector<string> outputs);

/**
 * This allows for the use of * and | to match operator types,
 * engines, or any other property that is represented by strings.
 *
 * For example, if we wanted to match an operator to Conv or FC, we can give:
 * "Conv|FC" as the type() of that op.
 */
TORCH_API bool MatchStrings(string p, string s);

/**
 * This ensures that each named arg that exists in the pattern exists in g_op,
 * is equal in value.
 */
TORCH_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);

} // namespace caffe2