Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / operator_test / bucketize_op_test.py





from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
import numpy as np


class TestBucketizeOp(hu.HypothesisTestCase):
    @given(
        x=hu.tensor(
            min_dim=1, max_dim=2, dtype=np.float32,
            elements=hu.floats(min_value=-5, max_value=5)),
        **hu.gcs)
    def test_bucketize_op(self, x, gc, dc):
        length = np.random.randint(low=1, high=5)
        boundaries = np.random.randn(length) * 5
        boundaries.sort()

        def ref(x, boundaries):
            bucket_idx = np.digitize(x, boundaries, right=True)
            return [bucket_idx]

        op = core.CreateOperator('Bucketize',
                                 ["X"], ["INDICES"],
                                 boundaries=boundaries)
        self.assertReferenceChecks(gc, op, [x, boundaries], ref)


if __name__ == "__main__":
    import unittest
    unittest.main()