Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Test utils for tensorflow RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
class RaggedTensorTestCase(test_util.TensorFlowTestCase):
"""Base class for RaggedTensor test cases."""
def _GetPyList(self, a):
"""Converts a to a nested python list."""
if isinstance(a, ragged_tensor.RaggedTensor):
return self.evaluate(a).to_list()
elif isinstance(a, ops.Tensor):
a = self.evaluate(a)
return a.tolist() if isinstance(a, np.ndarray) else a
elif isinstance(a, np.ndarray):
return a.tolist()
elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
return a.to_list()
else:
return np.array(a).tolist()
def assertRaggedEqual(self, a, b):
"""Asserts that two potentially ragged tensors are equal."""
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertEqual(a_list, b_list)
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank)
def assertRaggedAlmostEqual(self, a, b, places=7):
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank)
def assertNestedListAlmostEqual(self, a, b, places=7, context='value'):
self.assertEqual(type(a), type(b))
if isinstance(a, (list, tuple)):
self.assertLen(a, len(b), 'Length differs for %s' % context)
for i in range(len(a)):
self.assertNestedListAlmostEqual(a[i], b[i], places,
'%s[%s]' % (context, i))
else:
self.assertAlmostEqual(
a, b, places,
'%s != %s within %s places at %s' % (a, b, places, context))
def eval_to_list(self, tensor):
value = self.evaluate(tensor)
if ragged_tensor.is_ragged(value):
return value.to_list()
elif isinstance(value, np.ndarray):
return value.tolist()
else:
return value
def _eval_tensor(self, tensor):
if ragged_tensor.is_ragged(tensor):
return ragged_tensor_value.RaggedTensorValue(
self._eval_tensor(tensor.values),
self._eval_tensor(tensor.row_splits))
else:
return test_util.TensorFlowTestCase._eval_tensor(self, tensor)
@staticmethod
def _normalize_pylist(item):
"""Convert all (possibly nested) np.arrays contained in item to list."""
# convert np.arrays in current level to list
if np.ndim(item) == 0:
return item
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
_normalize = RaggedTensorTestCase._normalize_pylist
return [_normalize(el) if np.ndim(el) != 0 else el for el in level]