Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / pytools   python

Repository URL to install this package:

/ prefork.py

"""OpenMPI, once intialized, prohibits forking. This helper module
allows the forking of *one* helper child process before OpenMPI
initializaton that can do the forking for the fork-challenged
parent process.

Since none of this is MPI-specific, it got parked in pytools.
"""
from __future__ import absolute_import


class ExecError(OSError):
    pass


class DirectForker(object):
    def __init__(self):
        self.apids = {}
        self.count = 0

    @staticmethod
    def call(cmdline, cwd=None):
        from subprocess import call as spcall

        try:
            return spcall(cmdline, cwd=cwd)
        except OSError as e:
            raise ExecError("error invoking '%s': %s"
                            % (" ".join(cmdline), e))

    def call_async(self, cmdline, cwd=None):
        from subprocess import Popen

        try:
            self.count += 1

            proc = Popen(cmdline, cwd=cwd)
            self.apids[self.count] = proc

            return self.count
        except OSError as e:
            raise ExecError("error invoking '%s': %s"
                             % (" ".join(cmdline), e))

    @staticmethod
    def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
        from subprocess import Popen, PIPE

        try:
            popen = Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE,
                          stderr=PIPE)
            stdout_data, stderr_data = popen.communicate()

            if error_on_nonzero and popen.returncode:
                raise ExecError("status %d invoking '%s': %s"
                                % (popen.returncode, " ".join(cmdline),
                                   stderr_data.decode("utf-8", errors="replace")))

            return popen.returncode, stdout_data, stderr_data
        except OSError as e:
            raise ExecError("error invoking '%s': %s"
                            % (" ".join(cmdline), e))

    def wait(self, aid):
        proc = self.apids.pop(aid)
        retc = proc.wait()

        return retc

    def waitall(self):
        rets = {}

        for aid in list(self.apids):
            rets[aid] = self.wait(aid)

        return rets


def _send_packet(sock, data):
    from struct import pack
    from six.moves.cPickle import dumps

    packet = dumps(data)

    sock.sendall(pack("I", len(packet)))
    sock.sendall(packet)


def _recv_packet(sock, who="Process", partner="other end"):
    from struct import calcsize, unpack
    size_bytes_size = calcsize("I")
    size_bytes = sock.recv(size_bytes_size)

    if len(size_bytes) < size_bytes_size:
        from warnings import warn
        warn("%s exiting upon apparent death of %s" % (who, partner))

        raise SystemExit

    size, = unpack("I", size_bytes)

    packet = b""
    while len(packet) < size:
        packet += sock.recv(size)

    from six.moves.cPickle import loads
    return loads(packet)


def _fork_server(sock):
    # Ignore keyboard interrupts, we'll get notified by the parent.
    import signal
    signal.signal(signal.SIGINT, signal.SIG_IGN)

    # Construct a local DirectForker to do the dirty work
    df = DirectForker()

    funcs = {
        "call": df.call,
        "call_async": df.call_async,
        "call_capture_output": df.call_capture_output,
        "wait": df.wait,
        "waitall": df.waitall
    }

    try:
        while True:
            func_name, args, kwargs = _recv_packet(
                sock, who="Prefork server", partner="parent"
            )

            if func_name == "quit":
                df.waitall()
                _send_packet(sock, ("ok", None))
                break
            else:
                try:
                    result = funcs[func_name](*args, **kwargs)
                # FIXME: Is catching all exceptions the right course of action?
                except Exception as e:  # pylint:disable=broad-except
                    _send_packet(sock, ("exception", e))
                else:
                    _send_packet(sock, ("ok", result))
    finally:
        sock.close()

    import os
    os._exit(0)  # pylint:disable=protected-access


class IndirectForker(object):
    def __init__(self, server_pid, sock):
        self.server_pid = server_pid
        self.socket = sock

        import atexit
        atexit.register(self._quit)

    def _remote_invoke(self, name, *args, **kwargs):
        _send_packet(self.socket, (name, args, kwargs))
        status, result = _recv_packet(
            self.socket, who="Prefork client", partner="prefork server"
        )

        if status == "exception":
            raise result

        assert status == "ok"
        return result

    def _quit(self):
        self._remote_invoke("quit")

        from os import waitpid
        waitpid(self.server_pid, 0)

    def call(self, cmdline, cwd=None):
        return self._remote_invoke("call", cmdline, cwd)

    def call_async(self, cmdline, cwd=None):
        return self._remote_invoke("call_async", cmdline, cwd)

    def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True):
        return self._remote_invoke("call_capture_output", cmdline, cwd,
                                   error_on_nonzero)

    def wait(self, aid):
        return self._remote_invoke("wait", aid)

    def waitall(self):
        return self._remote_invoke("waitall")


forker = DirectForker()


def enable_prefork():
    global forker  # pylint:disable=global-statement

    if isinstance(forker, IndirectForker):
        return

    from socket import socketpair
    s_parent, s_child = socketpair()

    from os import fork
    fork_res = fork()

    # Child
    if fork_res == 0:
        s_parent.close()
        _fork_server(s_child)
    # Parent
    else:
        s_child.close()
        forker = IndirectForker(fork_res, s_parent)


def call(cmdline, cwd=None):
    return forker.call(cmdline, cwd)


def call_async(cmdline, cwd=None):
    return forker.call_async(cmdline, cwd)


def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
    return forker.call_capture_output(cmdline, cwd, error_on_nonzero)


def wait(aid):
    return forker.wait(aid)


def waitall():
    return forker.waitall()