Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensorflow / purelib / tensorflow / python / debug / lib / debug_utils.py
Size: Mime:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Debugger (tfdbg) Utilities."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re

from six.moves import xrange  # pylint: disable=redefined-builtin


def add_debug_tensor_watch(run_options,
                           node_name,
                           output_slot=0,
                           debug_ops="DebugIdentity",
                           debug_urls=None,
                           tolerate_debug_op_creation_failures=False,
                           global_step=-1):
  """Add watch on a `Tensor` to `RunOptions`.

  N.B.:
    1. Under certain circumstances, the `Tensor` may not get actually watched
      (e.g., if the node of the `Tensor` is constant-folded during runtime).
    2. For debugging purposes, the `parallel_iteration` attribute of all
      `tf.while_loop`s in the graph are set to 1 to prevent any node from
      being executed multiple times concurrently. This change does not affect
      subsequent non-debugged runs of the same `tf.while_loop`s.

  Args:
    run_options: An instance of `config_pb2.RunOptions` to be modified.
    node_name: (`str`) name of the node to watch.
    output_slot: (`int`) output slot index of the tensor from the watched node.
    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a
      `list` of `str` or a single `str`. The latter case is equivalent to a
      `list` of `str` with only one element.
      For debug op types with customizable attributes, each debug op string can
      optionally contain a list of attribute names, in the syntax of:
        debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
    debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to,
      e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
    tolerate_debug_op_creation_failures: (`bool`) Whether to tolerate debug op
      creation failures by not throwing exceptions.
    global_step: (`int`) Optional global_step count for this debug tensor
      watch.
  """

  watch_opts = run_options.debug_options.debug_tensor_watch_opts
  run_options.debug_options.global_step = global_step

  watch = watch_opts.add()
  watch.tolerate_debug_op_creation_failures = (
      tolerate_debug_op_creation_failures)
  watch.node_name = node_name
  watch.output_slot = output_slot

  if isinstance(debug_ops, str):
    debug_ops = [debug_ops]

  watch.debug_ops.extend(debug_ops)

  if debug_urls:
    if isinstance(debug_urls, str):
      debug_urls = [debug_urls]

    watch.debug_urls.extend(debug_urls)


