Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple graph matching functions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six import string_types

from tensorflow.contrib.graph_editor import select
from tensorflow.python.framework import ops as tf_ops

__all__ = [
    "op_type",
    "OpMatcher",
]


def _make_graph_match(graph_match):
  """Convert to a OpMatcher instance."""
  if graph_match is None:
    return None
  if not isinstance(graph_match, OpMatcher):
    graph_match = OpMatcher(graph_match)
  return graph_match


def op_type(op_types, op=None):
  """Check if an op is of the given type.

  Args:
    op_types: tuple of strings containing the types to check against.
      For instance: ("Add", "Const")
    op: the operation to check (or None).
  Returns:
    if op is not None, return True if the op is of the correct type.
    if op is None, return a lambda function which does the type checking.
  """
  if isinstance(op_types, string_types):
    op_types = (op_types)
  if op is None:
    return lambda op: op.node_def.op in op_types
  else:
    return op.node_def.op in op_types


class OpMatcher(object):
  """Graph match class."""

  def __init__(self, positive_filter):
    """Graph match constructor."""
    self.positive_filters = []
    self.input_op_matches = None
    self.control_input_op_matches = None
    self.output_op_matches = None
    positive_filter = self._finalize_positive_filter(positive_filter)
    self.positive_filters.append(positive_filter)

  def _finalize_positive_filter(self, elem):
    """Convert to a filter function."""
    if select.can_be_regex(elem):
      regex_ = select.make_regex(elem)
      return lambda op, regex=regex_: regex.search(op.name) is not None
    elif isinstance(elem, tf_ops.Operation):
      return lambda op, match_op=elem: op is match_op
    elif callable(elem):
      return elem
    elif elem is True:
      return lambda op: True
    else:
      raise ValueError("Cannot finalize the positive filter: {}".format(elem))

  def __call__(self, op):
    """Evaluate if the op matches or not."""
    if not isinstance(op, tf_ops.Operation):
      raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
    for positive_filter in self.positive_filters:
      if not positive_filter(op):
        return False
    if self.input_op_matches is not None:
      if len(op.inputs) != len(self.input_op_matches):
        return False
      for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
        if input_op_match is None:
          continue
        if not input_op_match(input_t.op):
          return False
    if self.control_input_op_matches is not None:
      if len(op.control_inputs) != len(self.control_input_op_matches):
        return False
      for cinput_op, cinput_op_match in zip(op.control_inputs,
                                            self.control_input_op_matches):
        if cinput_op_match is None:
          continue
        if not cinput_op_match(cinput_op):
          return False
    if self.output_op_matches is not None:
      if len(op.outputs) != len(self.output_op_matches):
        return False
      for output_t, output_op_matches in zip(op.outputs,
                                             self.output_op_matches):
        if output_op_matches is None:
          continue
        if len(output_t.consumers()) != len(output_op_matches):
          return False
        for consumer_op, consumer_op_match in zip(output_t.consumers(),
                                                  output_op_matches):
          if consumer_op_match is None:
            continue
          if not consumer_op_match(consumer_op):
            return False
    return True

  def input_ops(self, *args):
    """Add input matches."""
    if self.input_op_matches is not None:
      raise ValueError("input_op_matches is already set.")
    self.input_op_matches = []
    for input_match in args:
      self.input_op_matches.append(_make_graph_match(input_match))
    return self

  def control_input_ops(self, *args):
    """Add input matches."""
    if self.control_input_op_matches is not None:
      raise ValueError("control_input_op_matches is already set.")
    self.control_input_op_matches = []
    for input_match in args:
      self.control_input_op_matches.append(_make_graph_match(input_match))
    return self

  def output_ops(self, *args):
    """Add output matches."""
    if self.output_op_matches is not None:
      raise ValueError("output_op_matches is already set.")
    self.output_op_matches = []
    for consumer_op_matches in args:
      if consumer_op_matches is None:
        self.output_op_matches.append(None)
      if not isinstance(consumer_op_matches, list):
        consumer_op_matches = [consumer_op_matches]
      consumer_op_matches = [_make_graph_match(consumer_op_match)
                             for consumer_op_match in consumer_op_matches]
      self.output_op_matches.append(consumer_op_matches)
    return self