Repository URL to install this package:
|
Version:
2.4.1 ▾
|
#pragma once
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
constexpr auto kProcessGroupCudaP2PDefaultTimeout =
std::chrono::milliseconds(10 * 60 * 1000);
namespace c10d {
class TORCH_API ProcessGroupCudaP2P : public Backend {
public:
struct Options : Backend::Options {
c10::intrusive_ptr<ProcessGroupNCCL::Options> nccl_options;
c10::optional<size_t> buffer_size;
explicit Options()
: Backend::Options("cuda_p2p", kProcessGroupCudaP2PDefaultTimeout) {}
};
bool is_p2p_available();
size_t get_buffer_size();
c10::Stream stream();
ProcessGroupCudaP2P(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options);
c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce_sparse(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) override;
c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
/* P2P-only */
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<Work> intra_node_barrier(
c10::optional<std::vector<int64_t>> ranks = c10::nullopt);
at::Tensor get_p2p_buffer(
size_t rank,
const std::vector<int64_t>& sizes,
c10::ScalarType dtype,
int64_t storage_offest = 0);
void shutdown(c10::optional<std::string> reason = c10::nullopt);
private:
c10::intrusive_ptr<ProcessGroupNCCL> nccl_backend_;
c10::intrusive_ptr<c10d::intra_node_comm::IntraNodeComm> p2p_backend_;
c10::Stream stream_;
};
} // namespace c10d
#endif // USE_C10D_NCCL