Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensorflow / purelib / tensorflow / python / data / util / traverse.py
Size: Mime:
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers to traverse the Dataset dependency structure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import queue as Queue  # pylint: disable=redefined-builtin

from tensorflow.python.framework import dtypes


def obtain_all_variant_tensor_ops(dataset):
  """Given an input dataset, finds all dataset ops used for construction.

  A series of transformations would have created this dataset with each
  transformation including zero or more Dataset ops, each producing a dataset
  variant tensor. This method outputs all of them.

  Args:
    dataset: Dataset to find variant tensors for.

  Returns:
    A list of variant_tensor producing dataset ops used to construct this
    dataset.
  """
  all_variant_tensor_ops = []
  bfs_q = Queue.Queue()
  bfs_q.put(dataset._variant_tensor.op)  # pylint: disable=protected-access
  visited = []
  while not bfs_q.empty():
    op = bfs_q.get()
    visited.append(op)
    # We look for all ops that produce variant tensors as output. This is a bit
    # of overkill but the other dataset _inputs() traversal strategies can't
    # cover the case of function inputs that capture dataset variants.
    # TODO(b/120873778): Make this more efficient.
    if op.outputs[0].dtype == dtypes.variant:
      all_variant_tensor_ops.append(op)
    for i in op.inputs:
      input_op = i.op
      if input_op not in visited:
        bfs_q.put(input_op)
  return all_variant_tensor_ops