Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / opt / backend_transformer_base.h

#pragma once

#include "caffe2/core/common.h"
#include "caffe2/core/workspace.h"
#include "caffe2/opt/bound_shape_inferencer.h"
#include "caffe2/proto/caffe2_pb.h"

#include <string>
#include <unordered_map>
#include <vector>

namespace caffe2 {
namespace {
constexpr char kNetPos[] = "net_pos";
constexpr char kModelId[] = "model_id";
} // namespace

struct BackendTransformOptions {
  explicit BackendTransformOptions() : bound_shape_spec(0, 0) {}

  // Enable debugging by dumping more intermediate graphs
  bool debug{false};

  // Minimum number of ops to create a backend op. If the subgraph is too
  // small, it doesn't make sense to lower it to backend.
  size_t min_ops{1};

  // Bound shape spec
  BoundShapeSpec bound_shape_spec;
};

// Wrap TensorShape into TensorProto
TensorProto wrapShapeInfoIntoTensorProto(
    const std::string& name,
    const ShapeInfo& shape_info);

// Wrap Quantized TensorShape into QTensorProto
QTensorProto wrapShapeInfoIntoQTensorProto(
    const std::string& name,
    const ShapeInfo& shape_info);

// This class contains some common functions for backend lowering and graph
// cutting
class BackendTransformerBase {
 public:
  BackendTransformerBase() {}
  virtual ~BackendTransformerBase() {}

  const std::unordered_map<std::string, std::string>& input_mapping() const {
    return input_mapping_;
  }

  const std::unordered_map<std::string, std::string>& reverse_input_mapping()
      const {
    return reverse_input_mapping_;
  }

  virtual void transform(
      Workspace* ws,
      NetDef* pred_net,
      const std::vector<std::string>& weight_names,
      const ShapeInfoMap& shape_hints,
      const std::unordered_set<int>& blocklisted_ops) = 0;

  static void annotateOpIndex(NetDef* net);

  // Get model ID from the NetDef
  static std::string getModelId(const NetDef& net);

 protected:
  // add shape info to the net
  void addShapeToNet(NetDef& shape_net, const ShapeInfoMap& shape_hints) const;

  // Dump the net with shape info
  void dumpNet(
      const NetDef& pred_net,
      const ShapeInfoMap& map,
      const std::string& fname) const;

  // SSA rewrite the net and return name mapping
  ShapeInfoMap ssaRewriteAndMapNames(
      Workspace* ws,
      NetDef* pred_net,
      const ShapeInfoMap& input_shape_hints);

  // Do bound shape inference and collect shape infos
  ShapeInfoMap inferShapes(
      Workspace* ws,
      NetDef* pred_net,
      const ShapeInfoMap& shape_hints_mapped,
      const BoundShapeSpec& spec);

  // Input mapping of input name -> original input name
  std::unordered_map<std::string, std::string> input_mapping_;

  // Input mapping of original input name -> input name
  std::unordered_map<std::string, std::string> reverse_input_mapping_;
};
} // namespace caffe2