Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / c10d / logger.hpp
Size: Mime:
#include <c10/util/Logging.h>
#include <c10d/reducer.hpp>

#include <mutex>

namespace c10d {

class TORCH_API Logger {
 public:
  explicit Logger(std::shared_ptr<c10d::Reducer> reducer);
  // Set logging data that can be got during DistributedDataParallel
  // construction time.
  void set_construction_data_and_log(
      const std::string& module_name,
      const std::vector<int>& device_ids,
      int output_device,
      bool broadcast_buffers,
      bool has_sync_bn,
      bool static_graph
  );

  void set_static_graph();

  // An interface for users to get DDPLoggingData and log them
  // in the applications. Explanation of logging fields are in
  // "struct DDPLoggingData" of "torch/c10/util/Logging.h".
  at::DDPLoggingData get_ddp_logging_data();

  // Stream insertion operator for logging data to stream under
  // TORCH_DISTRIBUTED_DEBUG.
  friend std::ostream& operator<<(std::ostream& output, const Logger& logger);

  ~Logger() noexcept(false) {
    // Log if DDP graph is static in Logger dtor instead of Reducer dtor since
    // Logger is deleted before Reducer.
    log_if_graph_static(reducer_->ddp_graph_static());
  }

  // Set environment variables.
  void set_env_variables();
  // Set parameters stats.
  void set_parameter_stats();
  // Get size of each bucket (Bytes).
  std::vector<int64_t> get_bucket_sizes();
  // Get variable indices for each bucket.
  std::vector<std::vector<size_t>> get_per_bucket_variable_indices();
  // Set comm. hook, if used
  void set_comm_hook(const std::string& hook);
  // Set running with uneven input detection (model.join() context manager)
  void set_uneven_input_join();

  // Reset performance stats at current iteration
  void reset_performance_stats();

  // Calculate avg stats using cpu timer and gpu timer
  // that has been recorded in reducer.
  void calculate_avg_time(
      int64_t& avg_time,
      int64_t& time_duration,
      Timer& timer,
      Timer::Event start_event,
      Timer::Event end_event);

  // Set the absolute time of the event that has been recorded in reducer.
  void set_event_time(
    int64_t& event_time,
    Timer& timer,
    Timer::Event event
  );
  // Set stats that can be collected only during
  // training loop. It is called at the beginning of forward call
  // to record the run time stats of sampled iterations that previouly ran.
  // GPU performance stats are collected only for single process
  // single device program and single device module right now.
  // TODO to support single process multiple devices and multi device modules,
  // events need to be created and recorded on multiple devices.
  void set_runtime_stats_and_log();

  // Called when DDP/reducer is failing with an error. The
  // logging data structure will have two fields filled: "has_error" indicating
  // that this iteration encountered an error and other fields are not valid,
  // and "error", a string which contains the error message that DDP failed
  // with.
  template <typename... Args>
  void set_error_and_log(const std::string& ddp_error, const Args&... args) {
    ddp_logging_data_->ints_map["has_error"] = 1;
    auto err = c10::str(ddp_error, args...);
    ddp_logging_data_->strs_map["error"] = err;
    // Report the iteration we are erroring at so user knows how many examples
    // successfully processed before this error was hit.
    ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_;
    at::LogPyTorchDDPUsage(*ddp_logging_data_);
  }

  // When running without static graph, called when reducer is destroyed to log
  // if graph was actually static and is a candidate for static graph
  // optimization.
  void log_if_graph_static(bool is_static);


 private:
  // ddp_logging_data_ is used to hold all the ddp related logging
  // data fields.
  std::unique_ptr<at::DDPLoggingData> ddp_logging_data_;
  std::shared_ptr<c10d::Reducer> reducer_;
  // track the number of iterations when runtime stats are collected so far.
  long num_iterations_stats_recorded_ = 0;
};

} // namespace c10d