Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / net_async_tracing.h

/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef CAFFE2_CORE_NET_ASYNC_TRACING_H_
#define CAFFE2_CORE_NET_ASYNC_TRACING_H_

#include "caffe2/core/common.h"
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"

C10_DECLARE_string(caffe2_net_async_tracing_filepath);
C10_DECLARE_string(caffe2_net_async_names_to_trace);
C10_DECLARE_int(caffe2_net_async_tracing_nth);

namespace caffe2 {
namespace tracing {

struct TORCH_API TracerEvent {
  int op_id_ = -1;
  int task_id_ = -1;
  int stream_id_ = -1;
  const char* name_ = nullptr;
  const char* category_ = nullptr;
  long timestamp_ = -1.0;
  bool is_beginning_ = false;
  long thread_label_ = -1;
  std::thread::id tid_;
  int iter_ = -1;
};

enum TracingField {
  TRACE_OP,
  TRACE_TASK,
  TRACE_STREAM,
  TRACE_THREAD,
  TRACE_NAME,
  TRACE_CATEGORY,
  TRACE_ITER,
};

enum class TracingMode {
  EVERY_K_ITERATIONS,
  GLOBAL_TIMESLICE,
};

struct TracingConfig {
  TracingMode mode{TracingMode::EVERY_K_ITERATIONS};
  std::string filepath{"/tmp"};

  // for TracingMode::EVERY_K_ITERATIONS
  int64_t trace_every_nth_batch = 100;
  int64_t dump_every_nth_batch = 10000;

  // for TracingMode::GLOBAL_TIMESLICE
  int64_t trace_every_n_ms = 2 * 60 * 1000; // 2min
  int64_t trace_for_n_ms = 1000; // 1sec
};

class TORCH_API Tracer {
 public:
  Tracer(
      const NetBase* net,
      const std::string& net_name,
      TracingConfig = TracingConfig{});

  void recordEvent(const TracerEvent& event);
  std::string opTraceName(const OperatorBase* op);
  std::string opBlobsInfo(const OperatorBase& op);
  std::string serializeEvent(const TracerEvent& event);
  void linearizeEvents();
  void renameThreads();
  void setEnabled(bool enabled);
  bool isEnabled() const;
  const TracingConfig& config() {
    return config_;
  }
  int bumpIter();
  int getIter();
  int bumpDumpingIter();
  // Dump the tracing result to file with given suffix, and then
  // clear current events.
  void dumpTracingResultAndClearEvents(const std::string& file_suffix);

  virtual ~Tracer();

 private:
  const NetBase* net_ = nullptr;
  std::string filename_;
  std::vector<TracerEvent> events_;
  std::mutex tracer_mutex_;
  bool enabled_ = false;
  Timer timer_;
  int iter_;
  int dumping_iter_;
  TracingConfig config_;

  friend class TracerGuard;
};

class TORCH_API TracerGuard {
 public:
  TracerGuard() {}

  void init(Tracer* tracer);

  void addArgument();
  void addArgument(TracingField field, const char* value);
  void addArgument(TracingField field, int value);

  template <typename T, typename... Args>
  void addArgument(TracingField field, const T& value, const Args&... args) {
    addArgument(field, value);
    addArgument(args...);
  }

  void recordEventStart();

  virtual ~TracerGuard();

  static TracerGuard* getCurrentTracerGuard();
  void disable();

 private:
  bool enabled_ = false;
  TracerEvent event_;
  Tracer* tracer_;
};

// Extract the shard id from name of the form "...shard:123..."
// Return -1 if there is no shard found
TORCH_API int extractShardId(const std::string& name);

// Check if the net name is white-listed for tracing (specified via a command
// line flag)
TORCH_API bool isTraceableNetName(const std::string& net_name);

TORCH_API std::shared_ptr<Tracer> create(
    const NetBase* net,
    const std::string& net_name);
TORCH_API bool startIter(const std::shared_ptr<Tracer>& tracer);

} // namespace tracing

#define TRACE_NAME_CONCATENATE(s1, s2) s1##s2
#define TRACE_ANONYMOUS_NAME(str) TRACE_NAME_CONCATENATE(str, __LINE__)

#define TRACE_EVENT_INIT(...)                                 \
  TRACE_ANONYMOUS_NAME(trace_guard).init(tracer_.get());      \
  TRACE_ANONYMOUS_NAME(trace_guard).addArgument(__VA_ARGS__); \
  TRACE_ANONYMOUS_NAME(trace_guard).recordEventStart();

// Supposed to be used only once per scope in AsyncNetBase-derived nets
#define TRACE_EVENT(...)                                  \
  tracing::TracerGuard TRACE_ANONYMOUS_NAME(trace_guard); \
  if (tracer_ && tracer_->isEnabled()) {                  \
    TRACE_EVENT_INIT(__VA_ARGS__)                         \
  }

#define TRACE_EVENT_IF(cond, ...)                         \
  tracing::TracerGuard TRACE_ANONYMOUS_NAME(trace_guard); \
  if (tracer_ && tracer_->isEnabled() && (cond)) {        \
    TRACE_EVENT_INIT(__VA_ARGS__)                         \
  }

} // namespace caffe2

#endif // CAFFE2_CORE_NET_ASYNC_TRACING_H_