Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / caffe2 / utils / zmq_helper.h

#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_
#define CAFFE2_UTILS_ZMQ_HELPER_H_

#include <zmq.h>

#include "caffe2/core/logging.h"

namespace caffe2 {

class ZmqContext {
 public:
  explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) {
    CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context.");
    int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads);
    CAFFE_ENFORCE_EQ(rc, 0);
    rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT);
    CAFFE_ENFORCE_EQ(rc, 0);
  }
  ~ZmqContext() {
    int rc = zmq_ctx_destroy(ptr_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void* ptr() { return ptr_; }

 private:
  void* ptr_;

  C10_DISABLE_COPY_AND_ASSIGN(ZmqContext);
};

class ZmqMessage {
 public:
  ZmqMessage() {
    int rc = zmq_msg_init(&msg_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  ~ZmqMessage() {
    int rc = zmq_msg_close(&msg_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  zmq_msg_t* msg() { return &msg_; }

  void* data() { return zmq_msg_data(&msg_); }
  size_t size() { return zmq_msg_size(&msg_); }

 private:
  zmq_msg_t msg_;
  C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage);
};

class ZmqSocket {
 public:
  explicit ZmqSocket(int type)
      : context_(1), ptr_(zmq_socket(context_.ptr(), type)) {
    CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq socket.");
  }

  ~ZmqSocket() {
    int rc = zmq_close(ptr_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Bind(const string& addr) {
    int rc = zmq_bind(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Unbind(const string& addr) {
    int rc = zmq_unbind(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Connect(const string& addr) {
    int rc = zmq_connect(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Disconnect(const string& addr) {
    int rc = zmq_disconnect(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  int Send(const string& msg, int flags) {
    int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags);
    if (nbytes) {
      return nbytes;
    } else if (zmq_errno() == EAGAIN) {
      return 0;
    } else {
      LOG(FATAL) << "Cannot send zmq message. Error number: "
                      << zmq_errno();
      return 0;
    }
  }

  int SendTillSuccess(const string& msg, int flags) {
    CAFFE_ENFORCE(msg.size(), "You cannot send an empty message.");
    int nbytes = 0;
    do {
      nbytes = Send(msg, flags);
    } while (nbytes == 0);
    return nbytes;
  }

  int Recv(ZmqMessage* msg) {
    int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0);
    if (nbytes >= 0) {
      return nbytes;
    } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
      return 0;
    } else {
      LOG(FATAL) << "Cannot receive zmq message. Error number: "
                      << zmq_errno();
      return 0;
    }
  }

  int RecvTillSuccess(ZmqMessage* msg) {
    int nbytes = 0;
    do {
      nbytes = Recv(msg);
    } while (nbytes == 0);
    return nbytes;
  }

 private:
  ZmqContext context_;
  void* ptr_;
};

}  // namespace caffe2


#endif  // CAFFE2_UTILS_ZMQ_HELPER_H_