Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / operator_test / lengths_pad_op_test.py






from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
import hypothesis.strategies as st
import numpy as np


class TestLengthsPadOp(serial.SerializedTestCase):

    @serial.given(
        inputs=hu.lengths_tensor(
            dtype=np.float32,
            min_value=1,
            max_value=5,
            allow_empty=True,
        ),
        delta_length=st.integers(0, 10),
        padding_value=st.floats(-10.0, 10.0),
        **hu.gcs
    )
    def test_lengths_pad(self, inputs, delta_length, padding_value, gc, dc):
        data, lengths = inputs
        max_length = np.max(lengths) if len(lengths) > 0 else 0
        target_length = max(max_length + delta_length, 1)

        def lengths_pad_op(data, lengths):
            N = len(lengths)
            output = np.ndarray(
                shape=(target_length * N, ) + data.shape[1:], dtype=np.float32)
            output.fill(padding_value)
            ptr1, ptr2 = 0, 0
            for i in range(N):
                output[ptr1:ptr1 + lengths[i]] = data[ptr2:ptr2 + lengths[i]]
                ptr1 += target_length
                ptr2 += lengths[i]

            return [output]

        op = core.CreateOperator(
            "LengthsPad",
            ["data", "lengths"],
            ["data_padded"],
            target_length=target_length,
            padding_value=padding_value,
        )

        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[data, lengths],
            reference=lengths_pad_op,
        )