#ifndef CAFFE2_OPT_CONVERTER_H
#define CAFFE2_OPT_CONVERTER_H
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/opt/annotations.h"
#include "caffe2/proto/caffe2_pb.h"
#include "nomnigraph/Graph/Graph.h"
#include "nomnigraph/Representations/ControlFlow.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include <unordered_map>
namespace caffe2 {
TORCH_API void injectDataEdgeIndicators(caffe2::NetDef* net);
TORCH_API void removeDataEdgeIndicators(caffe2::NetDef* net);
// Default conversion to a NNModule
// Optionally strict -- which checks for various input and output conditions.
// Optionally this function will update a vector that maps operators in the
// netdef positionally to NodeRefs in the resultant NNModule.
TORCH_API nom::repr::NNModule convertToNNModule(
const caffe2::NetDef& net,
bool strict = false,
std::vector<nom::repr::NNGraph::NodeRef>* = nullptr);
TORCH_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&);
// Pass in an oldNet to copy all the attributes of that network.
// Be warned that transformations that modify the graph's inputs or outputs
// are not reflected in changes to external_input or external_output.
TORCH_API caffe2::NetDef convertToCaffe2Proto(
nom::repr::NNModule&,
const caffe2::NetDef& oldNet);
// Use these functions instead of the registry directly.
TORCH_API std::unique_ptr<nom::repr::NeuralNetOperator>
convertToNeuralNetOperator(const caffe2::OperatorDef& op);
TORCH_API caffe2::OperatorDef convertToOperatorDef(
const nom::repr::NNGraph::NodeRef& instrNode);
// If the annotation doesn't exist, attempt to add it
TORCH_API Caffe2Annotation* getOrAddCaffe2Annotation(
nom::repr::NNGraph::NodeRef& instrNode);
class TORCH_API Converter {
public:
explicit Converter() = default;
virtual std::unique_ptr<nom::repr::NeuralNetOperator>
convertToNeuralNetOperator(const OperatorDef&) = 0;
virtual OperatorDef convertToOperatorDef(const nom::repr::NeuralNetOperator*);
static std::map<std::string, caffe2::Argument> getArgumentsFromOperator(
caffe2::OperatorDef op);
virtual ~Converter() {}
protected:
caffe2::DeviceOption getDeviceOption(
const nom::repr::NeuralNetOperator* nnOp) const;
};
C10_DECLARE_REGISTRY(ConverterRegistry, Converter);
#define REGISTER_CONVERTER(name, cls) \
C10_REGISTER_CLASS(ConverterRegistry, name, cls)
#define TRIVIAL_CONVERTER(opName) \
class opName##Converter : public Converter { \
std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator( \
const OperatorDef& op) override { \
return nom::util::make_unique<nom::repr::opName>(); \
} \
virtual ~opName##Converter() {} \
};
} // namespace caffe2
#endif // CAFFE2_OPT_CONVERTER_H