Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / parallel_workers_test.py






import unittest

from caffe2.python import workspace, core
import caffe2.python.parallel_workers as parallel_workers


def create_queue():
    queue = 'queue'

    workspace.RunOperatorOnce(
        core.CreateOperator(
            "CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000
        )
    )
    # Technically, blob creations aren't thread safe. Since the unittest below
    # does RunOperatorOnce instead of CreateNet+RunNet, we have to precreate
    # all blobs beforehand
    for i in range(100):
        workspace.C.Workspace.current.create_blob("blob_" + str(i))
        workspace.C.Workspace.current.create_blob("status_blob_" + str(i))
    workspace.C.Workspace.current.create_blob("dequeue_blob")
    workspace.C.Workspace.current.create_blob("status_blob")

    return queue


def create_worker(queue, get_blob_data):
    def dummy_worker(worker_id):
        blob = 'blob_' + str(worker_id)

        workspace.FeedBlob(blob, get_blob_data(worker_id))

        workspace.RunOperatorOnce(
            core.CreateOperator(
                'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob_' + str(worker_id)]
            )
        )

    return dummy_worker


def dequeue_value(queue):
    dequeue_blob = 'dequeue_blob'
    workspace.RunOperatorOnce(
        core.CreateOperator(
            "SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob']
        )
    )

    return workspace.FetchBlob(dequeue_blob)


class ParallelWorkersTest(unittest.TestCase):
    def testParallelWorkers(self):
        workspace.ResetWorkspace()

        queue = create_queue()
        dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
        worker_coordinator = parallel_workers.init_workers(dummy_worker)
        worker_coordinator.start()

        for _ in range(10):
            value = dequeue_value(queue)
            self.assertTrue(
                value in [b'0', b'1'], 'Got unexpected value ' + str(value)
            )

        self.assertTrue(worker_coordinator.stop())

    def testParallelWorkersInitFun(self):
        workspace.ResetWorkspace()

        queue = create_queue()
        dummy_worker = create_worker(
            queue, lambda worker_id: workspace.FetchBlob('data')
        )
        workspace.FeedBlob('data', 'not initialized')

        def init_fun(worker_coordinator, global_coordinator):
            workspace.FeedBlob('data', 'initialized')

        worker_coordinator = parallel_workers.init_workers(
            dummy_worker, init_fun=init_fun
        )
        worker_coordinator.start()

        for _ in range(10):
            value = dequeue_value(queue)
            self.assertEqual(
                value, b'initialized', 'Got unexpected value ' + str(value)
            )

        # A best effort attempt at a clean shutdown
        worker_coordinator.stop()

    def testParallelWorkersShutdownFun(self):
        workspace.ResetWorkspace()

        queue = create_queue()
        dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
        workspace.FeedBlob('data', 'not shutdown')

        def shutdown_fun():
            workspace.FeedBlob('data', 'shutdown')

        worker_coordinator = parallel_workers.init_workers(
            dummy_worker, shutdown_fun=shutdown_fun
        )
        worker_coordinator.start()

        self.assertTrue(worker_coordinator.stop())

        data = workspace.FetchBlob('data')
        self.assertEqual(data, b'shutdown', 'Got unexpected value ' + str(data))