Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / operator_test / onnx_while_test.py





from caffe2.proto import caffe2_pb2
from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
from hypothesis import given, settings
import hypothesis.strategies as st
import numpy as np
import unittest


class TestONNXWhile(serial.SerializedTestCase):
    @given(
        condition=st.booleans(),
        max_trip_count=st.integers(0, 100),
        save_scopes=st.booleans(),
        disable_scopes=st.booleans(),
        seed=st.integers(0, 65535),
        **hu.gcs_cpu_only)
    @settings(deadline=10000)
    def test_onnx_while_fibb(
            self, condition, max_trip_count, save_scopes, disable_scopes, seed, gc, dc):
        np.random.seed(seed)
        if disable_scopes:
            save_scopes = False

        # Create body net
        body_net = caffe2_pb2.NetDef()
        # Two loop carried dependencies: first and second
        body_net.external_input.extend(['i', 'cond', 'first', 'second'])
        body_net.external_output.extend(['cond_new', 'second', 'third', 'third'])
        add_op = core.CreateOperator(
            'Add',
            ['first', 'second'],
            ['third'],
        )
        print3 = core.CreateOperator(
            'Print',
            ['third'],
            [],
        )
        limit_const = core.CreateOperator(
            'ConstantFill',
            [],
            ['limit_const'],
            shape=[1],
            dtype=caffe2_pb2.TensorProto.FLOAT,
            value=100.0,
        )
        cond = core.CreateOperator(
            'LT',
            ['third', 'limit_const'],
            ['cond_new'],
        )
        body_net.op.extend([add_op, print3, limit_const, cond])

        while_op = core.CreateOperator(
            'ONNXWhile',
            ['max_trip_count', 'condition', 'first_init', 'second_init'],
            ['first_a', 'second_a', 'third_a'],
            body=body_net,
            has_cond=True,
            has_trip_count=True,
            save_scopes=save_scopes,
            disable_scopes=disable_scopes,
        )

        condition_arr = np.array(condition).astype(np.bool)
        max_trip_count_arr = np.array(max_trip_count).astype(np.int64)
        first_init = np.array([1]).astype(np.float32)
        second_init = np.array([1]).astype(np.float32)

        def ref(max_trip_count, condition, first_init, second_init):
            first = 1
            second = 1
            results = []
            if condition:
                for _ in range(max_trip_count):
                    third = first + second
                    first = second
                    second = third
                    results.append(third)
                    if third > 100:
                        break
            return (first, second, np.array(results).astype(np.float32))

        self.assertReferenceChecks(
            gc,
            while_op,
            [max_trip_count_arr, condition_arr, first_init, second_init],
            ref,
        )

if __name__ == "__main__":
    unittest.main()