#pragma once
#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
namespace caffe2 {
namespace transform {
/**
* Graph representation of an operator.
*/
struct TORCH_API Node {
public:
// Empty constructor for resize
Node() {}
// Alternate constructor
Node(
const OperatorDef& op,
bool active,
std::map<int, std::vector<string>> parents,
std::map<int, std::vector<string>> children)
: op(op), active(active), parents(parents), children(children) {}
// The OperatorDef which this node represents.
OperatorDef op;
// Keeps track of if an operator has been deleted through a transformation.
bool active = true;
// Stores a pair (idx, blob_list),
// idx = index of the child
// blob_list = a list of strings, containing the blobs that connect the nodes
std::map<int, std::vector<string>> parents;
std::map<int, std::vector<string>> children;
};
/**
* Graph representation of a Netdef.
*/
struct TORCH_API Graph {
public:
/**
* Given a subgraph, gets all of the parents of the subgraph, as well as
* their associated blob names. Sorted by blob names.
*
* <string, int> := (name of blob writing into subgraph,
* index of node that writes into subgraph using that blob)
*/
const std::vector<std::pair<string, int>> GetSubgraphInput(
const std::vector<int>& subgraph);
/**
* Given a subgraph, gets all of the children of the subgraph, as well as
* their associated blob names. Sorted by blob names.
*
* <string, int> := (name of blob reading from subgraph,
* index of node that reads from subgraph using that blob)
*/
const std::vector<std::pair<string, int>> GetSubgraphOutput(
const std::vector<int>& subgraph);
/**
* Graph generation.
* Given a netdef, returns a Graph.
*
* Each node represents an operator.
* An edge exists between two nodes if the parent op writes to a blob, which
* is the input of the child blob, with no other op writing to the blob in
* between the execution order.
*
* Time Complexity: O(E), where E is the number of blobs
*/
explicit Graph(const NetDef& net_def);
/**
* Generates a NetDef Representation for the current graph.
* Nodes are visited in topological order, which is proper Opdef ordering.
* TODO(benz):
* There exists conflicts with repeated blob names, where topological sorting
* is not sufficient for correct netdef representation, unless blobs are
* renamed.
* For example, if after a transformation, We have operator ancestry:
* A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the
* same blob name, then A, B, D, E, C is a correct topological ordering,
* but D will write to the blob that C reads from, instead of B.
* Currently believe that there will always be ambiguity unless blobs are
* renamed.
* This is solved by performing SSA on all transformed blob names.
*/
NetDef GetNetDef();
/**
* Deactivate a subgraph, and get rid of all edges into this subgraph.
*/
void DeactivateSubgraph(std::vector<int> subgraph);
size_t size() const {
return nodes_.size();
}
void push_node(const Node& new_node) {
return nodes_.push_back(new_node);
}
void resize_nodes(size_t new_size) {
nodes_.resize(new_size);
}
// Index safe, less verbose way to access nodes
inline const Node& node(size_t idx) const {
return nodes_.at(idx);
}
inline Node& node(size_t idx) {
return nodes_.at(idx);
}
inline bool is_node_active(size_t idx) {
return node(idx).active;
}
inline const std::set<string>& external_input() const {
return external_input_;
}
inline const std::set<string>& external_output() const {
return external_output_;
}
private:
const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
bool from_children,
const std::vector<int>& match);
// Stores the netdef representation. Is updated upon calls to GetNetDef.
NetDef netdef_;
// Stores which blobs the graph reads from, and writes to.
std::set<string> external_input_;
std::set<string> external_output_;
// Keeps track of all the Operators currently within graph, even if inactive.
std::vector<Node> nodes_;
};
} // namespace transform
// Adds an operator def to a netdef.
// Returns the ptr, if you want to add anything extra (such as device_option)
TORCH_API OperatorDef* AddOp(
NetDef* netdef_ptr,
string op_type,
std::vector<string> inputs,
std::vector<string> outputs);
/**
* This allows for the use of * and | to match operator types,
* engines, or any other property that is represented by strings.
*
* For example, if we wanted to match an operator to Conv or FC, we can give:
* "Conv|FC" as the type() of that op.
*/
TORCH_API bool MatchStrings(string p, string s);
/**
* This ensures that each named arg that exists in the pattern exists in g_op,
* is equal in value.
*/
TORCH_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);
} // namespace caffe2