Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various function for graph editing."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.ops import array_ops as tf_array_ops

__all__ = [
    "detach_control_inputs",
    "detach_control_outputs",
    "detach_inputs",
    "detach_outputs",
    "detach",
    "connect",
    "bypass",
]


def detach_control_inputs(sgv):
  """Detach all the external control inputs of the subgraph sgv.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)
  for op in sgv.ops:
    cops = [cop for cop in op.control_inputs if cop not in sgv.ops]
    reroute.remove_control_inputs(op, cops)


def detach_control_outputs(sgv, control_outputs):
  """Detach all the external control outputs of the subgraph sgv.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
    control_outputs: a util.ControlOutputs instance.
  """
  if not isinstance(control_outputs, util.ControlOutputs):
    raise TypeError("Expected a util.ControlOutputs, got: {}",
                    type(control_outputs))
  control_outputs.update()
  sgv = subgraph.make_view(sgv)
  for op in sgv.ops:
    for cop in control_outputs.get(op):
      if cop not in sgv.ops:
        reroute.remove_control_inputs(cop, op)


def detach_inputs(sgv, control_inputs=False):
  """Detach the inputs of a subgraph view.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
      Note that sgv is modified in place.
    control_inputs: if True control_inputs are also detached.
  Returns:
    A tuple `(sgv, input_placeholders)` where
      `sgv` is a new subgraph view of the detached subgraph;
      `input_placeholders` is a list of the created input placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)

  with sgv.graph.as_default():
    input_placeholders = [
        tf_array_ops.placeholder(
            dtype=input_t.dtype, name=util.placeholder_name(input_t))
        for input_t in sgv.inputs
    ]

  reroute.swap_inputs(sgv, input_placeholders)
  if control_inputs:
    detach_control_inputs(sgv)
  return sgv, input_placeholders


def detach_outputs(sgv, control_outputs=None):
  """Detach the output of a subgraph view.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
      Note that sgv is modified in place.
    control_outputs: a util.ControlOutputs instance or None. If not None the
      control outputs are also detached.
  Returns:
    A tuple `(sgv, output_placeholders)` where
      `sgv` is a new subgraph view of the detached subgraph;
      `output_placeholders` is a list of the created output placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)
  # only select outputs with consumers
  sgv_ = sgv.remap_outputs([output_id
                            for output_id, output_t in enumerate(sgv.outputs)
                            if output_t.consumers()])
  # create consumer subgraph and remap
  consumers_sgv = subgraph.SubGraphView(sgv_.consumers())
  consumers_sgv = consumers_sgv.remap_inputs(
      [input_id for input_id, input_t in enumerate(consumers_sgv.inputs)
       if input_t in sgv_.outputs])

  with sgv_.graph.as_default():
    output_placeholders = [
        util.make_placeholder_from_tensor(input_t)
        for input_t in consumers_sgv.inputs
    ]

  reroute.swap_outputs(sgv_, output_placeholders)
  if control_outputs is not None:
    detach_control_outputs(sgv_, control_outputs)
  return sgv_, output_placeholders


def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None):
  """Detach both the inputs and the outputs of a subgraph view.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
      Note that sgv is modified in place.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of util.ControlOutputs or None. If not None,
      control outputs are enabled.
    control_ios:  An instance of util.ControlOutputs or None. If not None, both
      control inputs and control outputs are enabled. This is equivalent to set
      control_inputs to True and control_outputs to the util.ControlOutputs
      instance.
  Returns:
    A tuple `(sgv, detached_inputs, detached_outputs)` where:
    `sgv` is a new subgraph view of the detached subgraph;
    `detach_inputs` is a list of the created input placeholders;
    `detach_outputs` is a list of the created output placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  control_inputs, control_outputs = select.check_cios(control_inputs,
                                                      control_outputs,
                                                      control_ios)
  _, detached_inputs = detach_inputs(sgv, control_inputs)
  _, detached_outputs = detach_outputs(sgv, control_outputs)
  return sgv, detached_inputs, detached_outputs


def connect(sgv0, sgv1, disconnect_first=False):
  """Connect the outputs of sgv0 to the inputs of sgv1.

  Args:
    sgv0: the first subgraph to have its outputs swapped. This argument is
      converted to a subgraph using the same rules as the function
      subgraph.make_view.
      Note that sgv0 is modified in place.
    sgv1: the second subgraph to have its outputs swapped. This argument is
      converted to a subgraph using the same rules as the function
      subgraph.make_view.
      Note that sgv1 is modified in place.
    disconnect_first: if True the current outputs of sgv0 are disconnected.
  Returns:
    A tuple `(sgv0, sgv1)` of the now connected subgraphs.
  Raises:
    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  sgv0 = subgraph.make_view(sgv0)
  sgv1 = subgraph.make_view(sgv1)
  util.check_graphs(sgv0, sgv1)
  if disconnect_first:
    detach_outputs(sgv0)
  sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
  reroute.reroute_inputs(sgv0_outputs, sgv1)
  return sgv0, sgv1


def bypass(sgv):
  """Bypass the given subgraph by connecting its inputs to its outputs.

  Args:
    sgv: the subgraph view to be bypassed. This argument is converted to a
      subgraph using the same rules than the function subgraph.make_view.
      Note that sgv is modified in place.
  Returns:
    A tuple `(sgv, detached_inputs)` where:
      `sgv` is a new subgraph view of the bypassed subgraph;
      `detached_inputs` is a list of the created input placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
  sgv = subgraph.make_view(sgv)
  sgv_inputs = list(sgv.inputs)
  sgv, detached_inputs = detach_inputs(sgv)
  reroute.reroute_ts(sgv_inputs, sgv.outputs)
  return sgv, detached_inputs