Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / opt / tvm_transformer.h

#pragma once

#include "caffe2/opt/backend_transformer_base.h"

#include <unordered_set>

namespace caffe2 {

struct TvmTransformOptions final : public BackendTransformOptions {
  explicit TvmTransformOptions() : BackendTransformOptions() {}

  //  Whether to enable profiling based jit
  bool profiling_based_jit{false};
};

class TORCH_API TvmTransformer final : public BackendTransformerBase {
 public:
  explicit TvmTransformer(const TvmTransformOptions& opts)
      : BackendTransformerBase(), opts_(opts) {}

  ~TvmTransformer() override {}

  // Given workspace and predict net, cluster continuous parts that can be run
  // by TVM and create one TVMJit op for each clustered subgraph.
  // \param ws c2 workspace
  // \param pred_net c2 predict net
  // \param weight_names list of the names of the constant weights
  // \param shape_hints User provided shape info, usually for primary inputs so
  // that bound shape inference can have something to start
  // \param blocklisted_ops a set of ops that we don't want to lower to TVM in
  // terms of their net positions. This is very useful for debugging but for
  // normal runs it should be empty
  void transform(
      Workspace* ws,
      NetDef* pred_net,
      const std::vector<std::string>& weight_names,
      const ShapeInfoMap& shape_hints,
      const std::unordered_set<int>& blocklisted_ops) override;

  static const std::unordered_set<std::string>& getSupportedOps();

  static bool canConvertFullGraph(
      const caffe2::NetDef& net,
      const std::unordered_set<int>& blocklisted_ops);

 private:
  // Given TVM runnable subnets, contract them into one TVMJitOp
  NetDef buildTvmOp(
      const caffe2::NetDef& net,
      const std::unordered_set<std::string>& weights,
      const ShapeInfoMap& shape_hints);

  // Apply transform to cluster connected TVM runnable ops into one TVMJitOp
  NetDef applyTvmTransform(
      NetDef* pred_net,
      const std::unordered_set<std::string>& weights,
      const std::unordered_set<int>& blocklisted_ops,
      const ShapeInfoMap& shape_hints);

  // Options
  TvmTransformOptions opts_;

  // Track number of TVMJitOp we created
  int tvm_op_id_{0};

  // Model id
  std::string model_id_;
};

// Helper function to clean up a net and run tvm transform.
TORCH_API void tvmTransform(
    NetDef* net,
    Workspace* ws,
    const std::vector<std::string>& input_names,
    const std::vector<std::string>& output_names,
    const std::vector<std::string>& weight_names,
    const ShapeInfoMap& shape_hints,
    const std::unordered_set<int>& blocklisted_ops,
    int32_t max_batch_size,
    int32_t max_seq_size,
    int32_t num_embeddings,
    int32_t embedding_size,
    int32_t tvm_min_ops,
    bool tvm_profiling_based_jit,
    bool debug);

TORCH_API void cleanUpPredictNet(
    NetDef* net,
    const std::vector<std::string>& input_names,
    const std::vector<std::string>& output_names,
    const std::vector<std::string>& weight_names);

} // namespace caffe2