Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
torch / include / torch / csrc / distributed / c10d / ProcessGroupGloo.hpp
Size: Mime:
#pragma once

#ifdef USE_C10D_GLOO

#include <condition_variable>
#include <deque>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>

#include <gloo/algorithm.h>
#include <gloo/common/error.h>
#include <gloo/context.h>
#include <gloo/rendezvous/store.h>
#include <gloo/transport/device.h>

#include <c10/util/hash.h>

#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>

namespace c10d {

constexpr const char* GLOO_BACKEND_NAME = "gloo";

// ProcessGroupGloo implements Gloo bindings for c10d.
//
// All functions on this class are expected to be called in the same
// order across processes in the group. This is the only way that we
// can guarantee to match up the same calls across processes. For
// multi-threaded usage of process groups, you can use consider using
// multiple process group instances.
//
// The Gloo algorithms that this class calls into are cached by their
// signature (see description of AlgorithmKey above). This cache works
// as follows: every function call instantiates an AlgorithmKey and
// looks in the cache for existing entries. If there is one, it is
// removed from the cache and returned to the caller. If there are
// none, a new entry is created and returned. If an entry was created
// before, but is still in use, the call will block and wait until the
// entry is returned to the cache.
//
// In the future, we hope to extend this to allow multiple entries per
// key, to enable parallelism for a single key. The number of entries
// per key must always be identical for all processes. This maximum
// number can be automatically tuned, but only if we let a single
// process take charge, and have it broadcast the limits.
//
class TORCH_API ProcessGroupGloo : public Backend {
 public:
  // AsyncWork is the Gloo specific superclass for asynchronous work items.
  // We can split asynchronous work into 3 phases:
  // 1) Sanity checks and prepare input (e.g. memcpy)
  // 2) Run operation on background thread
  // 3) Synchronize with completion on foreground thread
  //
  // There is state to be shared between these 3 phases and all of this state
  // is captured in the AsyncWork class and its derivatives.
  //
  // Note: while we are porting operations to use new style collectives, there
  // is a split between operations using the existing caching approach and
  // operations using the new AsyncWork base class. Over time we will port
  // all operations and perform needed cleanup.
  //
  // FIXME: This probably should be called WorkGloo since the work is executed
  // in sync mode by a background thread.
  class TORCH_API AsyncWork : public Work {
   public:
    explicit AsyncWork(
        std::vector<std::vector<at::Tensor>> outputTensors,
        const char* profilingTitle = nullptr,
        const c10::optional<std::vector<at::Tensor>>& inputTensors =
            c10::nullopt);

    ~AsyncWork() override = default;

    static void execute(c10::intrusive_ptr<AsyncWork> work);

    virtual void run() = 0;

    std::vector<at::Tensor> result() override;

    c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

   protected:
    friend class ProcessGroupGloo;

   private:
    void finishWorkGloo();
    void finishWorkGlooError(std::exception_ptr eptr);
    inline void recordAsyncWorkProfilingInfo(
        const char* profilingTitle,
        const c10::optional<std::vector<at::Tensor>>& inputTensors);

    const std::vector<std::vector<at::Tensor>> outputTensors_;
    c10::intrusive_ptr<at::ivalue::Future> future_;
    std::function<void()> recordFunctionBeforeCallback_;
  };

  // Wrap c10d store as Gloo store
  class TORCH_API GlooStore : public ::gloo::rendezvous::Store {
   public:
    GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {}

    void setUint(const std::string& key, const std::vector<uint8_t>& value) {
      store_->set(key, value);
    }

    void set(const std::string& key, const std::vector<char>& value) override {
      std::vector<uint8_t> tmp(value.begin(), value.end());
      store_->set(key, tmp);
    }

    std::vector<uint8_t> getUint(const std::string& key) {
      auto value = store_->get(key);
      return value;
    }

    std::vector<char> get(const std::string& key) override {
      auto value = store_->get(key);
      return std::vector<char>(value.begin(), value.end());
    }

    void wait(const std::vector<std::string>& keys) override {
      store_->wait(keys, ::c10d::Store::kDefaultTimeout);
    }

    void wait(
        const std::vector<std::string>& keys,
        const std::chrono::milliseconds& timeout) override {
      store_->wait(keys, timeout);
    }

#ifdef GLOO_STORE_HAS_STORE_V2
    bool has_v2_support() override {
      return store_->hasExtendedApi();
    }

