Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / lazy_dyndep_test.py

#!/usr/bin/env python






from hypothesis import given, settings
import hypothesis.strategies as st
from multiprocessing import Process

import numpy as np
import tempfile
import shutil

import caffe2.python.hypothesis_test_util as hu
import unittest

op_engine = 'GLOO'

class TemporaryDirectory:
    def __enter__(self):
        self.tmpdir = tempfile.mkdtemp()
        return self.tmpdir

    def __exit__(self, type, value, traceback):
        shutil.rmtree(self.tmpdir)


def allcompare_process(filestore_dir, process_id, data, num_procs):
    from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
    from caffe2.python.model_helper import ModelHelper
    from caffe2.proto import caffe2_pb2
    lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")

    workspace.RunOperatorOnce(
        core.CreateOperator(
            "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
        )
    )
    rendezvous = dict(
        kv_handler="store_handler",
        shard_id=process_id,
        num_shards=num_procs,
        engine=op_engine,
        exit_nets=None
    )

    model = ModelHelper()
    model._rendezvous = rendezvous

    workspace.FeedBlob("test_data", data)

    data_parallel_model._RunComparison(
        model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
    )


class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
    @given(
        d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
    )
    @settings(deadline=None)
    def test_allcompare(self, d, n, num_procs):
        dims = []
        for _ in range(d):
            dims.append(np.random.randint(1, high=n))
        test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)

        with TemporaryDirectory() as tempdir:
            processes = []
            for idx in range(num_procs):
                process = Process(
                    target=allcompare_process,
                    args=(tempdir, idx, test_data, num_procs)
                )
                processes.append(process)
                process.start()

            while len(processes) > 0:
                process = processes.pop()
                process.join()

class TestLazyDynDepError(unittest.TestCase):
    def test_errorhandler(self):
        from caffe2.python import core, lazy_dyndep
        import tempfile

        with tempfile.NamedTemporaryFile() as f:
            lazy_dyndep.RegisterOpsLibrary(f.name)

            def handler(e):
                raise ValueError("test")
            lazy_dyndep.SetErrorHandler(handler)
            with self.assertRaises(ValueError, msg="test"):
                core.RefreshRegisteredOperators()

    def test_importaftererror(self):
        from caffe2.python import core, lazy_dyndep
        import tempfile

        with tempfile.NamedTemporaryFile() as f:
            lazy_dyndep.RegisterOpsLibrary(f.name)

            def handler(e):
                raise ValueError("test")
            lazy_dyndep.SetErrorHandler(handler)
            with self.assertRaises(ValueError):
                core.RefreshRegisteredOperators()

            def handlernoop(e):
                raise
            lazy_dyndep.SetErrorHandler(handlernoop)
            lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
            core.RefreshRegisteredOperators()

    def test_workspacecreatenet(self):
        from caffe2.python import workspace, lazy_dyndep
        import tempfile

        with tempfile.NamedTemporaryFile() as f:
            lazy_dyndep.RegisterOpsLibrary(f.name)
            called = False

            def handler(e):
                raise ValueError("test")
            lazy_dyndep.SetErrorHandler(handler)
            with self.assertRaises(ValueError, msg="test"):
                workspace.CreateNet("fake")


if __name__ == "__main__":
    unittest.main()