Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / net_async_base.h

#ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
#define CAFFE2_CORE_NET_ASYNC_BASE_H_

#include <c10/macros/Macros.h>
#include "c10/core/thread_pool.h"
#include "c10/util/Registry.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/core/net_dag_utils.h"
#include "caffe2/core/prof_dag_counters.h"
#include "caffe2/core/stats.h"
#include "caffe2/core/timer.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/prof_dag.pb.h"
#include "caffe2/utils/proto_utils.h"

C10_DECLARE_int(caffe2_streams_per_gpu);
C10_DECLARE_int(caffe2_net_async_max_gpus);
C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
C10_DECLARE_int(caffe2_net_async_thread_pool_size);
C10_DECLARE_bool(caffe2_net_async_check_stream_status);
C10_DECLARE_bool(caffe2_net_async_use_single_pool);
C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
C10_DECLARE_bool(caffe2_net_async_run_root_tasks_inline);
C10_DECLARE_bool(caffe2_net_async_profile_operators);

namespace caffe2 {

class AsyncNetExecutorHelper;

namespace tracing {
class Tracer;
}

struct ExecutionOptions {
  explicit ExecutionOptions(const std::shared_ptr<const NetDef>& net_def);

  // number of gpu streams per gpu per cpu thread
  int streams_per_gpu_ = 1;
  // ops synchronization options
  bool finish_chain_ = false;
  bool always_schedule_child_ = false;
  // try to pick gpu stream that is not busy
  bool check_stream_status_ = false;
  // use single thread pool for all devices
  bool use_single_pool_ = false;
  // use per net instances thread pools instead of global ones
  bool use_per_net_pools_ = false;
  // whether RunAsync is blocking
  bool is_blocking_ = false;
  // prof_dag counters reporting
  bool report_stats_ = false;
  // immediately run children tasks inline whenever possible
  bool use_dfs_scheduling_ = false;
  // run net's root tasks in RunAsync thread instead of in thread pool
  bool run_root_tasks_inline_ = false;
};

struct TORCH_API AsyncNetCancelled : public std::exception {
  const char* what() const noexcept override {
    return "Cancelled";
  }
};

class TORCH_API AsyncNetBase : public NetBase {
 public:
  AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
  ~AsyncNetBase() override;

  bool SupportsAsync() override {
    return true;
  }

  vector<OperatorBase*> GetOperators() const override {
    return operators_;
  }

  bool RunAsync() override;

  const dag_utils::ExecutionChains& TEST_execution_chains() const {
    return execution_chains_;
  }

  ProfDAGProtos GetOperatorStats() const;
  ProfDAGProtos GetPerOperatorCost() const;
  ProfDAGReport GetProfReport() const;

 protected:
  bool canSchedule(
      int chain_id,
      const std::vector<EventStatus>* status = nullptr,
      bool* parent_failed = nullptr);
  bool canSchedule(int parent_id, int child_id);

  int tasksNum() const;
  Event& event(int task_id) const;
  EventStatus query(int task_id) const;
  const std::vector<int>& children(int task_id) const;
  const std::vector<int>& parents(int task_id) const;
  int updateParentCount(int child_id);
  int getParentCount(int child_id);
  bool testAndSetScheduled(int task_id);
  int numOps(int task_id) const;

  int firstTaskOpId(int task_id) const;
  int lastTaskOpId(int task_id) const;
  const OperatorBase* firstTaskOp(int task_id) const;
  const OperatorBase* lastTaskOp(int task_id) const;
  OperatorBase* firstTaskOp(int task_id);
  OperatorBase* lastTaskOp(int task_id);

  void asyncWait(
      int task_id,
      int stream_id,
      const std::vector<int>& wait_task_ids) const;
  bool run(int task_id, int stream_id) noexcept;
  int stream(int task_id);
  TaskThreadPoolBase* pool(const DeviceOption& device_option);
  TaskThreadPoolBase* pool();

  void finishTasks(const std::unordered_set<int>& task_ids);
  void finalizeEvents();

  bool isStreamFree(int task_id, int stream_id) const;

  virtual void reset();

  bool handleRunError() override;

  // Operator/task graph
  std::vector<OperatorBase*> operators_;
  std::vector<dag_utils::OperatorNode> operator_nodes_;
  std::vector<std::vector<int>> chains_;
  std::vector<dag_utils::OpGraphNode> chain_nodes_; // chains' parents/children
  dag_utils::ExecutionChains execution_chains_; // for testing

  // Pools and streams
  std::mutex pools_mutex_;
  // first int key - device id, second - pool size, one pool per (device, size)
  typedef std::unordered_map<
      int,
      std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
      PoolsMap;
  PoolsMap cpu_pools_;
  PoolsMap gpu_pools_;
  static std::vector<int>& getStreamCounters();
  int num_workers_;

  // Exception/error handling
  void handleChainError(
      int task_id,
      OperatorBase* op,
      const char* err_msg,
      bool save_exception = false) noexcept;
  std::atomic<bool> success_;

  // Tracing
  std::shared_ptr<tracing::Tracer> tracer_;

  // execution mode flags
  ExecutionOptions options_;

  ProfDAGCounters counters_;

  C10_DISABLE_COPY_AND_ASSIGN(AsyncNetBase);

 private:
  TaskThreadPoolBase*
  poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);

  std::unique_ptr<AsyncNetExecutorHelper> helper_;

  friend class AsyncNetExecutorHelper;
  friend class tracing::Tracer;
};

class AsyncNetExecutorHelper : public ExecutorHelper {
 public:
  explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}
  TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
    return net_->pool(option);
  }

 private:
  AsyncNetBase* net_;
};

template <class TaskThreadPoolImpl, int device_type>
std::shared_ptr<TaskThreadPoolBase>
GetAsyncNetThreadPool(int device_id, int pool_size, bool create_new) {
  static std::unordered_map<
      int,
      std::unordered_map<int, std::weak_ptr<TaskThreadPoolBase>>>
      pools;
  static std::mutex pool_mutex;

  const auto& device_type_name = DeviceTypeName(device_type);

  if (pool_size <= 0) {
    if (FLAGS_caffe2_net_async_thread_pool_size > 0) {
      pool_size = FLAGS_caffe2_net_async_thread_pool_size;
      LOG(INFO) << "Using default " << device_type_name
                << " pool size: " << pool_size << "; device id: " << device_id;
    } else {
      auto num_cores = std::thread::hardware_concurrency();
      CAFFE_ENFORCE(num_cores > 0, "Failed to get number of CPU cores");
      LOG(INFO) << "Using estimated " << device_type_name
                << " pool size: " << num_cores << "; device id: " << device_id;
      pool_size = num_cores;
    }
  } else {
    LOG(INFO) << "Using specified " << device_type_name
              << " pool size: " << pool_size << "; device id: " << device_id;
  }

  if (create_new) {
    LOG(INFO) << "Created new " << device_type_name
              << " pool, size: " << pool_size << "; device id: " << device_id;
    return std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
  } else {
    std::lock_guard<std::mutex> lock(pool_mutex);

    auto shared_pool = pools[device_id][pool_size].lock();
    if (!shared_pool) {
      LOG(INFO) << "Created shared " << device_type_name
                << " pool, size: " << pool_size << "; device id: " << device_id;
      shared_pool = std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
      pools[device_id][pool_size] = shared_pool;
    }
    return shared_pool;
  }
}

} // namespace caffe2

#endif // CAFFE2_CORE_NET_ASYNC_BASE_H_