Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / operator_test / matmul_op_test.py






import inspect

import numpy as np

from hypothesis import assume, given, settings
import hypothesis.strategies as st
from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial


class TestMatMul(serial.SerializedTestCase):
    @serial.given(
        M=st.integers(min_value=1, max_value=10),
        K=st.integers(min_value=1, max_value=10),
        N=st.integers(min_value=1, max_value=10),
        trans_a=st.booleans(),
        trans_b=st.booleans(),
        **hu.gcs
    )
    def test_matmul(self, M, K, N, trans_a, trans_b, gc, dc):
        X = np.random.rand(M, K).astype(np.float32) - 0.5
        if trans_a:
            X = X.transpose()

        Y = np.random.rand(K, N).astype(np.float32) - 0.5
        if trans_b:
            Y = Y.transpose()

        op = core.CreateOperator(
            'MatMul', ['X', 'Y'], 'out', trans_a=trans_a, trans_b=trans_b
        )

        def matmul_ref(X, Y, trans_a, trans_b):
            XX = X.transpose() if trans_a else X
            YY = Y.transpose() if trans_b else Y
            return (XX.dot(YY), )

        # Check against numpy reference
        self.assertReferenceChecks(gc, op, [X, Y, trans_a, trans_b], matmul_ref)
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [X, Y], [0])
        # Gradient check wrt X
        self.assertGradientChecks(gc, op, [X, Y], 0, [0])
        # Gradient check wrt Y
        self.assertGradientChecks(gc, op, [X, Y], 1, [0])

    @given(
        M=st.integers(min_value=1, max_value=10),
        K=st.integers(min_value=1, max_value=10),
        N=st.integers(min_value=1, max_value=10),
        axis_a=st.sampled_from([-3, -2, -1, 1, 2, 3]),
        axis_b=st.sampled_from([-3, -2, -1, 1, 2, 3]),
        trans_a=st.booleans(),
        trans_b=st.booleans(),
        **hu.gcs
    )
    @settings(deadline=1000)
    def test_matmul_axis(
        self, M, K, N, axis_a, axis_b, trans_a, trans_b, gc, dc
    ):
        X = np.random.rand(M, K).astype(np.float32) - 0.5
        if trans_a:
            X = X.transpose()
        shape_x = [X.shape[0], 1, 1, 1]
        shape_x[axis_a] = X.shape[1]
        X = X.reshape(*shape_x)

        Y = np.random.rand(K, N).astype(np.float32) - 0.5
        if trans_b:
            Y = Y.transpose()
        shape_y = [Y.shape[0], 1, 1, 1]
        shape_y[axis_b] = Y.shape[1]
        Y = Y.reshape(*shape_y)
        op = core.CreateOperator(
            'MatMul', ['X', 'Y'],
            'out',
            axis_a=axis_a,
            axis_b=axis_b,
            trans_a=trans_a,
            trans_b=trans_b
        )

        def size_to_dim(X, axis):
            dim = 1
            for i in range(axis):
                dim *= X.shape[i]
            return dim

        def size_from_dim(X, axis):
            dim = 1
            for i in range(axis, X.ndim):
                dim *= X.shape[i]
            return dim

        def reshape(X, axis):
            dim_0, dim_1 = size_to_dim(X, axis), size_from_dim(X, axis)
            return X.reshape(dim_0, dim_1)

        def canonical_axis(axis, ndim):
            return ndim + axis if axis < 0 else axis

        def matmul_ref(X, Y, axis_a, axis_b, trans_a, trans_b):
            can_axis_a = canonical_axis(axis_a, X.ndim)
            can_axis_b = canonical_axis(axis_b, Y.ndim)
            X, Y = reshape(X, can_axis_a), reshape(Y, can_axis_b)
            XX = X.transpose() if trans_a else X
            YY = Y.transpose() if trans_b else Y
            return (XX.dot(YY), )

        # Check against numpy reference
        self.assertReferenceChecks(
            gc, op, [X, Y, axis_a, axis_b, trans_a, trans_b], matmul_ref
        )
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [X, Y], [0])
        # Gradient check wrt X
        self.assertGradientChecks(gc, op, [X, Y], 0, [0])
        # Gradient check wrt Y
        self.assertGradientChecks(gc, op, [X, Y], 1, [0])


