import unittest
import hypothesis.strategies as st
from hypothesis import given
import numpy as np
from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.mkl_test_util as mu
@unittest.skipIf(
not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn."
)
class MKLConcatTest(hu.HypothesisTestCase):
@given(
batch_size=st.integers(1, 10),
channel_splits=st.lists(st.integers(1, 10), min_size=1, max_size=3),
height=st.integers(1, 10),
width=st.integers(1, 10),
**mu.gcs
)
def test_mkl_concat(
self, batch_size, channel_splits, height, width, gc, dc
):
Xs = [
np.random.rand(batch_size, channel,
height, width).astype(np.float32)
for channel in channel_splits
]
op = core.CreateOperator(
"Concat",
["X_{}".format(i) for i in range(len(Xs))],
["concat_result", "split_info"],
order="NCHW",
)
self.assertDeviceChecks(dc, op, Xs, [0])
if __name__ == "__main__":
import unittest
unittest.main()