Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers to traverse the Dataset dependency structure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
def obtain_all_variant_tensor_ops(dataset):
"""Given an input dataset, finds all dataset ops used for construction.
A series of transformations would have created this dataset with each
transformation including zero or more Dataset ops, each producing a dataset
variant tensor. This method outputs all of them.
Args:
dataset: Dataset to find variant tensors for.
Returns:
A list of variant_tensor producing dataset ops used to construct this
dataset.
"""
all_variant_tensor_ops = []
bfs_q = Queue.Queue()
bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access
visited = []
while not bfs_q.empty():
op = bfs_q.get()
visited.append(op)
# We look for all ops that produce variant tensors as output. This is a bit
# of overkill but the other dataset _inputs() traversal strategies can't
# cover the case of function inputs that capture dataset variants.
# TODO(b/120873778): Make this more efficient.
if op.outputs[0].dtype == dtypes.variant:
all_variant_tensor_ops.append(op)
for i in op.inputs:
input_op = i.op
if input_op not in visited:
bfs_q.put(input_op)
return all_variant_tensor_ops