Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / operator_test / checkpoint_test.py






from caffe2.python import core, workspace, test_util
import os
import shutil
import tempfile
import unittest


class CheckpointTest(test_util.TestCase):
    """A simple test case to make sure that the checkpoint behavior is correct.
    """

    @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
    def testCheckpoint(self):
        temp_root = tempfile.mkdtemp()
        net = core.Net("test_checkpoint")
        # Note(jiayq): I am being a bit lazy here and am using the old iter
        # convention that does not have an input. Optionally change it to the
        # new style if needed.
        net.Iter([], "iter")
        net.ConstantFill([], "value", shape=[1, 2, 3])
        net.Checkpoint(["iter", "value"], [],
                       db=os.path.join(temp_root, "test_checkpoint_at_%05d"),
                       db_type="leveldb", every=10, absolute_path=True)
        self.assertTrue(workspace.CreateNet(net))
        for i in range(100):
            self.assertTrue(workspace.RunNet("test_checkpoint"))
        for i in range(1, 10):
            # Print statements are only for debugging purposes.
            # print("Asserting %d" % i)
            # print(os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10)))
            self.assertTrue(os.path.exists(
                os.path.join(temp_root, "test_checkpoint_at_%05d" % (i * 10))))

        # Finally, clean up.
        shutil.rmtree(temp_root)


if __name__ == "__main__":
    unittest.main()