    std::vector<std::vector<char>> multi_get(
        const std::vector<std::string>& keys) override {
      std::vector<std::vector<char>> res;
      for (auto& value : store_->multiGet(keys)) {
        res.emplace_back(std::vector<char>(value.begin(), value.end()));
      }
      return res;
    }

    void multi_set(
        const std::vector<std::string>& keys,
        const std::vector<std::vector<char>>& values) override {
      std::vector<std::vector<uint8_t>> u_values;
      for (auto& value : values) {
        u_values.emplace_back(std::vector<uint8_t>(value.begin(), value.end()));
      }
      store_->multiSet(keys, u_values);
    }

    void append(const std::string& key, const std::vector<char>& value)
        override {
      std::vector<uint8_t> tmp(value.begin(), value.end());
      return store_->append(key, tmp);
    }

    int64_t add(const std::string& key, int64_t value) override {
      return store_->add(key, value);
    }
#endif

   protected:
    c10::intrusive_ptr<::c10d::Store> store_;
  };

  // For send and recv operations there is no need to pass them to the
  // thread pool as they are entirely completed by the device thread.
  // This work object is used to synchronize completion of the send or
  // recv operation. It keeps a reference to the tensor it is
  // operating on to prevent it from being deallocated while the
  // operation is still in flight.
  class TORCH_API SendWork : public Work {
   public:
    explicit SendWork(
        at::Tensor& tensor,
        std::unique_ptr<::gloo::transport::UnboundBuffer> buffer);

    bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;

    void abort() override;

   protected:
    at::Tensor tensor_;
    std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
  };

  class TORCH_API RecvWork : public Work {
   public:
    explicit RecvWork(
        at::Tensor& tensor,
        std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
        const char* profilingTitle = nullptr);

    int sourceRank() const override;

    bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;

    void abort() override;

   protected:
    at::Tensor tensor_;
    std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
    int srcRank_;
  };

  struct TORCH_API Options : public Backend::Options {
    explicit Options(
        std::chrono::milliseconds timeout = kBackendDefaultTimeout);

    // return intrusive_ptr of the object
    static c10::intrusive_ptr<Options> create(
        std::chrono::milliseconds timeout = kBackendDefaultTimeout) {
      return c10::make_intrusive<Options>(timeout);
    }

    std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
    int threads;
  };

  const std::string getBackendName() const override {
    return std::string(GLOO_BACKEND_NAME);
  }

  // Helper functions to create a new device object.
  // They are static functions on this class to keep them logically
  // separate from the rest of the code base (e.g. torch/csrc/distributed).

  // Create new device instance for specific interface.
  static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
      const std::string& interface);

