Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensorflow / purelib / tensorflow / contrib / specs / python / summaries.py
Size: Mime:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for summarizing and describing TensorFlow graphs.

This contains functions that generate string descriptions from
TensorFlow graphs, for debugging, testing, and model size
estimation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
from tensorflow.contrib.specs.python import specs
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops

# These are short abbreviations for common TensorFlow operations used
# in test cases with tf_structure to verify that specs_lib generates a
# graph structure with the right operations. Operations outside the
# scope of specs (e.g., Const and Placeholder) are just assigned "_"
# since they are not relevant to testing.

SHORT_NAMES_SRC = """
BiasAdd biasadd
Const _
Conv2D conv
MatMul dot
Placeholder _
Sigmoid sig
Variable var
""".split()

SHORT_NAMES = {
    x: y
    for x, y in zip(SHORT_NAMES_SRC[::2], SHORT_NAMES_SRC[1::2])
}


def _truncate_structure(x):
  """A helper function that disables recursion in tf_structure.

  Some constructs (e.g., HorizontalLstm) are complex unrolled
  structures and don't need to be represented in the output
  of tf_structure or tf_print. This helper function defines
  which tree branches should be pruned. This is a very imperfect
  way of dealing with unrolled LSTM's (since it truncates
  useful information as well), but it's not worth doing something
  better until the new fused and unrolled ops are ready.

  Args:
      x: a Tensor or Op

  Returns:
      A bool indicating whether the subtree should be pruned.
  """
  if "/HorizontalLstm/" in x.name:
    return True
  return False


def tf_structure(x, include_shapes=False, finished=None):
  """A postfix expression summarizing the TF graph.

  This is intended to be used as part of test cases to
  check for gross differences in the structure of the graph.
  The resulting string is not invertible or unabiguous
  and cannot be used to reconstruct the graph accurately.

  Args:
      x: a tf.Tensor or tf.Operation
      include_shapes: include shapes in the output string
      finished: a set of ops that have already been output

  Returns:
      A string representing the structure as a string of
      postfix operations.
  """
  if finished is None:
    finished = set()
  if isinstance(x, ops.Tensor):
    shape = x.get_shape().as_list()
    x = x.op
  else:
    shape = []
  if x in finished:
    return " <>"
  finished |= {x}
  result = ""
  if not _truncate_structure(x):
    for y in x.inputs:
      result += tf_structure(y, include_shapes, finished)
  if include_shapes:
    result += " %s" % (shape,)
  if x.type != "Identity":
    name = SHORT_NAMES.get(x.type, x.type.lower())
    result += " " + name
  return result


def tf_print(x, depth=0, finished=None, printer=print):
  """A simple print function for a TensorFlow graph.

  Args:
      x: a tf.Tensor or tf.Operation
      depth: current printing depth
      finished: set of nodes already output
      printer: print function to use

  Returns:
      Total number of parameters found in the
      subtree.
  """

  if finished is None:
    finished = set()
  if isinstance(x, ops.Tensor):
    shape = x.get_shape().as_list()
    x = x.op
  else:
    shape = ""
  if x.type == "Identity":
    x = x.inputs[0].op
  if x in finished:
    printer("%s<%s> %s %s" % ("  " * depth, x.name, x.type, shape))
    return
  finished |= {x}
  printer("%s%s %s %s" % ("  " * depth, x.name, x.type, shape))
  if not _truncate_structure(x):
    for y in x.inputs:
      tf_print(y, depth + 1, finished, printer=printer)


def tf_num_params(x):
  """Number of parameters in a TensorFlow subgraph.

  Args:
      x: root of the subgraph (Tensor, Operation)

  Returns:
      Total number of elements found in all Variables
      in the subgraph.
  """

  if isinstance(x, ops.Tensor):
    shape = x.get_shape()
    x = x.op
  if x.type in ["Variable", "VariableV2"]:
    return shape.num_elements()
  totals = [tf_num_params(y) for y in x.inputs]
  return sum(totals)


def tf_left_split(op):
  """Split the parameters of op for left recursion.

  Args:
    op: tf.Operation

  Returns:
    A tuple of the leftmost input tensor and a list of the
    remaining arguments.
  """

  if len(op.inputs) < 1:
    return None, []
  if op.type == "Concat":
    return op.inputs[1], op.inputs[2:]
  return op.inputs[0], op.inputs[1:]


def tf_parameter_iter(x):
  """Iterate over the left branches of a graph and yield sizes.

  Args:
      x: root of the subgraph (Tensor, Operation)

  Yields:
      A triple of name, number of params, and shape.
  """

  while 1:
    if isinstance(x, ops.Tensor):
      shape = x.get_shape().as_list()
      x = x.op
    else:
      shape = ""
    left, right = tf_left_split(x)
    totals = [tf_num_params(y) for y in right]
    total = sum(totals)
    yield x.name, total, shape
    if left is None:
      break
    x = left


def _combine_filter(x):
  """A filter for combining successive layers with similar names."""
  last_name = None
  last_total = 0
  last_shape = None
  for name, total, shape in x:
    name = re.sub("/.*", "", name)
    if name == last_name:
      last_total += total
      continue
    if last_name is not None:
      yield last_name, last_total, last_shape
    last_name = name
    last_total = total
    last_shape = shape
  if last_name is not None:
    yield last_name, last_total, last_shape


def tf_parameter_summary(x, printer=print, combine=True):
  """Summarize parameters by depth.

  Args:
      x: root of the subgraph (Tensor, Operation)
      printer: print function for output
      combine: combine layers by top-level scope
  """
  seq = tf_parameter_iter(x)
  if combine:
    seq = _combine_filter(seq)
  seq = reversed(list(seq))
  for name, total, shape in seq:
    printer("%10d %-20s %s" % (total, name, shape))


def tf_spec_structure(spec,
                      inputs=None,
                      input_shape=None,
                      input_type=dtypes.float32):
  """Return a postfix representation of the specification.

  This is intended to be used as part of test cases to
  check for gross differences in the structure of the graph.
  The resulting string is not invertible or unabiguous
  and cannot be used to reconstruct the graph accurately.

  Args:
      spec: specification
      inputs: input to the spec construction (usually a Tensor)
      input_shape: tensor shape (in lieu of inputs)
      input_type: type of the input tensor

  Returns:
      A string with a postfix representation of the
      specification.
  """

  if inputs is None:
    inputs = array_ops.placeholder(input_type, input_shape)
  outputs = specs.create_net(spec, inputs)
  return str(tf_structure(outputs).strip())


def tf_spec_summary(spec,
                    inputs=None,
                    input_shape=None,
                    input_type=dtypes.float32):
  """Output a summary of the specification.

  This prints a list of left-most tensor operations and summarized the
  variables found in the right branches. This kind of representation
  is particularly useful for networks that are generally structured
  like pipelines.

  Args:
      spec: specification
      inputs: input to the spec construction (usually a Tensor)
      input_shape: optional shape of input
      input_type: type of the input tensor
  """

  if inputs is None:
    inputs = array_ops.placeholder(input_type, input_shape)
  outputs = specs.create_net(spec, inputs)
  tf_parameter_summary(outputs)


def tf_spec_print(spec,
                  inputs=None,
                  input_shape=None,
                  input_type=dtypes.float32):
  """Print a tree representing the spec.

  Args:
      spec: specification
      inputs: input to the spec construction (usually a Tensor)
      input_shape: optional shape of input
      input_type: type of the input tensor
  """

  if inputs is None:
    inputs = array_ops.placeholder(input_type, input_shape)
  outputs = specs.create_net(spec, inputs)
  tf_print(outputs)