Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various function for graph editing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.ops import array_ops as tf_array_ops
__all__ = [
"detach_control_inputs",
"detach_control_outputs",
"detach_inputs",
"detach_outputs",
"detach",
"connect",
"bypass",
]
def detach_control_inputs(sgv):
"""Detach all the external control inputs of the subgraph sgv.
Args:
sgv: the subgraph view to be detached. This argument is converted to a
subgraph using the same rules as the function subgraph.make_view.
"""
sgv = subgraph.make_view(sgv)
for op in sgv.ops:
cops = [cop for cop in op.control_inputs if cop not in sgv.ops]
reroute.remove_control_inputs(op, cops)
def detach_control_outputs(sgv, control_outputs):
"""Detach all the external control outputs of the subgraph sgv.
Args:
sgv: the subgraph view to be detached. This argument is converted to a
subgraph using the same rules as the function subgraph.make_view.
control_outputs: a util.ControlOutputs instance.
"""
if not isinstance(control_outputs, util.ControlOutputs):
raise TypeError("Expected a util.ControlOutputs, got: {}",
type(control_outputs))
control_outputs.update()
sgv = subgraph.make_view(sgv)
for op in sgv.ops:
for cop in control_outputs.get(op):
if cop not in sgv.ops:
reroute.remove_control_inputs(cop, op)
def detach_inputs(sgv, control_inputs=False):
"""Detach the inputs of a subgraph view.
Args:
sgv: the subgraph view to be detached. This argument is converted to a
subgraph using the same rules as the function subgraph.make_view.
Note that sgv is modified in place.
control_inputs: if True control_inputs are also detached.
Returns:
A tuple `(sgv, input_placeholders)` where
`sgv` is a new subgraph view of the detached subgraph;
`input_placeholders` is a list of the created input placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
sgv = subgraph.make_view(sgv)
with sgv.graph.as_default():
input_placeholders = [
tf_array_ops.placeholder(
dtype=input_t.dtype, name=util.placeholder_name(input_t))
for input_t in sgv.inputs
]
reroute.swap_inputs(sgv, input_placeholders)
if control_inputs:
detach_control_inputs(sgv)
return sgv, input_placeholders
def detach_outputs(sgv, control_outputs=None):
"""Detach the output of a subgraph view.
Args:
sgv: the subgraph view to be detached. This argument is converted to a
subgraph using the same rules as the function subgraph.make_view.
Note that sgv is modified in place.
control_outputs: a util.ControlOutputs instance or None. If not None the
control outputs are also detached.
Returns:
A tuple `(sgv, output_placeholders)` where
`sgv` is a new subgraph view of the detached subgraph;
`output_placeholders` is a list of the created output placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
sgv = subgraph.make_view(sgv)
# only select outputs with consumers
sgv_ = sgv.remap_outputs([output_id
for output_id, output_t in enumerate(sgv.outputs)
if output_t.consumers()])
# create consumer subgraph and remap
consumers_sgv = subgraph.SubGraphView(sgv_.consumers())
consumers_sgv = consumers_sgv.remap_inputs(
[input_id for input_id, input_t in enumerate(consumers_sgv.inputs)
if input_t in sgv_.outputs])
with sgv_.graph.as_default():
output_placeholders = [
util.make_placeholder_from_tensor(input_t)
for input_t in consumers_sgv.inputs
]
reroute.swap_outputs(sgv_, output_placeholders)
if control_outputs is not None:
detach_control_outputs(sgv_, control_outputs)
return sgv_, output_placeholders
def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None):
"""Detach both the inputs and the outputs of a subgraph view.
Args:
sgv: the subgraph view to be detached. This argument is converted to a
subgraph using the same rules as the function subgraph.make_view.
Note that sgv is modified in place.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
control_ios: An instance of util.ControlOutputs or None. If not None, both
control inputs and control outputs are enabled. This is equivalent to set
control_inputs to True and control_outputs to the util.ControlOutputs
instance.
Returns:
A tuple `(sgv, detached_inputs, detached_outputs)` where:
`sgv` is a new subgraph view of the detached subgraph;
`detach_inputs` is a list of the created input placeholders;
`detach_outputs` is a list of the created output placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
control_inputs, control_outputs = select.check_cios(control_inputs,
control_outputs,
control_ios)
_, detached_inputs = detach_inputs(sgv, control_inputs)
_, detached_outputs = detach_outputs(sgv, control_outputs)
return sgv, detached_inputs, detached_outputs
def connect(sgv0, sgv1, disconnect_first=False):
"""Connect the outputs of sgv0 to the inputs of sgv1.
Args:
sgv0: the first subgraph to have its outputs swapped. This argument is
converted to a subgraph using the same rules as the function
subgraph.make_view.
Note that sgv0 is modified in place.
sgv1: the second subgraph to have its outputs swapped. This argument is
converted to a subgraph using the same rules as the function
subgraph.make_view.
Note that sgv1 is modified in place.
disconnect_first: if True the current outputs of sgv0 are disconnected.
Returns:
A tuple `(sgv0, sgv1)` of the now connected subgraphs.
Raises:
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
sgv0 = subgraph.make_view(sgv0)
sgv1 = subgraph.make_view(sgv1)
util.check_graphs(sgv0, sgv1)
if disconnect_first:
detach_outputs(sgv0)
sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
reroute.reroute_inputs(sgv0_outputs, sgv1)
return sgv0, sgv1
def bypass(sgv):
"""Bypass the given subgraph by connecting its inputs to its outputs.
Args:
sgv: the subgraph view to be bypassed. This argument is converted to a
subgraph using the same rules than the function subgraph.make_view.
Note that sgv is modified in place.
Returns:
A tuple `(sgv, detached_inputs)` where:
`sgv` is a new subgraph view of the bypassed subgraph;
`detached_inputs` is a list of the created input placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
"""
# TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
sgv = subgraph.make_view(sgv)
sgv_inputs = list(sgv.inputs)
sgv, detached_inputs = detach_inputs(sgv)
reroute.reroute_ts(sgv_inputs, sgv.outputs)
return sgv, detached_inputs