Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ experiments / python / sparse_reshape_op_test.py

# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################






import numpy as np
from scipy.sparse import coo_matrix

from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase


def test_reshape(old_shape, new_shape, stride_only=False):
    blob_in0 = 'col'
    blob_out0 = 'col_out'

    blob_in1 = 'row'
    blob_out1 = 'row_out'

    old_shape_for_op = (-1, old_shape[1]) if stride_only else old_shape

    op = core.CreateOperator('SparseMatrixReshape',
                             [blob_in0, blob_in1],
                             [blob_out0, blob_out1],
                             old_shape=old_shape_for_op,
                             new_shape=new_shape)

    A = np.random.random_sample(old_shape)
    A[np.random.random_sample(old_shape) > .5] = 0
    A_coo = coo_matrix(A)
    old_row, old_col = A_coo.row, A_coo.col

    workspace.FeedBlob(blob_in0, old_col.astype(np.int64))
    workspace.FeedBlob(blob_in1, old_row.astype(np.int32))

    workspace.RunOperatorOnce(op)

    A_new_coo = coo_matrix(A.reshape(new_shape))
    new_row, new_col = A_new_coo.row, A_new_coo.col

    col_out = workspace.FetchBlob(blob_out0)
    row_out = workspace.FetchBlob(blob_out1)

    np.testing.assert_array_equal(col_out, new_col)
    np.testing.assert_array_equal(row_out, new_row)


class TestSparseMatrixReshapeOp(TestCase):
    def test_basic_reshape(self):
        test_reshape(old_shape=(3, 4), new_shape=(4, 3))

    def test_missing_dim(self):
        test_reshape(old_shape=(2, 8), new_shape=(-1, 4))

    def test_stride_only(self):
        test_reshape(old_shape=(2, 8), new_shape=(-1, 4), stride_only=True)

    def test_sparse_reshape_mm(self):
        M, N, K = 300, 400, 500
        A = np.random.rand(M, K).astype(np.float32)
        A_sparse = A * (np.random.rand(*A.shape) > .5)
        A_sparse = A_sparse.reshape((K, M))
        A_coo = coo_matrix(A_sparse)
        idx0, idx1, a = A_coo.row, A_coo.col, A_coo.data
        B = np.random.rand(K, N).astype(np.float32)

        workspace.FeedBlob('col', idx1.astype(np.int64))
        workspace.FeedBlob('row', idx0.astype(np.int32))
        workspace.FeedBlob('B', B)
        workspace.FeedBlob('a', a)

        reshape_op = core.CreateOperator(
            'SparseMatrixReshape',
            ['col', 'row'],
            ['new_col', 'new_row'],
            old_shape=(K, M),
            new_shape=(M, K))

        mm_op = core.CreateOperator(
            'SparseUnsortedSegmentWeightedSum',
            ['B', 'a', 'new_col', 'new_row'],
            ['Y'])

        workspace.RunOperatorOnce(reshape_op)
        workspace.RunOperatorOnce(mm_op)

        Y = workspace.FetchBlob('Y')
        np.testing.assert_allclose(A_sparse.reshape(M, K).dot(B), Y,
                                   rtol=1e-4)