def watch_graph(run_options,
                graph,
                debug_ops="DebugIdentity",
                debug_urls=None,
                node_name_regex_whitelist=None,
                op_type_regex_whitelist=None,
                tensor_dtype_regex_whitelist=None,
                tolerate_debug_op_creation_failures=False,
                global_step=-1,
                reset_disk_byte_usage=False):
  """Add debug watches to `RunOptions` for a TensorFlow graph.

  To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
  and `op_type_regex_whitelist` be the default (`None`).

  N.B.:
    1. Under certain circumstances, the `Tensor` may not get actually watched
      (e.g., if the node of the `Tensor` is constant-folded during runtime).
    2. For debugging purposes, the `parallel_iteration` attribute of all
      `tf.while_loop`s in the graph are set to 1 to prevent any node from
      being executed multiple times concurrently. This change does not affect
      subsequent non-debugged runs of the same `tf.while_loop`s.


  Args:
    run_options: An instance of `config_pb2.RunOptions` to be modified.
    graph: An instance of `ops.Graph`.
    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
    debug_urls: URLs to send debug values to. Can be a list of strings,
      a single string, or None. The case of a single string is equivalent to
      a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`,
      `grpc://localhost:12345`.
      For debug op types with customizable attributes, each debug op name string
      can optionally contain a list of attribute names, in the syntax of:
        debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
    node_name_regex_whitelist: Regular-expression whitelist for node_name,
      e.g., `"(weight_[0-9]+|bias_.*)"`
    op_type_regex_whitelist: Regular-expression whitelist for the op type of
      nodes, e.g., `"(Variable|Add)"`.
      If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
      are set, the two filtering operations will occur in a logical `AND`
      relation. In other words, a node will be included if and only if it
      hits both whitelists.
    tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
      data type, e.g., `"^int.*"`.
      This whitelist operates in logical `AND` relations to the two whitelists
      above.
    tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
      failures (e.g., due to dtype incompatibility) are to be tolerated by not
      throwing exceptions.
    global_step: (`int`) Optional global_step count for this debug tensor
      watch.
    reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
      usage to zero (default: `False`).
  """

  if isinstance(debug_ops, str):
    debug_ops = [debug_ops]

  node_name_pattern = (re.compile(node_name_regex_whitelist)
                       if node_name_regex_whitelist else None)
  op_type_pattern = (re.compile(op_type_regex_whitelist)
                     if op_type_regex_whitelist else None)
  tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist)
                          if tensor_dtype_regex_whitelist else None)

  ops = graph.get_operations()
  for op in ops:
    # Skip nodes without any output tensors.
    if not op.outputs:
      continue

    node_name = op.name
    op_type = op.type

    if node_name_pattern and not node_name_pattern.match(node_name):
      continue
    if op_type_pattern and not op_type_pattern.match(op_type):
      continue

    for slot in xrange(len(op.outputs)):
      if (tensor_dtype_pattern and
          not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
        continue

      add_debug_tensor_watch(
          run_options,
          node_name,
          output_slot=slot,
          debug_ops=debug_ops,
          debug_urls=debug_urls,
          tolerate_debug_op_creation_failures=(
              tolerate_debug_op_creation_failures),
          global_step=global_step)
  run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage


def watch_graph_with_blacklists(run_options,
                                graph,
                                debug_ops="DebugIdentity",
                                debug_urls=None,
                                node_name_regex_blacklist=None,
                                op_type_regex_blacklist=None,
                                tensor_dtype_regex_blacklist=None,
                                tolerate_debug_op_creation_failures=False,
                                global_step=-1,
                                reset_disk_byte_usage=False):
  """Add debug tensor watches, blacklisting nodes and op types.

  This is similar to `watch_graph()`, but the node names and op types are
  blacklisted, instead of whitelisted.

  N.B.:
    1. Under certain circumstances, the `Tensor` may not get actually watched
      (e.g., if the node of the `Tensor` is constant-folded during runtime).
    2. For debugging purposes, the `parallel_iteration` attribute of all
      `tf.while_loop`s in the graph are set to 1 to prevent any node from
      being executed multiple times concurrently. This change does not affect
      subsequent non-debugged runs of the same `tf.while_loop`s.

  Args:
    run_options: An instance of `config_pb2.RunOptions` to be modified.
    graph: An instance of `ops.Graph`.
    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
      See the documentation of `watch_graph` for more details.
    debug_urls: URL(s) to send debug values to, e.g.,
      `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
    node_name_regex_blacklist: Regular-expression blacklist for node_name.
      This should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`.
    op_type_regex_blacklist: Regular-expression blacklist for the op type of
      nodes, e.g., `"(Variable|Add)"`.
      If both node_name_regex_blacklist and op_type_regex_blacklist
      are set, the two filtering operations will occur in a logical `OR`
      relation. In other words, a node will be excluded if it hits either of
      the two blacklists; a node will be included if and only if it hits
      neither of the blacklists.
    tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor
      data type, e.g., `"^int.*"`.
      This blacklist operates in logical `OR` relations to the two whitelists
      above.
    tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
      failures (e.g., due to dtype incompatibility) are to be tolerated by not
      throwing exceptions.
    global_step: (`int`) Optional global_step count for this debug tensor
      watch.
    reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
      usage to zero (default: `False`).
  """

  if isinstance(debug_ops, str):
    debug_ops = [debug_ops]

  node_name_pattern = (re.compile(node_name_regex_blacklist) if
                       node_name_regex_blacklist else None)
  op_type_pattern = (re.compile(op_type_regex_blacklist) if
                     op_type_regex_blacklist else None)
  tensor_dtype_pattern = (re.compile(tensor_dtype_regex_blacklist) if
                          tensor_dtype_regex_blacklist else None)

  ops = graph.get_operations()
  for op in ops:
    # Skip nodes without any output tensors.
    if not op.outputs:
      continue

    node_name = op.name
    op_type = op.type

    if node_name_pattern and node_name_pattern.match(node_name):
      continue
    if op_type_pattern and op_type_pattern.match(op_type):
      continue

    for slot in xrange(len(op.outputs)):
      if (tensor_dtype_pattern and
          tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
        continue

      add_debug_tensor_watch(
          run_options,
          node_name,
          output_slot=slot,
          debug_ops=debug_ops,
          debug_urls=debug_urls,
          tolerate_debug_op_creation_failures=(
              tolerate_debug_op_creation_failures),
          global_step=global_step)
    run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage