Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / test / inference_lstm_op_test.py

#!/usr/bin/env python3

import hypothesis.strategies as st
import numpy as np
import torch
from caffe2.python import core
from caffe2.python.test_util import TestCase
from hypothesis import given, settings
from torch import nn


class TestC2LSTM(TestCase):
    @given(
        bsz=st.integers(1, 5),
        seq_lens=st.integers(1, 6),
        emb_lens=st.integers(5, 10),
        hidden_size=st.integers(3, 7),
        num_layers=st.integers(1, 4),
        has_biases=st.booleans(),
        is_bidirectional=st.booleans(),
        batch_first=st.booleans(),
    )
    @settings(deadline=10000)
    def test_c2_lstm(
        self,
        bsz,
        seq_lens,
        emb_lens,
        hidden_size,
        num_layers,
        has_biases,
        is_bidirectional,
        batch_first,
    ):
        net = core.Net("test_net")
        num_directions = 2 if is_bidirectional else 1
        py_lstm = nn.LSTM(
            emb_lens,
            hidden_size,
            batch_first=batch_first,
            bidirectional=is_bidirectional,
            bias=has_biases,
            num_layers=num_layers,
        )

        hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)

        if batch_first:
            inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
        else:
            inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)

        py_results = py_lstm(torch.from_numpy(inputs))
        lstm_in = [
            torch.from_numpy(inputs),
            torch.from_numpy(hx),
            torch.from_numpy(hx),
        ] + [param.detach() for param in py_lstm._flat_weights]

        c2_results = torch.ops._caffe2.InferenceLSTM(
            lstm_in, num_layers, has_biases, batch_first, is_bidirectional
        )

        np.testing.assert_array_almost_equal(
            py_results[0].detach().numpy(), c2_results[0].detach().numpy()
        )
        np.testing.assert_array_almost_equal(
            py_results[1][0].detach().numpy(), c2_results[1].detach().numpy()
        )
        np.testing.assert_array_almost_equal(
            py_results[1][1].detach().numpy(), c2_results[2].detach().numpy()
        )