#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/graph.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
/**
* The Transform Base Object
*
* A Transform is an operation which manipulates a Caffe2 NetDef.
* You can consider it as a function: Transform.ApplyTo(NetDef) -> NetDef
*
* A Transform Operation does 4 things:
* 1) Creates a Graph object from a NetDef, which stores connections.
* 2) Pattern Matches on the Graph, to find subgraphs it wants to change.
* 3) Replaces the subgraphs that it's matched with new operators.
* 4) Creates a NetDef from the changed Graph, and returns it.
*
* The effect of a Transform is defined by its 3 protected virtual functions.
* 1) PatternRule determines for an ordered subgraph and a node, whether to
* consider adding the node to the subgraph.
* 2) ValidatorRule determines, for an ordered subgraph, whether it is a
* match.
* 3) ReplaceRule mutates the graph, based on a matched subgraph.
*
* This is the base class for all derived classes to base off. To create your
* own transform, write your implementations for PatternRule, ValidatorRule, and
* ReplaceRule.
*/
class TORCH_API Transform {
public:
Transform() {}
/**
* Apply a Transform onto a NetDef.
* Returns the transformed NetDef.
*/
NetDef ApplyTo(const NetDef& orig_net_def);
virtual ~Transform() {}
/**
* Determines the type of subgraphs that PatternMatch will find.
*
* CONNECTED_SUBGRAPH will only match subgraphs that are connected.
* These subgraphs satisfy that every node of the match is connected to the
* subgraph of the nodes that come before it.
* For example, in the graph (1) --> (2) --> (3) --> (4),
* This is capable of matching the subgraph [2, 3] and [4, 3]
* This is not capable of matching the subgraph [2, 4].
*
*
* SORTED_WRT_EXECUTION_ORDER will match subgraphs that guarantee
* sorted execution order.
* The nodes don't have to be connected. It is faster than General.
* For example, in the graph (1) --> (2) --> (3) --> (4),
* This is capable of matching the subgraph [2, 4], [3, 4].
* This is not capable of matching the subgraph [3, 1], [4, 3].
*
*
* GENERAL can match any subgraph.
* For example, in the graph (1) --> (2) --> (3) --> (4),
* This is capable of matching subgraphs [2, 4], [3, 4], [4, 2, 1].
* There is no ordered subgraph of G that cannot be matched by this.
*/
enum PatternMatchType {
CONNECTED_SUBGRAPH,
SORTED_WRT_EXECUTION_ORDER,
GENERAL
};
/**
* Generates all matches (stored as ordered subgraphs) and returns them.
*
* A match is stored as vector<int>, which is a mapping to OperatorDefs
* in Graph. The order matters.
*/
std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph);
/**
* Applies the replace rule onto each of the matches found.
*/
void ReplacePattern(
const std::vector<std::vector<int>>& matches,
transform::Graph* graph);
protected:
/**
* The PatternRule essentially answers:
* Given the current subgraph (ordered), should we append the new node at idx?
*/
virtual bool PatternRule(
const transform::Graph& g,
const std::vector<int>& subgraph,
int /*idx*/) {
CAFFE_NOT_IMPLEMENTED;
}
/**
* The ValidatorRule essentially answers:
* Given a subgraph, can we accept it?
*/
virtual bool ValidatorRule(
const transform::Graph& g,
const std::vector<int>& subgraph) {
CAFFE_NOT_IMPLEMENTED;
}
/**
* The ReplaceRule actually mutates the graph, and applies the transformation
* upon the subgraph.
*/
virtual bool ReplaceRule(
const std::vector<int>& subgraph,
transform::Graph* g_ptr) {
CAFFE_NOT_IMPLEMENTED;
}
void SetPatternMatchType(PatternMatchType type) {
pattern_match_type_ = type;
}
private:
/**
* A helper function for PatternMatch, which keeps track of the best subgraph
* so far.
*/
void PatternMatchHelper(
const transform::Graph& graph,
const std::vector<bool>& matched,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
/**
* Attempts to append each neighbor to the end of the subgraph.
*/
void TryNeighbors(
const transform::Graph& graph,
const std::map<int, std::vector<string>>& neighbors,
const std::vector<bool>& matched,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
PatternMatchType pattern_match_type_ = CONNECTED_SUBGRAPH;
};
// Creates a Transform based on a key, which should be defined in registry.
TORCH_API unique_ptr<Transform> CreateTransform(string key);
C10_DECLARE_REGISTRY(TransformRegistry, Transform);
#define REGISTER_TRANSFORM(name, ...) \
C10_REGISTER_CLASS(TransformRegistry, name, __VA_ARGS__)
// Create a Transform object from registry,
// and immediately apply it to a Netdef.
TORCH_API NetDef ApplyTransform(const string& key, const NetDef& netdef);
// Create a Transform object from registry, apply it to a NetDef.
// Will only return the transformed net if it is faster than the old net.
// This will run the init net first, will run the two nets warmup_runs times.
// Then, we will take the average time of main_runs runs, and only keep the
// transformed net if it is faster by a factor of improvement_threshold.
TORCH_API NetDef ApplyTransformIfFaster(
const string& key,
const NetDef& netdef,
const NetDef& init_netdef,
const int warmup_runs,
const int main_runs,
const double improvement_threshold);
} // namespace