# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
import numpy as np
from scipy.sparse import coo_matrix
from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
def test_reshape(old_shape, new_shape, stride_only=False):
blob_in0 = 'col'
blob_out0 = 'col_out'
blob_in1 = 'row'
blob_out1 = 'row_out'
old_shape_for_op = (-1, old_shape[1]) if stride_only else old_shape
op = core.CreateOperator('SparseMatrixReshape',
[blob_in0, blob_in1],
[blob_out0, blob_out1],
old_shape=old_shape_for_op,
new_shape=new_shape)
A = np.random.random_sample(old_shape)
A[np.random.random_sample(old_shape) > .5] = 0
A_coo = coo_matrix(A)
old_row, old_col = A_coo.row, A_coo.col
workspace.FeedBlob(blob_in0, old_col.astype(np.int64))
workspace.FeedBlob(blob_in1, old_row.astype(np.int32))
workspace.RunOperatorOnce(op)
A_new_coo = coo_matrix(A.reshape(new_shape))
new_row, new_col = A_new_coo.row, A_new_coo.col
col_out = workspace.FetchBlob(blob_out0)
row_out = workspace.FetchBlob(blob_out1)
np.testing.assert_array_equal(col_out, new_col)
np.testing.assert_array_equal(row_out, new_row)
class TestSparseMatrixReshapeOp(TestCase):
def test_basic_reshape(self):
test_reshape(old_shape=(3, 4), new_shape=(4, 3))
def test_missing_dim(self):
test_reshape(old_shape=(2, 8), new_shape=(-1, 4))
def test_stride_only(self):
test_reshape(old_shape=(2, 8), new_shape=(-1, 4), stride_only=True)
def test_sparse_reshape_mm(self):
M, N, K = 300, 400, 500
A = np.random.rand(M, K).astype(np.float32)
A_sparse = A * (np.random.rand(*A.shape) > .5)
A_sparse = A_sparse.reshape((K, M))
A_coo = coo_matrix(A_sparse)
idx0, idx1, a = A_coo.row, A_coo.col, A_coo.data
B = np.random.rand(K, N).astype(np.float32)
workspace.FeedBlob('col', idx1.astype(np.int64))
workspace.FeedBlob('row', idx0.astype(np.int32))
workspace.FeedBlob('B', B)
workspace.FeedBlob('a', a)
reshape_op = core.CreateOperator(
'SparseMatrixReshape',
['col', 'row'],
['new_col', 'new_row'],
old_shape=(K, M),
new_shape=(M, K))
mm_op = core.CreateOperator(
'SparseUnsortedSegmentWeightedSum',
['B', 'a', 'new_col', 'new_row'],
['Y'])
workspace.RunOperatorOnce(reshape_op)
workspace.RunOperatorOnce(mm_op)
Y = workspace.FetchBlob('Y')
np.testing.assert_allclose(A_sparse.reshape(M, K).dot(B), Y,
rtol=1e-4)