Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / caffe2 / mpi / mpi_common.h

#ifndef CAFFE2_MPI_MPI_COMMON_H_
#define CAFFE2_MPI_MPI_COMMON_H_

#include <mpi.h>
#include <mutex>

#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"

namespace caffe2 {

inline void CheckInitializedMPI() {
  int flag;
  MPI_Initialized(&flag);
  CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
}

template <typename T>
class MPIDataTypeWrapper;

#define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
  template <>                                  \
  class MPIDataTypeWrapper<c_type> {           \
   public:                                     \
    inline static MPI_Datatype type() {        \
      return mpi_type;                         \
    }                                          \
  };

MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
// Note(Yangqing): as necessary, add more specializations.
#undef MPI_DATATYPE_WRAPPER

// For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
TORCH_API std::mutex& MPIMutex();

#define MPI_CHECK(condition)                                 \
  do {                                                       \
    std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
    int error = (condition);                                 \
    CAFFE_ENFORCE(                                           \
        error == MPI_SUCCESS,                                \
        "Caffe2 MPI Error at: ",                             \
        __FILE__,                                            \
        ":",                                                 \
        __LINE__,                                            \
        ": ",                                                \
        error);                                              \
  } while (0)

/**
 * @brief Gets the global MPI communicator used by Caffe2. In default, this
 * is MPI_COMM_WORLD unless you call SetGlobalMPIComm().
 */
TORCH_API MPI_Comm GlobalMPIComm();

/**
 * @brief Sets the global MPI communicator. Caffe2 takes over the ownership
 * of the passed in communicator.
 */
TORCH_API void SetGlobalMPIComm(MPI_Comm new_comm);

/**
 * @brief A helper function to return the size of the given communicator.
 */
TORCH_API int MPICommSize(MPI_Comm comm);

/**
 * @brief A helper function to return the rank of the given communicator.
 */
TORCH_API int MPICommRank(MPI_Comm comm);

/**
 * @brief A simple wrapper over an MPI common world.
 */
class MPICommonWorldWrapper {
 public:
  /**
   * @brief Creates a common world wrapper.
   *
   * The new common world is created by taking the existing communicator
   * passed in as src_comm, and splitting it using the color and the rank
   * specified. In default, we will split from Caffe2's global communicator,
   * and use color 0 as well as rank implicitly given by src_comm. As a result,
   * the default constructor basically creates a comm identical to the source
   * comm world.
   */
  explicit MPICommonWorldWrapper(
      MPI_Comm src_comm = MPI_COMM_NULL,
      int color = 0,
      int rank = -1) {
    if (src_comm == MPI_COMM_NULL) {
      src_comm = GlobalMPIComm();
    }
    if (rank == -1) {
      MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
    }
    MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
    MPI_CHECK(MPI_Comm_size(comm_, &size_));
    MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
  }

  ~MPICommonWorldWrapper() {
    int ret;
    MPI_CHECK(MPI_Finalized(&ret));
    if (!ret) {
      MPI_Comm_free(&comm_);
    }
  }

  /**
   * @brief Returns the common world held by the wrapper.
   */
  inline MPI_Comm comm() const {
    return comm_;
  }
  /**
   * @brief Returns the size of the world.
   */
  inline int size() const {
    return size_;
  }
  /**
   * @brief Returns the rank of this process in the world.
   */
  inline int rank() const {
    return rank_;
  }

 private:
  MPI_Comm comm_;
  int size_;
  int rank_;
};

/**
 * A function used to perform peer setup so one does not need to use
 * mpirun / mpiexec to run the binary. Note that if you use mpirun or mpiexec
 * to set up the common world, do not use this function - MPI_Init would have
 * already set that up.
 *
 * This also assumes that you have a common path (like NFS) that multiple
 * instances can read from.
 *
 * Inputs:
 *   replicas (int): the number of replicas that mpi will run with.
 *   role (string): the role of this process, "server" or "client".
 *   job_path (string): a file name that the server will write its port into
 *       and the clients will read the server's port from.
 */
void MPISetupPeers(
    const int replicas,
    const string& role,
    const string& job_path);
} // namespace caffe2

#endif // CAFFE2_MPI_MPI_COMMON_H_