Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / schema_test.py






from caffe2.python import core, schema
import numpy as np

import unittest
import pickle
import random

class TestField(unittest.TestCase):
    def testInitShouldSetEmptyParent(self):
        f = schema.Field([])
        self.assertTupleEqual(f._parent, (None, 0))

    def testInitShouldSetFieldOffsets(self):
        f = schema.Field([
            schema.Scalar(dtype=np.int32),
            schema.Struct(
                ('field1', schema.Scalar(dtype=np.int32)),
                ('field2', schema.List(schema.Scalar(dtype=str))),
            ),
            schema.Scalar(dtype=np.int32),
            schema.Struct(
                ('field3', schema.Scalar(dtype=np.int32)),
                ('field4', schema.List(schema.Scalar(dtype=str)))
            ),
            schema.Scalar(dtype=np.int32),
        ])
        self.assertListEqual(f._field_offsets, [0, 1, 4, 5, 8, 9])

    def testInitShouldSetFieldOffsetsIfNoChildren(self):
        f = schema.Field([])
        self.assertListEqual(f._field_offsets, [0])


class TestDB(unittest.TestCase):
    def testPicklable(self):
        s = schema.Struct(
            ('field1', schema.Scalar(dtype=np.int32)),
            ('field2', schema.List(schema.Scalar(dtype=str)))
        )
        s2 = pickle.loads(pickle.dumps(s))
        for r in (s, s2):
            self.assertTrue(isinstance(r.field1, schema.Scalar))
            self.assertTrue(isinstance(r.field2, schema.List))
            self.assertTrue(getattr(r, 'non_existent', None) is None)

    def testListSubclassClone(self):
        class Subclass(schema.List):
            pass

        s = Subclass(schema.Scalar())
        clone = s.clone()
        self.assertIsInstance(clone, Subclass)
        self.assertEqual(s, clone)
        self.assertIsNot(clone, s)

    def testListWithEvictedSubclassClone(self):
        class Subclass(schema.ListWithEvicted):
            pass

        s = Subclass(schema.Scalar())
        clone = s.clone()
        self.assertIsInstance(clone, Subclass)
        self.assertEqual(s, clone)
        self.assertIsNot(clone, s)

    def testStructSubclassClone(self):
        class Subclass(schema.Struct):
            pass

        s = Subclass(
            ('a', schema.Scalar()),
        )
        clone = s.clone()
        self.assertIsInstance(clone, Subclass)
        self.assertEqual(s, clone)
        self.assertIsNot(clone, s)

    def testNormalizeField(self):
        s = schema.Struct(('field1', np.int32), ('field2', str))
        self.assertEquals(
            s,
            schema.Struct(
                ('field1', schema.Scalar(dtype=np.int32)),
                ('field2', schema.Scalar(dtype=str))
            )
        )

    def testTuple(self):
        s = schema.Tuple(np.int32, str, np.float32)
        s2 = schema.Struct(
            ('field_0', schema.Scalar(dtype=np.int32)),
            ('field_1', schema.Scalar(dtype=np.str)),
            ('field_2', schema.Scalar(dtype=np.float32))
        )
        self.assertEquals(s, s2)
        self.assertEquals(s[0], schema.Scalar(dtype=np.int32))
        self.assertEquals(s[1], schema.Scalar(dtype=np.str))
        self.assertEquals(s[2], schema.Scalar(dtype=np.float32))
        self.assertEquals(
            s[2, 0],
            schema.Struct(
                ('field_2', schema.Scalar(dtype=np.float32)),
                ('field_0', schema.Scalar(dtype=np.int32)),
            )
        )
        # test iterator behavior
        for i, (v1, v2) in enumerate(zip(s, s2)):
            self.assertEquals(v1, v2)
            self.assertEquals(s[i], v1)
            self.assertEquals(s2[i], v1)

    def testRawTuple(self):
        s = schema.RawTuple(2)
        self.assertEquals(
            s, schema.Struct(
                ('field_0', schema.Scalar()), ('field_1', schema.Scalar())
            )
        )
        self.assertEquals(s[0], schema.Scalar())
        self.assertEquals(s[1], schema.Scalar())

    def testStructIndexing(self):
        s = schema.Struct(
            ('field1', schema.Scalar(dtype=np.int32)),
            ('field2', schema.List(schema.Scalar(dtype=str))),
            ('field3', schema.Struct()),
        )
        self.assertEquals(s['field2'], s.field2)
        self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str)))
        self.assertEquals(s['field3'], schema.Struct())
        self.assertEquals(
            s['field2', 'field1'],
            schema.Struct(
                ('field2', schema.List(schema.Scalar(dtype=str))),
                ('field1', schema.Scalar(dtype=np.int32)),
            )
        )

    def testListInStructIndexing(self):
        a = schema.List(schema.Scalar(dtype=str))
        s = schema.Struct(
            ('field1', schema.Scalar(dtype=np.int32)),
            ('field2', a)
        )
        self.assertEquals(s['field2:lengths'], a.lengths)
        self.assertEquals(s['field2:values'], a.items)
        with self.assertRaises(KeyError):
            s['fields2:items:non_existent']
        with self.assertRaises(KeyError):
            s['fields2:non_existent']

    def testListWithEvictedInStructIndexing(self):
        a = schema.ListWithEvicted(schema.Scalar(dtype=str))
        s = schema.Struct(
            ('field1', schema.Scalar(dtype=np.int32)),
            ('field2', a)
        )
        self.assertEquals(s['field2:lengths'], a.lengths)
        self.assertEquals(s['field2:values'], a.items)
        self.assertEquals(s['field2:_evicted_values'], a._evicted_values)
        with self.assertRaises(KeyError):
            s['fields2:items:non_existent']
        with self.assertRaises(KeyError):
            s['fields2:non_existent']

    def testMapInStructIndexing(self):
        a = schema.Map(
            schema.Scalar(dtype=np.int32),
            schema.Scalar(dtype=np.float32),
        )
        s = schema.Struct(
            ('field1', schema.Scalar(dtype=np.int32)),
            ('field2', a)
        )
        self.assertEquals(s['field2:values:keys'], a.keys)
        self.assertEquals(s['field2:values:values'], a.values)
        with self.assertRaises(KeyError):
            s['fields2:keys:non_existent']

    def testPreservesMetadata(self):
        s = schema.Struct(
            ('a', schema.Scalar(np.float32)), (
                'b', schema.Scalar(
                    np.int32,
                    metadata=schema.Metadata(categorical_limit=5)
                )
            ), (
                'c', schema.List(
                    schema.Scalar(
                        np.int32,
                        metadata=schema.Metadata(categorical_limit=6)
                    )
                )
            )
        )
        # attach metadata to lengths field
        s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7))

        self.assertEqual(None, s.a.metadata)
        self.assertEqual(5, s.b.metadata.categorical_limit)
        self.assertEqual(6, s.c.value.metadata.categorical_limit)
        self.assertEqual(7, s.c.lengths.metadata.categorical_limit)
        sc = s.clone()
        self.assertEqual(None, sc.a.metadata)
        self.assertEqual(5, sc.b.metadata.categorical_limit)
        self.assertEqual(6, sc.c.value.metadata.categorical_limit)
        self.assertEqual(7, sc.c.lengths.metadata.categorical_limit)
        sv = schema.from_blob_list(
            s, [
                np.array([3.4]), np.array([2]), np.array([3]),
                np.array([1, 2, 3])
            ]
        )
        self.assertEqual(None, sv.a.metadata)
        self.assertEqual(5, sv.b.metadata.categorical_limit)
        self.assertEqual(6, sv.c.value.metadata.categorical_limit)
        self.assertEqual(7, sv.c.lengths.metadata.categorical_limit)

    def testDupField(self):
        with self.assertRaises(ValueError):
            schema.Struct(
                ('a', schema.Scalar()),
                ('a', schema.Scalar()))

    def testAssignToField(self):
        with self.assertRaises(TypeError):
            s = schema.Struct(('a', schema.Scalar()))
            s.a = schema.Scalar()

    def testPreservesEmptyFields(self):
        s = schema.Struct(
            ('a', schema.Scalar(np.float32)),
            ('b', schema.Struct()),
        )
        sc = s.clone()
        self.assertIn("a", sc.fields)
        self.assertIn("b", sc.fields)
        sv = schema.from_blob_list(s, [np.array([3.4])])
        self.assertIn("a", sv.fields)
        self.assertIn("b", sv.fields)
        self.assertEqual(0, len(sv.b.fields))

    def testStructSubstraction(self):
        s1 = schema.Struct(
            ('a', schema.Scalar()),
            ('b', schema.Scalar()),
            ('c', schema.Scalar()),
        )
        s2 = schema.Struct(
            ('b', schema.Scalar())
        )
        s = s1 - s2
        self.assertEqual(['a', 'c'], s.field_names())

        s3 = schema.Struct(
            ('a', schema.Scalar())
        )
        s = s1 - s3
        self.assertEqual(['b', 'c'], s.field_names())

        with self.assertRaises(TypeError):
            s1 - schema.Scalar()

    def testStructNestedSubstraction(self):
        s1 = schema.Struct(
            ('a', schema.Scalar()),
            ('b', schema.Struct(
                ('c', schema.Scalar()),
                ('d', schema.Scalar()),
                ('e', schema.Scalar()),
                ('f', schema.Scalar()),
            )),
        )
        s2 = schema.Struct(
            ('b', schema.Struct(
                ('d', schema.Scalar()),
                ('e', schema.Scalar()),
            )),
        )
        s = s1 - s2
        self.assertEqual(['a', 'b:c', 'b:f'], s.field_names())

    def testStructAddition(self):
        s1 = schema.Struct(
            ('a', schema.Scalar())
        )
        s2 = schema.Struct(
            ('b', schema.Scalar())
        )
        s = s1 + s2
        self.assertIn("a", s.fields)
        self.assertIn("b", s.fields)
        with self.assertRaises(TypeError):
            s1 + s1
        with self.assertRaises(TypeError):
            s1 + schema.Scalar()

    def testStructNestedAddition(self):
        s1 = schema.Struct(
            ('a', schema.Scalar()),
            ('b', schema.Struct(
                ('c', schema.Scalar())
            )),
        )
        s2 = schema.Struct(
            ('b', schema.Struct(
                ('d', schema.Scalar())
            ))
        )
        s = s1 + s2
        self.assertEqual(['a', 'b:c', 'b:d'], s.field_names())

        s3 = schema.Struct(
            ('b', schema.Scalar()),
        )
        with self.assertRaises(TypeError):
            s = s1 + s3

    def testGetFieldByNestedName(self):
        st = schema.Struct(
            ('a', schema.Scalar()),
            ('b', schema.Struct(
                ('c', schema.Struct(
                    ('d', schema.Scalar()),
                )),
            )),
        )
        self.assertRaises(KeyError, st.__getitem__, '')
        self.assertRaises(KeyError, st.__getitem__, 'x')
        self.assertRaises(KeyError, st.__getitem__, 'x:y')
        self.assertRaises(KeyError, st.__getitem__, 'b:c:x')
        a = st['a']
        self.assertTrue(isinstance(a, schema.Scalar))
        bc = st['b:c']
        self.assertIn('d', bc.fields)
        bcd = st['b:c:d']
        self.assertTrue(isinstance(bcd, schema.Scalar))

    def testAddFieldByNestedName(self):
        f_a = schema.Scalar(blob=core.BlobReference('blob1'))
Loading ...