Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / operator_test / top_k_test.py






import hypothesis.strategies as st
import numpy as np

from caffe2.python import core
from hypothesis import given, settings
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial


class TestTopK(serial.SerializedTestCase):

    def top_k_ref(self, X, k, flatten_indices, axis=-1):
        in_dims = X.shape
        out_dims = list(in_dims)
        out_dims[axis] = k
        out_dims = tuple(out_dims)
        if axis == -1:
            axis = len(in_dims) - 1
        prev_dims = 1
        next_dims = 1
        for i in range(axis):
            prev_dims *= in_dims[i]
        for i in range(axis + 1, len(in_dims)):
            next_dims *= in_dims[i]
        n = in_dims[axis]
        X_flat = X.reshape((prev_dims, n, next_dims))

        values_ref = np.ndarray(
            shape=(prev_dims, k, next_dims), dtype=np.float32)
        values_ref.fill(0)
        indices_ref = np.ndarray(
            shape=(prev_dims, k, next_dims), dtype=np.int64)
        indices_ref.fill(-1)
        flatten_indices_ref = np.ndarray(
            shape=(prev_dims, k, next_dims), dtype=np.int64)
        flatten_indices_ref.fill(-1)
        for i in range(prev_dims):
            for j in range(next_dims):
                kv = []
                for x in range(n):
                    val = X_flat[i, x, j]
                    y = x * next_dims + i * in_dims[axis] * next_dims + j
                    kv.append((val, x, y))
                cnt = 0
                for val, x, y in sorted(
                        kv, key=lambda x: (x[0], -x[1]), reverse=True):
                    values_ref[i, cnt, j] = val
                    indices_ref[i, cnt, j] = x
                    flatten_indices_ref[i, cnt, j] = y
                    cnt += 1
                    if cnt >= k or cnt >= n:
                        break

        values_ref = values_ref.reshape(out_dims)
        indices_ref = indices_ref.reshape(out_dims)
        flatten_indices_ref = flatten_indices_ref.flatten()

        if flatten_indices:
            return (values_ref, indices_ref, flatten_indices_ref)
        else:
            return (values_ref, indices_ref)

    @serial.given(
        X=hu.tensor(),
        flatten_indices=st.booleans(),
        seed=st.integers(0, 10),
        **hu.gcs
    )
    def test_top_k(self, X, flatten_indices, seed, gc, dc):
        X = X.astype(dtype=np.float32)
        np.random.seed(seed)
        # `k` can be larger than the total size
        k = np.random.randint(1, X.shape[-1] + 4)

        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(bs=st.integers(1, 3), n=st.integers(1, 1), k=st.integers(1, 1),
           flatten_indices=st.booleans(), **hu.gcs)
    def test_top_k_1(self, bs, n, k, flatten_indices, gc, dc):
        X = np.random.rand(bs, n).astype(dtype=np.float32)
        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(bs=st.integers(1, 3), n=st.integers(1, 10000), k=st.integers(1, 1),
           flatten_indices=st.booleans(), **hu.gcs)
    def test_top_k_2(self, bs, n, k, flatten_indices, gc, dc):
        X = np.random.rand(bs, n).astype(dtype=np.float32)

        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(bs=st.integers(1, 3), n=st.integers(1, 10000),
           k=st.integers(1, 1024), flatten_indices=st.booleans(), **hu.gcs)
    def test_top_k_3(self, bs, n, k, flatten_indices, gc, dc):
        X = np.random.rand(bs, n).astype(dtype=np.float32)
        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(bs=st.integers(1, 3), n=st.integers(100, 10000),
           flatten_indices=st.booleans(), **hu.gcs)
    @settings(deadline=1000)
    def test_top_k_4(self, bs, n, flatten_indices, gc, dc):
        k = np.random.randint(n // 3, 3 * n // 4)
        X = np.random.rand(bs, n).astype(dtype=np.float32)

        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(bs=st.integers(1, 3), n=st.integers(1, 1024),
           flatten_indices=st.booleans(), **hu.gcs)
    def test_top_k_5(self, bs, n, flatten_indices, gc, dc):
        k = n
        X = np.random.rand(bs, n).astype(dtype=np.float32)

        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(bs=st.integers(1, 3), n=st.integers(1, 5000),
           flatten_indices=st.booleans(), **hu.gcs)
    @settings(deadline=1000)
    def test_top_k_6(self, bs, n, flatten_indices, gc, dc):
        k = n
        X = np.random.rand(bs, n).astype(dtype=np.float32)

        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator("TopK", ["X"], output_list,
                                 k=k, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(X=hu.tensor(dtype=np.float32), k=st.integers(1, 5),
           axis=st.integers(-1, 5), flatten_indices=st.booleans(),
           **hu.gcs)
    def test_top_k_axis(self, X, k, axis, flatten_indices, gc, dc):
        dims = X.shape
        if axis >= len(dims):
            axis %= len(dims)

        output_list = ["Values", "Indices"]
        if flatten_indices:
            output_list.append("FlattenIndices")
        op = core.CreateOperator(
            "TopK", ["X"], output_list, k=k, axis=axis, device_option=gc)

        def bind_ref(X_loc):
            return self.top_k_ref(X_loc, k, flatten_indices, axis)

        self.assertReferenceChecks(gc, op, [X], bind_ref)
        self.assertDeviceChecks(dc, op, [X], [0])

    @given(X=hu.tensor(dtype=np.float32), k=st.integers(1, 5),
           axis=st.integers(-1, 5), **hu.gcs)
    @settings(deadline=10000)
    def test_top_k_grad(self, X, k, axis, gc, dc):
        dims = X.shape
        if axis >= len(dims):
            axis %= len(dims)

        input_axis = len(dims) - 1 if axis == -1 else axis
        prev_dims = 1
        next_dims = 1
        for i in range(input_axis):
            prev_dims *= dims[i]
        for i in range(input_axis + 1, len(dims)):
            next_dims *= dims[i]

        X_flat = X.reshape((prev_dims, dims[input_axis], next_dims))
        for i in range(prev_dims):
            for j in range(next_dims):
                # this try to make sure adding stepsize (0.05)
                # will not change TopK selections at all
                X_flat[i, :, j] = np.arange(dims[axis], dtype=np.float32) / 5
                np.random.shuffle(X_flat[i, :, j])
        X = X_flat.reshape(dims)

        op = core.CreateOperator(
            "TopK", ["X"], ["Values", "Indices"], k=k, axis=axis,
            device_option=gc)

        self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.05)