Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / operator_test / index_ops_test.py





from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
import numpy as np
import tempfile


class TestIndexOps(TestCase):
    def _test_index_ops(self, entries, dtype, index_create_op):
        workspace.RunOperatorOnce(core.CreateOperator(
            index_create_op,
            [],
            ['index'],
            max_elements=10))
        my_entries = np.array(
            [entries[0], entries[1], entries[2]], dtype=dtype)

        workspace.FeedBlob('entries', my_entries)
        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexLoad',
            ['index', 'entries'],
            ['index']))
        query1 = np.array(
            [entries[0], entries[3], entries[0], entries[4]],
            dtype=dtype)

        workspace.FeedBlob('query1', query1)
        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexGet',
            ['index', 'query1'],
            ['result1']))
        result1 = workspace.FetchBlob('result1')
        np.testing.assert_array_equal([1, 4, 1, 5], result1)

        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexFreeze',
            ['index'],
            ['index']))

        query2 = np.array(
            [entries[5], entries[4], entries[0], entries[6], entries[7]],
            dtype=dtype)
        workspace.FeedBlob('query2', query2)
        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexGet',
            ['index', 'query2'],
            ['result2']))
        result2 = workspace.FetchBlob('result2')
        np.testing.assert_array_equal([0, 5, 1, 0, 0], result2)

        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexSize',
            ['index'],
            ['index_size']))
        size = workspace.FetchBlob('index_size')
        self.assertEquals(size, 6)

        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexStore',
            ['index'],
            ['stored_entries']))
        stored_actual = workspace.FetchBlob('stored_entries')
        new_entries = np.array([entries[3], entries[4]], dtype=dtype)
        expected = np.concatenate((my_entries, new_entries))
        if dtype is str:
            # we'll always get bytes back from Caffe2
            expected = np.array([
                x.item().encode('utf-8') if isinstance(x, np.str_) else x
                for x in expected
            ], dtype=object)
        np.testing.assert_array_equal(expected, stored_actual)

        workspace.RunOperatorOnce(core.CreateOperator(
            index_create_op,
            [],
            ['index2']))

        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexLoad',
            ['index2', 'stored_entries'],
            ['index2'],
            skip_first_entry=1))

        workspace.RunOperatorOnce(core.CreateOperator(
            'IndexSize',
            ['index2'],
            ['index2_size']))
        index2_size = workspace.FetchBlob('index2_size')
        self.assertEquals(index2_size, 5)

        # test serde
        with tempfile.NamedTemporaryFile() as tmp:
            workspace.RunOperatorOnce(core.CreateOperator(
                'Save',
                ['index'],
                [],
                absolute_path=1,
                db_type='minidb',
                db=tmp.name))
            # frees up the blob
            workspace.FeedBlob('index', np.array([]))
            # reloads the index
            workspace.RunOperatorOnce(core.CreateOperator(
                'Load',
                [],
                ['index'],
                absolute_path=1,
                db_type='minidb',
                db=tmp.name))
            query3 = np.array(
                [entries[0], entries[3], entries[0], entries[4], entries[4]],
                dtype=dtype)

            workspace.FeedBlob('query3', query3)
            workspace.RunOperatorOnce(core.CreateOperator(
                'IndexGet', ['index', 'query3'], ['result3']))
            result3 = workspace.FetchBlob('result3')
            np.testing.assert_array_equal([1, 4, 1, 5, 5], result3)

    def test_string_index_ops(self):
        self._test_index_ops([
            'entry1', 'entry2', 'entry3', 'new_entry1',
            'new_entry2', 'miss1', 'miss2', 'miss3',
        ], str, 'StringIndexCreate')

    def test_int_index_ops(self):
        self._test_index_ops(list(range(8)), np.int32, 'IntIndexCreate')

    def test_long_index_ops(self):
        self._test_index_ops(list(range(8)), np.int64, 'LongIndexCreate')

if __name__ == "__main__":
    import unittest
    unittest.main()