from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
import numpy as np
import tempfile
class TestIndexOps(TestCase):
def _test_index_ops(self, entries, dtype, index_create_op):
workspace.RunOperatorOnce(core.CreateOperator(
index_create_op,
[],
['index'],
max_elements=10))
my_entries = np.array(
[entries[0], entries[1], entries[2]], dtype=dtype)
workspace.FeedBlob('entries', my_entries)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexLoad',
['index', 'entries'],
['index']))
query1 = np.array(
[entries[0], entries[3], entries[0], entries[4]],
dtype=dtype)
workspace.FeedBlob('query1', query1)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexGet',
['index', 'query1'],
['result1']))
result1 = workspace.FetchBlob('result1')
np.testing.assert_array_equal([1, 4, 1, 5], result1)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexFreeze',
['index'],
['index']))
query2 = np.array(
[entries[5], entries[4], entries[0], entries[6], entries[7]],
dtype=dtype)
workspace.FeedBlob('query2', query2)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexGet',
['index', 'query2'],
['result2']))
result2 = workspace.FetchBlob('result2')
np.testing.assert_array_equal([0, 5, 1, 0, 0], result2)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexSize',
['index'],
['index_size']))
size = workspace.FetchBlob('index_size')
self.assertEquals(size, 6)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexStore',
['index'],
['stored_entries']))
stored_actual = workspace.FetchBlob('stored_entries')
new_entries = np.array([entries[3], entries[4]], dtype=dtype)
expected = np.concatenate((my_entries, new_entries))
if dtype is str:
# we'll always get bytes back from Caffe2
expected = np.array([
x.item().encode('utf-8') if isinstance(x, np.str_) else x
for x in expected
], dtype=object)
np.testing.assert_array_equal(expected, stored_actual)
workspace.RunOperatorOnce(core.CreateOperator(
index_create_op,
[],
['index2']))
workspace.RunOperatorOnce(core.CreateOperator(
'IndexLoad',
['index2', 'stored_entries'],
['index2'],
skip_first_entry=1))
workspace.RunOperatorOnce(core.CreateOperator(
'IndexSize',
['index2'],
['index2_size']))
index2_size = workspace.FetchBlob('index2_size')
self.assertEquals(index2_size, 5)
# test serde
with tempfile.NamedTemporaryFile() as tmp:
workspace.RunOperatorOnce(core.CreateOperator(
'Save',
['index'],
[],
absolute_path=1,
db_type='minidb',
db=tmp.name))
# frees up the blob
workspace.FeedBlob('index', np.array([]))
# reloads the index
workspace.RunOperatorOnce(core.CreateOperator(
'Load',
[],
['index'],
absolute_path=1,
db_type='minidb',
db=tmp.name))
query3 = np.array(
[entries[0], entries[3], entries[0], entries[4], entries[4]],
dtype=dtype)
workspace.FeedBlob('query3', query3)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexGet', ['index', 'query3'], ['result3']))
result3 = workspace.FetchBlob('result3')
np.testing.assert_array_equal([1, 4, 1, 5, 5], result3)
def test_string_index_ops(self):
self._test_index_ops([
'entry1', 'entry2', 'entry3', 'new_entry1',
'new_entry2', 'miss1', 'miss2', 'miss3',
], str, 'StringIndexCreate')
def test_int_index_ops(self):
self._test_index_ops(list(range(8)), np.int32, 'IntIndexCreate')
def test_long_index_ops(self):
self._test_index_ops(list(range(8)), np.int64, 'LongIndexCreate')
if __name__ == "__main__":
import unittest
unittest.main()