  // Create new device instance for specific hostname or address.
  static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
      const std::string& hostname);

  // Create new device instance.
  // It tries to resolve this machine's hostname and bind to that address.
  // If that fails (i.e. the hostname doesn't resolve to an address), it
  // falls back to binding to the loopback address.
  static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();

  // Create ProcessGroupGloo instance.
  static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo(
      const c10::intrusive_ptr<Store>& store,
      int rank,
      int size,
      std::chrono::milliseconds timeout);

  explicit ProcessGroupGloo(
      const c10::intrusive_ptr<Store>& store,
      int rank,
      int size,
      c10::intrusive_ptr<Options> options = Options::create());

  ~ProcessGroupGloo() override;

  c10::intrusive_ptr<Options> getOptions() {
    return options_;
  }

  c10::intrusive_ptr<Work> broadcast(
      std::vector<at::Tensor>& tensors,
      const BroadcastOptions& opts = BroadcastOptions()) override;

  c10::intrusive_ptr<Work> allreduce(
      std::vector<at::Tensor>& tensors,
      const AllreduceOptions& opts = AllreduceOptions()) override;

  c10::intrusive_ptr<Work> allreduce_coalesced(
      std::vector<at::Tensor>& tensors,
      const AllreduceCoalescedOptions& opts =
          AllreduceCoalescedOptions()) override;

  c10::intrusive_ptr<Work> reduce(
      std::vector<at::Tensor>& tensors,
      const ReduceOptions& opts = ReduceOptions()) override;

  c10::intrusive_ptr<Work> allgather(
      std::vector<std::vector<at::Tensor>>& outputs,
      std::vector<at::Tensor>& inputs,
      const AllgatherOptions& opts = AllgatherOptions()) override;

  c10::intrusive_ptr<Work> _allgather_base(
      at::Tensor& outputBuffer,
      at::Tensor& inputBuffer,
      const AllgatherOptions& opts = AllgatherOptions()) override;

  c10::intrusive_ptr<Work> allgather_coalesced(
      std::vector<std::vector<at::Tensor>>& output_lists,
      std::vector<at::Tensor>& input_list,
      const AllgatherOptions& opts = AllgatherOptions()) override;

  c10::intrusive_ptr<Work> gather(
      std::vector<std::vector<at::Tensor>>& outputs,
      std::vector<at::Tensor>& inputs,
      const GatherOptions& opts = GatherOptions()) override;

  c10::intrusive_ptr<Work> scatter(
      std::vector<at::Tensor>& outputs,
      std::vector<std::vector<at::Tensor>>& inputs,
      const ScatterOptions& opts = ScatterOptions()) override;

  c10::intrusive_ptr<Work> reduce_scatter(
      std::vector<at::Tensor>& outputs,
      std::vector<std::vector<at::Tensor>>& inputs,
      const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

  c10::intrusive_ptr<Work> alltoall_base(
      at::Tensor& outputTensor,
      at::Tensor& inputTensor,
      std::vector<int64_t>& outputCounts,
      std::vector<int64_t>& inputCounts,
      const AllToAllOptions& opts = AllToAllOptions()) override;

  c10::intrusive_ptr<Work> send(
      std::vector<at::Tensor>& tensors,
      int dstRank,
      int tag) override;

  c10::intrusive_ptr<Work> recv(
      std::vector<at::Tensor>& tensors,
      int srcRank,
      int tag) override;

  c10::intrusive_ptr<Work> recvAnysource(
      std::vector<at::Tensor>& tensors,
      int tag) override;

  c10::intrusive_ptr<Work> barrier(
      const BarrierOptions& opts = BarrierOptions()) override;

  const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
    return store_;
  }

  // Similar to barrier(), but blocks rank 0 until all other ranks have
  // acknowledged that they are alive (through send/recv from rank 0). Rank 0
  // is able to report all failed ranks if waitAllRanks = true, otherwise
  // reports the first rank it detected as failed.
  void monitoredBarrier(
      const BarrierOptions& opts = BarrierOptions(),
      bool waitAllRanks = false) override;

  // Agrees on an initial sequence number for the whole group by having rank 0
  // create it and broadcast it to other ranks using the store.
  void setSequenceNumberForGroup() override;

  // Retrieves the current sequence number for the whole group, which should be
  // in sync. If the returned number is not consistent across the group, it
  // may indicate that there is some sort of collective desynchronization.
  uint64_t getSequenceNumberForGroup() override;

  int getNumThreads() {
    return options_->threads;
  }

 protected:
  std::unique_ptr<::gloo::rendezvous::Store> store_;
  const c10::intrusive_ptr<Options> options_;

  // Every Gloo context represents a set of connections to its peers.
  // In order to use more than one device (or allow for parallelism on
  // a single device), you need multiple contexts.
  std::vector<std::shared_ptr<::gloo::Context>> contexts_;
  std::vector<std::thread> threads_;
  bool stop_;

  // Incremented for every collective we kick off.
  // The value is used as tag for collective operations. Collectives are kicked
  // off in identical order across processes. Therefore the tag can be used
  // to match up operations during concurrent execution.
  uint32_t collectiveCounter_;

  // Returns next collective tag to use (uses collectiveCounter_).
  uint32_t nextTag();

  // Returns the context to use for the specified tag.
  // With `nextTag` returning an increasing number, this should lead
  // to contexts being used in a round-robin fashion.
  std::shared_ptr<::gloo::Context> getContext(uint32_t tag);

  // Entrypoint for worker threads.
  void runLoop(int workerIndex);

  // Queue work to run on worker thread.
  void enqueue(c10::intrusive_ptr<AsyncWork> work);

  // Keep both a queue of pending work, and a vector with in progress work.
  // Both of these can only be mutated when holding the queue lock.
  // We keep both around instead of just the queue, so we can grab a weak_ptr
  // to all in progress and pending work when executing a barrier.
  // When executing a barrier, we need to ensure that all prior work
  // has completed before completing itself.
  std::deque<c10::intrusive_ptr<AsyncWork>> workQueue_;
  std::vector<c10::intrusive_ptr<AsyncWork>> workInProgress_;
  std::mutex workMutex_;
  std::condition_variable workProduceCV_;
  std::condition_variable workConsumeCV_;
};

} // namespace c10d

#endif // USE_C10D_GLOO