Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple graph matching functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six import string_types
from tensorflow.contrib.graph_editor import select
from tensorflow.python.framework import ops as tf_ops
__all__ = [
"op_type",
"OpMatcher",
]
def _make_graph_match(graph_match):
"""Convert to a OpMatcher instance."""
if graph_match is None:
return None
if not isinstance(graph_match, OpMatcher):
graph_match = OpMatcher(graph_match)
return graph_match
def op_type(op_types, op=None):
"""Check if an op is of the given type.
Args:
op_types: tuple of strings containing the types to check against.
For instance: ("Add", "Const")
op: the operation to check (or None).
Returns:
if op is not None, return True if the op is of the correct type.
if op is None, return a lambda function which does the type checking.
"""
if isinstance(op_types, string_types):
op_types = (op_types)
if op is None:
return lambda op: op.node_def.op in op_types
else:
return op.node_def.op in op_types
class OpMatcher(object):
"""Graph match class."""
def __init__(self, positive_filter):
"""Graph match constructor."""
self.positive_filters = []
self.input_op_matches = None
self.control_input_op_matches = None
self.output_op_matches = None
positive_filter = self._finalize_positive_filter(positive_filter)
self.positive_filters.append(positive_filter)
def _finalize_positive_filter(self, elem):
"""Convert to a filter function."""
if select.can_be_regex(elem):
regex_ = select.make_regex(elem)
return lambda op, regex=regex_: regex.search(op.name) is not None
elif isinstance(elem, tf_ops.Operation):
return lambda op, match_op=elem: op is match_op
elif callable(elem):
return elem
elif elem is True:
return lambda op: True
else:
raise ValueError("Cannot finalize the positive filter: {}".format(elem))
def __call__(self, op):
"""Evaluate if the op matches or not."""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
for positive_filter in self.positive_filters:
if not positive_filter(op):
return False
if self.input_op_matches is not None:
if len(op.inputs) != len(self.input_op_matches):
return False
for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
if input_op_match is None:
continue
if not input_op_match(input_t.op):
return False
if self.control_input_op_matches is not None:
if len(op.control_inputs) != len(self.control_input_op_matches):
return False
for cinput_op, cinput_op_match in zip(op.control_inputs,
self.control_input_op_matches):
if cinput_op_match is None:
continue
if not cinput_op_match(cinput_op):
return False
if self.output_op_matches is not None:
if len(op.outputs) != len(self.output_op_matches):
return False
for output_t, output_op_matches in zip(op.outputs,
self.output_op_matches):
if output_op_matches is None:
continue
if len(output_t.consumers()) != len(output_op_matches):
return False
for consumer_op, consumer_op_match in zip(output_t.consumers(),
output_op_matches):
if consumer_op_match is None:
continue
if not consumer_op_match(consumer_op):
return False
return True
def input_ops(self, *args):
"""Add input matches."""
if self.input_op_matches is not None:
raise ValueError("input_op_matches is already set.")
self.input_op_matches = []
for input_match in args:
self.input_op_matches.append(_make_graph_match(input_match))
return self
def control_input_ops(self, *args):
"""Add input matches."""
if self.control_input_op_matches is not None:
raise ValueError("control_input_op_matches is already set.")
self.control_input_op_matches = []
for input_match in args:
self.control_input_op_matches.append(_make_graph_match(input_match))
return self
def output_ops(self, *args):
"""Add output matches."""
if self.output_op_matches is not None:
raise ValueError("output_op_matches is already set.")
self.output_op_matches = []
for consumer_op_matches in args:
if consumer_op_matches is None:
self.output_op_matches.append(None)
if not isinstance(consumer_op_matches, list):
consumer_op_matches = [consumer_op_matches]
consumer_op_matches = [_make_graph_match(consumer_op_match)
for consumer_op_match in consumer_op_matches]
self.output_op_matches.append(consumer_op_matches)
return self