Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / caffe2 / mpi / mpi_ops.h

#ifndef CAFFE2_MPI_MPI_OPS_H_
#define CAFFE2_MPI_MPI_OPS_H_

#include <mpi.h>

#include "caffe2/core/operator.h"
#include "caffe2/mpi/mpi_common.h"

namespace caffe2 {

// TODO(jiayq): if needed, write up the use of color and key with MPI split.
// Currently, the operator simply creates a communicator that has the
// same topology as the Caffe2 global communicator.
template <class Context>
class MPICreateCommonWorldOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  MPICreateCommonWorldOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws) {}

  bool RunOnDevice() override {
    OperatorBase::Outputs()[0]->Reset(new MPICommonWorldWrapper());
    return true;
  }
};

template <class Context>
class MPIBroadcastOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  MPIBroadcastOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
  ~MPIBroadcastOp() {}

  bool RunOnDevice() override {
    MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
    CAFFE_ENFORCE(
        OperatorBase::OutputIsTensorType(0, Context::GetDeviceType()),
        "Output is of wrong type.");
    auto* output = Output(0);
    // Make sure that output is already allocated.
    CAFFE_ENFORCE(
        output->numel() > 0,
        "Broadcast op uses in-place operation so the output "
        "should be already allocated.");
    MPI_CHECK(MPI_Bcast(
        output->raw_mutable_data(),
        output->nbytes(),
        MPIDataTypeWrapper<char>::type(),
        root_,
        comm));
    return true;
  }

 protected:
  int root_;
};

// MPIReduceOp does Reduce using MPI. Currently, only SUM is supported.
template <typename T, class Context>
class MPIReduceOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  MPIReduceOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
  ~MPIReduceOp() {}

  bool RunOnDevice() override {
    MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
    auto& input = Input(1);
    auto* output = Output(0, input.sizes(), at::dtype<T>());
    MPI_CHECK(MPI_Reduce(
        const_cast<T*>(input.template data<T>()),
        output->template mutable_data<T>(),
        input.numel(),
        MPIDataTypeWrapper<T>::type(),
        MPI_SUM,
        root_,
        comm));
    return true;
  }

 protected:
  int root_;
};

// MPIAllgatherOp does MPIAllgather using MPI.
template <typename T, class Context>
class MPIAllgatherOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  USE_SIMPLE_CTOR_DTOR(MPIAllgatherOp);

  bool RunOnDevice() override {
    MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
    auto& input = Input(1);
    auto* output = Output(0);
    vector<int64_t> output_dims = input.sizes().vec();
    output_dims[0] *= OperatorBase::Input<MPICommonWorldWrapper>(0).size();
    output->Resize(output_dims);
    MPI_CHECK(MPI_Allgather(
        const_cast<T*>(input.template data<T>()),
        input.numel(),
        MPIDataTypeWrapper<T>::type(),
        output->template mutable_data<T>(),
        input.numel(),
        MPIDataTypeWrapper<T>::type(),
        comm));
    return true;
  }
};

// MPIAllreduceOp does MPIAllreduce using MPI. Currently, only SUM is supported.
template <typename T, class Context>
class MPIAllreduceOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  USE_SIMPLE_CTOR_DTOR(MPIAllreduceOp);

  bool RunOnDevice() override {
    MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
    auto& input = Input(1);
    auto* output = Output(0, input.sizes(), at::dtype<T>());
    void* source;
    if (output->template mutable_data<T>() == input.template data<T>()) {
      // We are doing in-place call. Special case handling.
      source = MPI_IN_PLACE;
    } else {
      // Normal allreduce takes the source from the input.
      source = const_cast<T*>(input.template data<T>());
    }
    MPI_CHECK(MPI_Allreduce(
        source,
        output->template mutable_data<T>(),
        input.numel(),
        MPIDataTypeWrapper<T>::type(),
        MPI_SUM,
        comm));
    return true;
  }
};

template <class Context>
class MPISendTensorOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  MPISendTensorOp(const OperatorDef& def, Workspace* ws)
      : Operator<Context>(def, ws),
        OP_SINGLE_ARG(int, "dst", dst_, MPI_ANY_SOURCE),
        OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG),
        OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) {
    CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet.");
    CAFFE_ENFORCE(
        dst_ != MPI_ANY_SOURCE || def.input_size() == 4,
        "You should explicitly specify the to rank either via "
        "argument or via input blobs.");
    CAFFE_ENFORCE(
        tag_ != MPI_ANY_TAG || def.input_size() == 4,
        "You should explicitly specify the tag either via "
        "argument or via input blobs.");
  }

  bool RunOnDevice() override {
    MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm();
    auto& input = Input(INPUT);
    if (InputSize() == 4) {
      dst_ = OperatorBase::Input<Tensor>(DST, CPU).template data<int>()[0];
      tag_ = OperatorBase::Input<Tensor>(TAG, CPU).template data<int>()[0];
    }
    if (raw_buffer_) {
      // We need to do a const cast to cope with the fact that, before OpenMPI
      // 1.7, MPI_Send expects a non-const pointer although it uses it in a
      // const way.
      MPI_CHECK(MPI_Send(
          const_cast<void*>(input.raw_data()),
          input.nbytes(),
          MPI_CHAR,
          dst_,
          tag_,
          comm));
    } else {
      CAFFE_NOT_IMPLEMENTED;
    }
    return true;
  }

 protected:
  int dst_;
  int tag_;
  bool raw_buffer_;

  INPUT_TAGS(COMM, INPUT, DST, TAG);
};

template <class Context>
class MPIReceiveTensorOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  MPIReceiveTensorOp(const OperatorDef& def, Workspace* ws)
      : Operator<Context>(def, ws),
        OP_SINGLE_ARG(int, "src", src_, MPI_ANY_SOURCE),
        OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG),
        OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) {
    CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet.");
  }

  bool RunOnDevice() override {
    MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm();
    if (InputSize() == 4) {
      src_ = OperatorBase::Input<Tensor>(SRC_IN, CPU).template data<int>()[0];
      tag_ = OperatorBase::Input<Tensor>(TAG_IN, CPU).template data<int>()[0];
    }
    MPI_Status status;
    if (raw_buffer_) {
      auto* output = Output(OUTPUT);
      MPI_CHECK(MPI_Recv(
          output->raw_mutable_data(),
          output->nbytes(),
          MPI_CHAR,
          src_,
          tag_,
          comm,
          &status));
    } else {
      CAFFE_NOT_IMPLEMENTED;
    }
    auto* src_out = OperatorBase::Output<Tensor>(SRC_OUT, CPU);
    src_out->Resize();
    src_out->template mutable_data<int>()[0] = status.MPI_SOURCE;
    auto* tag_out = OperatorBase::Output<Tensor>(TAG_OUT, CPU);
    tag_out->Resize();
    tag_out->template mutable_data<int>()[0] = status.MPI_TAG;
    return true;
  }

 protected:
  int src_;
  int tag_;
  bool raw_buffer_;
  INPUT_TAGS(COMM, INPUT, SRC_IN, TAG_IN);
  OUTPUT_TAGS(OUTPUT, SRC_OUT, TAG_OUT);
};

} // namespace caffe2

#endif // CAFFE2_MPI_MPI_OPS_H_