#pragma once
#include <c10d/Utils.hpp>
namespace c10d {
namespace tcputil {
#define AF_SELECTED AF_UNSPEC
#define CONNECT_SOCKET_OFFSET 2
inline void closeSocket(int socket) { ::close(socket); }
inline int setSocketAddrReUse(int socket) {
int optval = 1;
return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int));
}
inline int poll(struct pollfd *fds, unsigned long nfds, int timeout) {
return ::poll(fds, nfds, timeout);
}
inline void addPollfd(std::vector<struct pollfd> &fds, int socket,
short events) {
fds.push_back({.fd = socket, .events = events});
}
inline void waitSocketConnected(
int socket,
struct ::addrinfo *nextAddr,
std::chrono::milliseconds timeout,
std::chrono::time_point<std::chrono::high_resolution_clock> startTime) {
SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK));
int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
if (ret != 0 && errno != EINPROGRESS) {
throw std::system_error(errno, std::system_category());
}
struct ::pollfd pfd;
pfd.fd = socket;
pfd.events = POLLOUT;
int64_t pollTimeout = -1;
if (timeout != kNoTimeout) {
// calculate remaining time and use that as timeout for poll()
const auto elapsed = std::chrono::high_resolution_clock::now() - startTime;
const auto remaining =
std::chrono::duration_cast<std::chrono::milliseconds>(timeout) -
std::chrono::duration_cast<std::chrono::milliseconds>(elapsed);
pollTimeout = std::max(static_cast<int64_t>(0),
static_cast<int64_t>(remaining.count()));
}
int numReady = ::poll(&pfd, 1, pollTimeout);
if (numReady < 0) {
throw std::system_error(errno, std::system_category());
} else if (numReady == 0) {
errno = 0;
throw std::runtime_error(kConnectTimeoutMsg);
}
socklen_t errLen = sizeof(errno);
errno = 0;
::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen);
// `errno` is set when:
// 1. `getsockopt` has failed
// 2. there is awaiting error in the socket
// (the error is saved to the `errno` variable)
if (errno != 0) {
throw std::system_error(errno, std::system_category());
}
// Disable non-blocking mode
int flags;
SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL));
SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
}
// Linux socket does not need init libs first
inline void socketInitialize() {}
inline struct ::pollfd getPollfd(int socket, short events) {
struct ::pollfd res = {.fd = socket, .events = events};
return res;
}
} // namespace tcputil
} // namespace c10d