Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various ways of selecting operations and tensors in a graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from six import iteritems
from six import string_types
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
__all__ = [
"can_be_regex",
"make_regex",
"filter_ts",
"filter_ts_from_regex",
"filter_ops",
"filter_ops_from_regex",
"get_name_scope_ops",
"check_cios",
"get_ops_ios",
"compute_boundary_ts",
"get_within_boundary_ops",
"get_forward_walk_ops",
"get_backward_walk_ops",
"get_walks_intersection_ops",
"get_walks_union_ops",
"select_ops",
"select_ts",
"select_ops_and_ts",
]
_RE_TYPE = type(re.compile(""))
def can_be_regex(obj):
"""Return True if obj can be turned into a regular expression."""
return isinstance(obj, string_types + (_RE_TYPE,))
def make_regex(obj):
"""Return a compiled regular expression.
Args:
obj: a string or a regular expression.
Returns:
A compiled regular expression.
Raises:
ValueError: if obj could not be converted to a regular expression.
"""
if not can_be_regex(obj):
raise ValueError("Expected a string or a regex, got: {}".format(type(obj)))
if isinstance(obj, string_types):
return re.compile(obj)
else:
return obj
def _get_input_ts(ops):
"""Compute the list of unique input tensors of all the op in ops.
Args:
ops: an object convertible to a list of `tf.Operation`.
Returns:
The list of unique input tensors of all the op in ops.
Raises:
TypeError: if ops cannot be converted to a list of `tf.Operation`.
"""
ops = util.make_list_of_op(ops)
ts = []
ts_set = set()
for op in ops:
for t in op.inputs:
if t not in ts_set:
ts.append(t)
ts_set.add(t)
return ts
def _get_output_ts(ops):
"""Compute the list of unique output tensors of all the op in ops.
Args:
ops: an object convertible to a list of tf.Operation.
Returns:
The list of unique output tensors of all the op in ops.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation.
"""
ops = util.make_list_of_op(ops)
ts = []
for op in ops:
ts += op.outputs
return ts
def filter_ts(ops, positive_filter):
"""Get all the tensors which are input or output of an op in ops.
Args:
ops: an object convertible to a list of `tf.Operation`.
positive_filter: a function deciding whether to keep a tensor or not.
If `True`, all the tensors are returned.
Returns:
A list of `tf.Tensor`.
Raises:
TypeError: if ops cannot be converted to a list of `tf.Operation`.
"""
ops = util.make_list_of_op(ops)
ts = _get_input_ts(ops)
util.concatenate_unique(ts, _get_output_ts(ops))
if positive_filter is not True:
ts = [t for t in ts if positive_filter(t)]
return ts
def filter_ts_from_regex(ops, regex):
r"""Get all the tensors linked to ops that match the given regex.
Args:
ops: an object convertible to a list of tf.Operation.
regex: a regular expression matching the tensors' name.
For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo"
scope.
Returns:
A list of tf.Tensor.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation.
"""
ops = util.make_list_of_op(ops)
regex_obj = make_regex(regex)
return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name))
def filter_ops(ops, positive_filter):
"""Get the ops passing the given filter.
Args:
ops: an object convertible to a list of tf.Operation.
positive_filter: a function deciding where to keep an operation or not.
If True, all the operations are returned.
Returns:
A list of selected tf.Operation.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation.
"""
ops = util.make_list_of_op(ops)
if positive_filter is not True: # pylint: disable=g-explicit-bool-comparison
ops = [op for op in ops if positive_filter(op)]
return ops
def filter_ops_from_regex(ops, regex):
"""Get all the operations that match the given regex.
Args:
ops: an object convertible to a list of `tf.Operation`.
regex: a regular expression matching the operation's name.
For example, `"^foo(/.*)?$"` will match all the operations in the "foo"
scope.
Returns:
A list of `tf.Operation`.
Raises:
TypeError: if ops cannot be converted to a list of `tf.Operation`.
"""
ops = util.make_list_of_op(ops)
regex_obj = make_regex(regex)
return filter_ops(ops, lambda op: regex_obj.search(op.name))
def get_name_scope_ops(ops, scope):
"""Get all the operations under the given scope path.
Args:
ops: an object convertible to a list of tf.Operation.
scope: a scope path.
Returns:
A list of tf.Operation.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation.
"""
if scope and scope[-1] == "/":
scope = scope[:-1]
return filter_ops_from_regex(ops, "^{}(/.*)?$".format(scope))
def check_cios(control_inputs=False, control_outputs=None, control_ios=None):
"""Do various check on control_inputs and control_outputs.
Args:
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
control_ios: An instance of util.ControlOutputs or None. If not None, both
control inputs and control outputs are enabled. This is equivalent to set
control_inputs to True and control_outputs to the util.ControlOutputs
instance.
Returns:
A tuple `(control_inputs, control_outputs)` where:
`control_inputs` is a boolean indicating whether to use control inputs.
`control_outputs` is an instance of util.ControlOutputs or None
Raises:
ValueError: if control_inputs is an instance of util.ControlOutputs but
control_outputs is not None
TypeError: if control_outputs is not None and is not a util.ControlOutputs.
"""
if control_ios is not None:
if not isinstance(control_ios, util.ControlOutputs):
raise TypeError("Expected a util.ControlOutputs, got: {}".format(
type(control_ios)))
if control_outputs is not None:
raise ValueError("control_outputs should be None when using control_ios.")
control_inputs = True
control_outputs = control_ios
elif control_outputs is not None:
if not isinstance(control_outputs, util.ControlOutputs):
raise TypeError("Expected a util.ControlOutputs, got: {}".format(
type(control_outputs)))
if control_outputs is not None:
control_outputs.update()
return control_inputs, control_outputs
def get_ops_ios(ops, control_inputs=False, control_outputs=None,
control_ios=None):
"""Return all the `tf.Operation` which are connected to an op in ops.
Args:
ops: an object convertible to a list of `tf.Operation`.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of `util.ControlOutputs` or `None`. If not
`None`, control outputs are enabled.
control_ios: An instance of `util.ControlOutputs` or `None`. If not `None`,
both control inputs and control outputs are enabled. This is equivalent to
set `control_inputs` to `True` and `control_outputs` to the
`util.ControlOutputs` instance.
Returns:
All the `tf.Operation` surrounding the given ops.
Raises:
TypeError: if `ops` cannot be converted to a list of `tf.Operation`.
"""
control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
control_ios)
ops = util.make_list_of_op(ops)
res = []
for op in ops:
util.concatenate_unique(res, [t.op for t in op.inputs])
for t in op.outputs:
util.concatenate_unique(res, t.consumers())
if control_outputs is not None:
util.concatenate_unique(res, control_outputs.get(op))
if control_inputs:
util.concatenate_unique(res, op.control_inputs)
return res
def compute_boundary_ts(ops):
"""Compute the tensors at the boundary of a set of ops.
This function looks at all the tensors connected to the given ops (in/out)
and classify them into three categories:
1) input tensors: tensors whose generating operation is not in ops.
2) output tensors: tensors whose consumer operations are not in ops
3) inside tensors: tensors which are neither input nor output tensors.
Note that a tensor can be both an inside tensor and an output tensor if it is
consumed by operations both outside and inside of `ops`.
Args:
ops: an object convertible to a list of tf.Operation.
Returns:
A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
`outside_input_ts` is a Python list of input tensors;
`outside_output_ts` is a python list of output tensors;
`inside_ts` is a python list of inside tensors.
Since a tensor can be both an inside tensor and an output tensor,
`outside_output_ts` and `inside_ts` might intersect.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation.
"""
ops = util.make_list_of_op(ops)
input_ts = _get_input_ts(ops)
output_ts = _get_output_ts(ops)
output_ts_set = frozenset(output_ts)
ops_set = frozenset(ops)
# Compute inside tensors.
inside_ts = []
only_inside_ts = []
for t in input_ts:
# Skip if the input tensor is not also an output tensor.
if t not in output_ts_set:
continue
# Mark as "inside".
inside_ts.append(t)
# Mark as "only inside" if the tensor is not both inside and output.
consumers = frozenset(t.consumers())
if consumers - ops_set:
continue
only_inside_ts.append(t)
inside_ts_set = frozenset(inside_ts)
only_inside_ts_set = frozenset(only_inside_ts)
outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set]
outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
return outside_input_ts, outside_output_ts, inside_ts
def get_within_boundary_ops(ops,
seed_ops,
boundary_ops=(),
inclusive=True,
control_inputs=False,
control_outputs=None,
control_ios=None):
"""Return all the `tf.Operation` within the given boundary.
Args:
ops: an object convertible to a list of `tf.Operation`. those ops define the
set in which to perform the operation (if a `tf.Graph` is given, it
will be converted to the list of all its operations).
seed_ops: the operations from which to start expanding.
boundary_ops: the ops forming the boundary.
inclusive: if `True`, the result will also include the boundary ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of `util.ControlOutputs` or `None`. If not
`None`, control outputs are enabled.
control_ios: An instance of `util.ControlOutputs` or `None`. If not
`None`, both control inputs and control outputs are enabled. This is
equivalent to set control_inputs to True and control_outputs to
the `util.ControlOutputs` instance.
Returns:
All the `tf.Operation` surrounding the given ops.
Raises:
TypeError: if `ops` or `seed_ops` cannot be converted to a list of
`tf.Operation`.
ValueError: if the boundary is intersecting with the seeds.
"""
control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
control_ios)
ops = util.make_list_of_op(ops)
seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
boundary_ops = set(util.make_list_of_op(boundary_ops))
res = set(seed_ops)
if boundary_ops & res:
raise ValueError("Boundary is intersecting with the seeds.")
wave = set(seed_ops)
while wave:
new_wave = set()
ops_io = get_ops_ios(wave, control_inputs, control_outputs)
for op in ops_io:
if op in res:
continue
if op in boundary_ops:
if inclusive:
res.add(op)
else:
new_wave.add(op)
res.update(new_wave)
wave = new_wave
return [op for op in ops if op in res]
def get_forward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
within_ops_fn=None,
stop_at_ts=(),
control_outputs=None):
"""Do a forward graph walk and return all the visited ops.
Args:
seed_ops: an iterable of operations from which the forward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the consumers of those tensors.
inclusive: if True the given seed_ops are also part of the resulting set.
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
within_ops_fn: if provided, a function on ops that should return True iff
the op is within the graph traversal. This can be used along within_ops,
in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_outputs: a `util.ControlOutputs` instance or None.
If not `None`, it will be used while walking the graph forward.
Returns:
A Python set of all the `tf.Operation` ahead of `seed_ops`.
Raises:
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
`tf.Operation`.
"""
_, control_outputs = check_cios(False, control_outputs)
if not util.is_iterable(seed_ops):
seed_ops = [seed_ops]
if not seed_ops:
return []
if isinstance(seed_ops[0], tf_ops.Tensor):
ts = util.make_list_of_t(seed_ops, allow_graph=False)
seed_ops = util.get_consuming_ops(ts)
else:
seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
seed_ops = frozenset(seed_ops)
stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
if within_ops:
within_ops = util.make_list_of_op(within_ops, allow_graph=False)
within_ops = frozenset(within_ops)
seed_ops &= within_ops
def is_within(op):
return (within_ops is None or op in within_ops) and (
within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
while wave:
new_wave = set()
for op in wave:
for new_t in op.outputs:
if new_t in stop_at_ts:
continue
for new_op in new_t.consumers():
if new_op not in result and is_within(new_op):
new_wave.add(new_op)
if control_outputs is not None:
for new_op in control_outputs.get(op):
if new_op not in result and is_within(new_op):
new_wave.add(new_op)
util.concatenate_unique(result, new_wave)
wave = new_wave
if not inclusive:
result = [op for op in result if op not in seed_ops]
return result
def get_backward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
within_ops_fn=None,
stop_at_ts=(),
control_inputs=False):
"""Do a backward graph walk and return all the visited ops.
Args:
seed_ops: an iterable of operations from which the backward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the generators of those tensors.
inclusive: if True the given seed_ops are also part of the resulting set.
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
within_ops_fn: if provided, a function on ops that should return True iff
the op is within the graph traversal. This can be used along within_ops,
in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_inputs: if True, control inputs will be used while moving backward.
Returns:
A Python set of all the `tf.Operation` behind `seed_ops`.
Raises:
TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
`tf.Operation`.
"""
if not util.is_iterable(seed_ops):
seed_ops = [seed_ops]
if not seed_ops:
return []
if isinstance(seed_ops[0], tf_ops.Tensor):
ts = util.make_list_of_t(seed_ops, allow_graph=False)
seed_ops = util.get_generating_ops(ts)
else:
seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
seed_ops = frozenset(util.make_list_of_op(seed_ops))
if within_ops:
within_ops = util.make_list_of_op(within_ops, allow_graph=False)
within_ops = frozenset(within_ops)
seed_ops &= within_ops
def is_within(op):
return (within_ops is None or op in within_ops) and (
within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
while wave:
new_wave = set()
for op in wave:
for new_t in op.inputs:
if new_t in stop_at_ts:
continue
if new_t.op not in result and is_within(new_t.op):
new_wave.add(new_t.op)
if control_inputs:
for new_op in op.control_inputs:
if new_op not in result and is_within(new_op):
new_wave.add(new_op)
util.concatenate_unique(result, new_wave)
wave = new_wave
if not inclusive:
result = [op for op in result if op not in seed_ops]
return result
def get_walks_intersection_ops(forward_seed_ops,
backward_seed_ops,
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
"""Return the intersection of a forward and a backward walk.
Args:
forward_seed_ops: an iterable of operations from which the forward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the consumers of those tensors.
backward_seed_ops: an iterable of operations from which the backward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the generators of those tensors.
forward_inclusive: if True the given forward_seed_ops are also part of the
resulting set.
backward_inclusive: if True the given backward_seed_ops are also part of the
resulting set.
within_ops: an iterable of tf.Operation within which the search is
restricted. If within_ops is None, the search is performed within
the whole graph.
within_ops_fn: if provided, a function on ops that should return True iff
the op is within the graph traversal. This can be used along within_ops,
in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
control_ios: An instance of util.ControlOutputs or None. If not None, both
control inputs and control outputs are enabled. This is equivalent to set
control_inputs to True and control_outputs to the util.ControlOutputs
instance.
Returns:
A Python set of all the tf.Operation in the intersection of a forward and a
backward walk.
Raises:
TypeError: if `forward_seed_ops` or `backward_seed_ops` or `within_ops`
cannot be converted to a list of `tf.Operation`.
"""
control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
control_ios)
forward_ops = get_forward_walk_ops(
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return [op for op in forward_ops if op in backward_ops]
def get_walks_union_ops(forward_seed_ops,
backward_seed_ops,
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
"""Return the union of a forward and a backward walk.
Args:
forward_seed_ops: an iterable of operations from which the forward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the consumers of those tensors.
backward_seed_ops: an iterable of operations from which the backward graph
walk starts. If a list of tensors is given instead, the seed_ops are set
to be the generators of those tensors.
forward_inclusive: if True the given forward_seed_ops are also part of the
resulting set.
backward_inclusive: if True the given backward_seed_ops are also part of the
resulting set.
within_ops: restrict the search within those operations. If within_ops is
None, the search is done within the whole graph.
within_ops_fn: if provided, a function on ops that should return True iff
the op is within the graph traversal. This can be used along within_ops,
in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
control_ios: An instance of util.ControlOutputs or None. If not None, both
control inputs and control outputs are enabled. This is equivalent to set
control_inputs to True and control_outputs to the util.ControlOutputs
instance.
Returns:
A Python set of all the tf.Operation in the union of a forward and a
backward walk.
Raises:
TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be
converted to a list of tf.Operation.
"""
control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
control_ios)
forward_ops = get_forward_walk_ops(
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return util.concatenate_unique(forward_ops, backward_ops)
def select_ops(*args, **kwargs):
"""Helper to select operations.
Args:
*args: list of 1) regular expressions (compiled or not) or 2) (array of)
`tf.Operation`. `tf.Tensor` instances are silently ignored.
**kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
required when using regex.
'positive_filter': an elem if selected only if `positive_filter(elem)` is
`True`. This is optional.
'restrict_ops_regex': a regular expression is ignored if it doesn't start
with the substring "(?#ops)".
Returns:
A list of `tf.Operation`.
Raises:
TypeError: if the optional keyword argument graph is not a `tf.Graph`
or if an argument in args is not an (array of) `tf.Operation`
or an (array of) `tf.Tensor` (silently ignored) or a string
or a regular expression.
ValueError: if one of the keyword arguments is unexpected or if a regular
expression is used without passing a graph as a keyword argument.
"""
# get keywords arguments
graph = None
positive_filter = None
restrict_ops_regex = False
for k, v in iteritems(kwargs):
if k == "graph":
graph = v
if graph is not None and not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
elif k == "positive_filter":
positive_filter = v
elif k == "restrict_ops_regex":
restrict_ops_regex = v
elif k == "restrict_ts_regex":
pass
else:
raise ValueError("Wrong keywords argument: {}.".format(k))
ops = []
for arg in args:
if can_be_regex(arg):
if graph is None:
raise ValueError("Use the keyword argument 'graph' to use regex.")
regex = make_regex(arg)
if regex.pattern.startswith("(?#ts)"):
continue
if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"):
continue
ops_ = filter_ops_from_regex(graph, regex)
for op_ in ops_:
if op_ not in ops:
if positive_filter is None or positive_filter(op_):
ops.append(op_)
else:
ops_aux = util.make_list_of_op(arg, ignore_ts=True)
if positive_filter is not None:
ops_aux = [op for op in ops_aux if positive_filter(op)]
ops_aux = [op for op in ops_aux if op not in ops]
ops += ops_aux
return ops
def select_ts(*args, **kwargs):
"""Helper to select tensors.
Args:
*args: list of 1) regular expressions (compiled or not) or 2) (array of)
`tf.Tensor`. `tf.Operation` instances are silently ignored.
**kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
required when using regex.
'positive_filter': an elem if selected only if `positive_filter(elem)` is
`True`. This is optional.
'restrict_ts_regex': a regular expression is ignored if it doesn't start
with the substring "(?#ts)".
Returns:
A list of `tf.Tensor`.
Raises:
TypeError: if the optional keyword argument graph is not a `tf.Graph`
or if an argument in args is not an (array of) `tf.Tensor`
or an (array of) `tf.Operation` (silently ignored) or a string
or a regular expression.
ValueError: if one of the keyword arguments is unexpected or if a regular
expression is used without passing a graph as a keyword argument.
"""
# get keywords arguments
graph = None
positive_filter = None
restrict_ts_regex = False
for k, v in iteritems(kwargs):
if k == "graph":
graph = v
if graph is not None and not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got {}".format(type(graph)))
elif k == "positive_filter":
positive_filter = v
elif k == "restrict_ts_regex":
restrict_ts_regex = v
elif k == "restrict_ops_regex":
pass
else:
raise ValueError("Wrong keywords argument: {}.".format(k))
ts = []
for arg in args:
if can_be_regex(arg):
if graph is None:
raise ValueError("Use the keyword argument 'graph' to use regex.")
regex = make_regex(arg)
if regex.pattern.startswith("(?#ops)"):
continue
if restrict_ts_regex and not regex.pattern.startswith("(?#ts)"):
continue
ts_ = filter_ts_from_regex(graph, regex)
for t_ in ts_:
if t_ not in ts:
if positive_filter is None or positive_filter(t_):
ts.append(t_)
else:
ts_aux = util.make_list_of_t(arg, ignore_ops=True)
if positive_filter is not None:
ts_aux = [t for t in ts_aux if positive_filter(t)]
ts_aux = [t for t in ts_aux if t not in ts]
ts += ts_aux
return ts
def select_ops_and_ts(*args, **kwargs):
"""Helper to select operations and tensors.
Args:
*args: list of 1) regular expressions (compiled or not) or 2) (array of)
`tf.Operation` 3) (array of) tf.Tensor. Regular expressions matching
tensors must start with the comment `"(?#ts)"`, for instance:
`"(?#ts)^foo/.*"`.
**kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
required when using regex.
'positive_filter': an elem if selected only if `positive_filter(elem)` is
`True`. This is optional.
Returns:
A tuple `(ops, ts)` where:
`ops` is a list of `tf.Operation`, and
`ts` is a list of `tf.Tensor`
Raises:
TypeError: if the optional keyword argument graph is not a `tf.Graph`
or if an argument in args is not an (array of) `tf.Tensor`
or an (array of) `tf.Operation` or a string or a regular expression.
ValueError: if one of the keyword arguments is unexpected or if a regular
expression is used without passing a graph as a keyword argument.
"""
ops = select_ops(*args, restrict_ops_regex=False, **kwargs)
ts = select_ts(*args, restrict_ts_regex=True, **kwargs)
return ops, ts