class TestBatchMatMul(serial.SerializedTestCase):
    @settings(max_examples=30, deadline=None)
    @given(
        C=st.integers(min_value=0, max_value=3),  # number of batch dims
        M=st.integers(min_value=1, max_value=10),
        K=st.integers(min_value=1, max_value=10),
        N=st.integers(min_value=1, max_value=10),
        trans_a=st.booleans(),
        trans_b=st.booleans(),
        dtype=st.sampled_from([np.float32, np.float16]),
        **hu.gcs
    )
    def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, dtype, gc, dc):
        if dtype == np.float16:
            # fp16 is only supported with CUDA/HIP
            assume(core.IsGPUDeviceType(gc.device_type))
            dc = [d for d in dc if core.IsGPUDeviceType(d.device_type)]

        batch_dims = np.random.randint(
            low=1,
            high=3,
            size=C,
            dtype=np.int64).tolist()
        X = np.random.rand(*(batch_dims + [M, K])).astype(dtype) - 0.5
        if trans_a:
            X = X.swapaxes(-1, -2)
        Y = np.random.rand(*(batch_dims + [K, N])).astype(dtype) - 0.5
        if trans_b:
            Y = Y.swapaxes(-1, -2)

        op = core.CreateOperator(
            'BatchMatMul', ['X', 'Y'], 'out', trans_a=trans_a, trans_b=trans_b
        )

        def matmul_ref(X, Y, trans_a, trans_b, dtype):
            XX = (X.swapaxes(-1, -2) if trans_a else X).astype(np.float32)
            YY = (Y.swapaxes(-1, -2) if trans_b else Y).astype(np.float32)
            return (np.matmul(XX, YY).astype(dtype),)

        # relaxing the "threshold" for fp16 to 150x of the default
        def relax_fp16_check(check_func, *args, **kwargs):
            # inspect the default "threshold" value in check_func
            argspec = inspect.getargspec(check_func)
            threshold = argspec.defaults[
                argspec.args.index('threshold') -
                (len(argspec.args) - len(argspec.defaults))]

            if dtype == np.float16:
                threshold = 150 * threshold
            check_func(*args, threshold=threshold, **kwargs)

        # Check against numpy reference
        relax_fp16_check(self.assertReferenceChecks, gc, op, [X, Y, trans_a, trans_b, dtype], matmul_ref)
        # Check over multiple devices
        relax_fp16_check(self.assertDeviceChecks, dc, op, [X, Y], [0])
        # Gradient check wrt X
        relax_fp16_check(self.assertGradientChecks, gc, op, [X, Y], 0, [0])
        # Gradient check wrt Y
        relax_fp16_check(self.assertGradientChecks, gc, op, [X, Y], 1, [0])

    def _test_batch_matmul_with_broadcast_common(
        self,
        X,
        Y,
        dtype,
        gc,
        dc,
        trans_a=None,
        trans_b=None,
    ):
        if trans_a is not None and trans_b is not None:
            op = core.CreateOperator(
                'BatchMatMul', ['X', 'Y'], 'out', trans_a=trans_a, trans_b=trans_b, broadcast=1
            )
        else:
            op = core.CreateOperator(
                'BatchMatMul', ['X', 'Y'], 'out', broadcast=1
            )

        def matmul_ref(X, Y, trans_a, trans_b, dtype):
            XX = (X.swapaxes(-1, -2) if trans_a else X).astype(np.float32)
            YY = (Y.swapaxes(-1, -2) if trans_b else Y).astype(np.float32)
            return (np.matmul(XX, YY).astype(dtype),)

        # Check against numpy reference
        self.assertReferenceChecks(gc, op, [X, Y, trans_a, trans_b, dtype], matmul_ref)
        # Check over multiple devices
        self.assertDeviceChecks(dc, op, [X, Y], [0])

    @given(
        C_1=st.integers(min_value=0, max_value=3),  # number of batch dims
        C_2=st.integers(min_value=0, max_value=3),
        M=st.integers(min_value=1, max_value=10),
        K=st.integers(min_value=1, max_value=10),
        N=st.integers(min_value=1, max_value=10),
        trans_a=st.booleans(),
        trans_b=st.booleans(),
        **hu.gcs
    )
    @settings(deadline=10000)
    def test_numpy_batch_matmul(self, C_1, C_2, M, K, N, trans_a, trans_b, gc, dc):
        dtype = np.float32
        batch_dims = np.random.randint(
            low=0,
            high=3,
            size=max(C_1, C_2),
            dtype=np.int64).tolist()
        lbd = len(batch_dims)
        X = np.random.rand(*(batch_dims[lbd - C_1:] + [M, K])).astype(dtype) - 0.5
        if trans_a:
            X = X.swapaxes(-1, -2)
        Y = np.random.rand(*(batch_dims[lbd - C_2:] + [K, N])).astype(dtype) - 0.5
        if trans_b:
            Y = Y.swapaxes(-1, -2)

        self._test_batch_matmul_with_broadcast_common(X, Y, dtype, gc, dc, trans_a, trans_b)

    @settings(max_examples=30, deadline=None)
    @given(
        K=st.integers(min_value=1, max_value=10),
        **hu.gcs
    )
    def test_numpy_batch_matmul_1d(self, K, gc, dc):
        dtype = np.float32
        X = np.random.rand(K).astype(dtype) - 0.5
        # TODO: test trans_a and trans_b
        Y = np.random.rand(K).astype(dtype) - 0.5

        self._test_batch_matmul_with_broadcast_common(X, Y, dtype, gc, dc)

    @settings(max_examples=30, deadline=None)
    @given(
        K=st.integers(min_value=1, max_value=10),
        N=st.integers(min_value=1, max_value=10),
        **hu.gcs
    )
    def test_numpy_batch_matmul_1d_2d(self, K, N, gc, dc):
        dtype = np.float32
        X = np.random.rand(K).astype(dtype) - 0.5
        # TODO: test trans_a and trans_b
        Y = np.random.rand(*[K, N]).astype(dtype) - 0.5

        self._test_batch_matmul_with_broadcast_common(X, Y, dtype, gc, dc)

    @settings(max_examples=30, deadline=None)
    @given(
        M=st.integers(min_value=1, max_value=10),
        K=st.integers(min_value=1, max_value=10),
        **hu.gcs
    )
    def test_numpy_batch_matmul_2d_1d(self, M, K, gc, dc):
        dtype = np.float32
        X = np.random.rand(*[M, K]).astype(dtype) - 0.5
        # TODO: test trans_a and trans_b
        Y = np.random.rand(K).astype(dtype) - 0.5

        self._test_batch_matmul_with_broadcast_common(X, Y, dtype, gc, dc)


if __name__ == "__main__":
    import unittest
    unittest.main()