Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / caffe2 / onnx / onnx_exporter.h

#pragma once

#include "caffe2/core/common.h"
#include "caffe2/core/tensor.h"
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2_pb.h"
#include "onnx/onnx_pb.h"

#include <string>
#include <unordered_map>
#include <vector>

namespace caffe2 {
namespace onnx {

namespace {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::GraphProto;
using ::ONNX_NAMESPACE::ModelProto;
using ::ONNX_NAMESPACE::NodeProto;
using ::ONNX_NAMESPACE::TensorProto;
} // namespace

using ConvertedResult =
    std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;

// Useful utility function
void rewriteSubnet(
    Argument* arg,
    std::map<std::string, std::string> oldname_to_newname);

// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
// output names for predict net.
TORCH_API std::unordered_map<std::string, std::string> SsaRewrite(
    caffe2::NetDef* init_net,
    caffe2::NetDef* pred_net,
    bool PreserveInPlaceOps = true);

::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
    caffe2::TensorProto::DataType t);

class TORCH_API OnnxExporter {
  using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
      const caffe2::OperatorDef&,
      const std::unordered_map<std::string, caffe2::TensorShape>&);

 public:
  OnnxExporter(DummyName* dummy = nullptr) {
    if (dummy) {
      dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
    } else {
      dummy_ = std::make_shared<DummyName>();
    }
  }

  ConvertedResult Caffe2OpToOnnxNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);

 private:
  ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);

  ConvertedResult CreateArgMaxMinOpNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateBinaryElementwiseOpNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateCastNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateElementwiseLinearNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateConvPoolNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateGemmNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateReshapeNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateSliceNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateChannelShuffleNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateReduceMeanNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateConcatNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateMergeDimNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateLrnNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  ConvertedResult CreateUpsampleNodes(
      const caffe2::OperatorDef& def,
      const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

  // \brief Check block listed arguments where we won't pass down when
  // converting to ONNX node
  bool IsBlockListed(const caffe2::Argument& arg);

  // \brief Convert Caffe2 argument to Onnx attribute
  void CopyCaffe2ArgToOnnxAttr(
      AttributeProto* attr,
      const std::string& op_type,
      const caffe2::Argument& arg);

  // LUT getters
  const std::unordered_map<std::string, std::string>& get_renamed_operators()
      const;
  const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
  const std::
      unordered_map<std::string, std::unordered_map<std::string, std::string>>&
      get_per_op_renamed_attrs() const;
  const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
  get_special_operators() const;

  // Dummy name generator
  std::shared_ptr<DummyName> dummy_;
};
} // namespace onnx
} // namespace caffe2