Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / operator_test / length_split_op_test.py






from caffe2.python import core
from hypothesis import given, settings
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
import hypothesis.strategies as st
import numpy as np


class TestLengthSplitOperator(serial.SerializedTestCase):

    def _length_split_op_ref(self, input_lengths, n_split_array):
        output = []
        n_split = n_split_array[0]
        for x in input_lengths:
            mod = x % n_split
            val = x // n_split + 1
            for _ in range(n_split):
                if mod > 0:
                    output.append(val)
                    mod -= 1
                else:
                    output.append(val - 1)
        return [np.array(output).astype(np.int32)]

    @given(**hu.gcs_cpu_only)
    @settings(deadline=1000)
    def test_length_split_edge(self, gc, dc):
        input_lengths = np.array([3, 4, 5]).astype(np.int32)
        n_split_ = np.array([5]).astype(np.int32)
        # Expected output:
        # [1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]
        op = core.CreateOperator(
            'LengthsSplit',
            ['input_lengths',
             'n_split'],
            ['Y'],
        )

        # Check against numpy reference
        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[input_lengths,
                    n_split_],
            reference=self._length_split_op_ref,
        )
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])

    @given(**hu.gcs_cpu_only)
    @settings(deadline=10000)
    def test_length_split_arg(self, gc, dc):
        input_lengths = np.array([9, 4, 5]).astype(np.int32)
        n_split = 3
        # Expected output:
        # [3, 3, 3, 2, 1, 1, 2, 2, 1]
        op = core.CreateOperator(
            'LengthsSplit',
            ['input_lengths'],
            ['Y'], n_split=n_split
        )

        # Check against numpy reference
        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[input_lengths],
            reference=lambda x : self._length_split_op_ref(x, [n_split]),
        )
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [input_lengths], [0])

    @given(**hu.gcs_cpu_only)
    @settings(deadline=10000)
    def test_length_split_override_arg(self, gc, dc):
        input_lengths = np.array([9, 4, 5]).astype(np.int32)
        n_split_ignored = 2
        n_split_used = np.array([3]).astype(np.int32)

        op = core.CreateOperator(
            'LengthsSplit',
            ['input_lengths',
             'n_split'],
            ['Y'], n_split=n_split_ignored
        )

        # Check against numpy reference
        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[input_lengths,
                    n_split_used],
            reference=self._length_split_op_ref,
        )
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [input_lengths, n_split_used], [0])

    @given(m=st.integers(1, 100), n_split=st.integers(1, 20),
           **hu.gcs_cpu_only)
    @settings(deadline=10000)
    def test_length_split_even_divide(self, m, n_split, gc, dc):
        # multiples of n_split
        input_lengths = np.random.randint(100, size=m).astype(np.int32) * n_split
        n_split_ = np.array([n_split]).astype(np.int32)

        op = core.CreateOperator(
            'LengthsSplit',
            ['input_lengths',
             'n_split'],
            ['Y'],
        )

        # Check against numpy reference
        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[input_lengths,
                    n_split_],
            reference=self._length_split_op_ref,
        )
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])

    @given(m=st.integers(1, 100), n_split=st.integers(1, 20),
           **hu.gcs_cpu_only)
    @settings(deadline=10000)
    def test_length_split_random(self, m, n_split, gc, dc):
        input_lengths = np.random.randint(100, size=m).astype(np.int32)
        n_split_ = np.array([n_split]).astype(np.int32)

        op = core.CreateOperator(
            'LengthsSplit',
            ['input_lengths',
             'n_split'],
            ['Y'],
        )

        # Check against numpy reference
        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[input_lengths,
                    n_split_],
            reference=self._length_split_op_ref,
        )
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])


if __name__ == "__main__":
    import unittest
    unittest.main()