Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensorflow / purelib / tensorflow / python / ops / critical_section_ops.py
Size: Mime:
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critical Section object and execution logic."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.core.protobuf import critical_section_pb2

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


__all__ = ["CriticalSection"]


# Graph Keys
CRITICAL_SECTIONS = "critical_sections"
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"


class _ExecutionSignature(
    collections.namedtuple("_ExecutionSignature",
                           ("op", "handle",
                            "resources", "exclusive_resource_access"))):
  """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
  pass


def _identity(x):
  """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
  if isinstance(x, tensor_array_ops.TensorArray):
    return x.identity()
  elif isinstance(x, ops.Operation):
    return control_flow_ops.group(x)
  elif context.executing_eagerly() and x is None:
    return None
  else:
    return array_ops.identity(x)


def _get_colocation(op):
  """Get colocation symbol from op, if any."""
  try:
    return op.get_attr("_class")
  except ValueError:
    return None


@tf_export("CriticalSection")
class CriticalSection(object):
  """Critical section.

  A `CriticalSection` object is a resource in the graph which executes subgraphs
  in **serial** order.  A common example of a subgraph one may wish to run
  exclusively is the one given by the following function:

  ```python
  v = resource_variable_ops.ResourceVariable(0.0, name="v")

  def count():
    value = v.read_value()
    with tf.control_dependencies([value]):
      with tf.control_dependencies([v.assign_add(1)]):
        return tf.identity(value)
  ```

  Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
  The snapshot value is returned.

  If multiple workers or threads all execute `count` in parallel, there is no
  guarantee that access to the variable `v` is atomic at any point within
  any thread's calculation of `count`.  In fact, even implementing an atomic
  counter that guarantees that the user will see each value `0, 1, ...,` is
  currently impossible.

  The solution is to ensure any access to the underlying resource `v` is
  only processed through a critical section:

  ```python
  cs = CriticalSection()
  f1 = cs.execute(count)
  f2 = cs.execute(count)
  output = f1 + f2
  session.run(output)
  ```
  The functions `f1` and `f2` will be executed serially, and updates to `v`
  will be atomic.

  **NOTES**

  All resource objects, including the critical section and any captured
  variables of functions executed on that critical section, will be
  colocated to the same device (host and cpu/gpu).

  When using multiple critical sections on the same resources, there is no
  guarantee of exclusive access to those resources.  This behavior is disallowed
  by default (but see the kwarg `exclusive_resource_access`).

  For example, running the same function in two separate critical sections
  will not ensure serial execution:

  ```python
  v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
  def accumulate(up):
    x = v.read_value()
    with tf.control_dependencies([x]):
      with tf.control_dependencies([v.assign_add(up)]):
        return tf.identity(x)
  ex1 = CriticalSection().execute(
    accumulate, 1.0, exclusive_resource_access=False)
  ex2 = CriticalSection().execute(
    accumulate, 1.0, exclusive_resource_access=False)
  bad_sum = ex1 + ex2
  sess.run(v.initializer)
  sess.run(bad_sum)  # May return 0.0
  ```
  """

  def __init__(self, name=None, shared_name=None,
               critical_section_def=None, import_scope=None):
    """Creates a critical section."""
    context.ensure_initialized()
    if critical_section_def and name is not None:
      raise ValueError("critical_section_def and shared_name are "
                       "mutually exclusive.")
    if critical_section_def:
      self._init_from_proto(critical_section_def, import_scope=import_scope)
    else:
      self._init_from_args(name, shared_name)

  def _init_from_proto(self, critical_section_def, import_scope):  # pylint: disable=invalid-name
    raise NotImplementedError("Not yet implemented")
    # TODO(ebrevdo): Re-enable once CriticalSection is in core.
    # assert isinstance(
    #     critical_section_def, critical_section_pb2.CriticalSectionDef)
    # # Create from critical_section_def.
    # g = ops.get_default_graph()
    # self._handle = g.as_graph_element(
    #     ops.prepend_name_scope(
    #         critical_section_def.critical_section_name,
    #         import_scope=import_scope))

  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
    """Initialize the CriticalSection from constructor arguments."""
    with ops.name_scope(name, "CriticalSection", []) as name:
      with ops.init_scope():
        # pylint: disable=protected-access
        container = ops.get_default_graph()._container
        # pylint: enable=protected-access
        if shared_name is None:
          shared_name = name
        if container is None:
          container = ""
        self._handle = gen_resource_variable_ops.mutex_v2(
            shared_name=shared_name, container=container, name=name)

    if not context.executing_eagerly():
      ops.add_to_collections(CRITICAL_SECTIONS, self)

  @property
  def name(self):
    return self._handle.op.name

  def execute(self, fn, exclusive_resource_access=True, name=None):
    """Execute function `fn()` inside the critical section.

    `fn` should not accept any arguments.  To add extra arguments to when
    calling `fn` in the critical section, create a lambda:

    ```python
    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
    ```

    Args:
      fn: The function to execute.  Must return at least one tensor.
      exclusive_resource_access: Whether the resources required by
        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
        You may want to set this to `False` if you will be accessing a
        resource in read-only mode in two different CriticalSections.
      name: The name to use when creating the execute operation.

    Returns:
      The tensors returned from `fn()`.

    Raises:
      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
        or lazy way that may cause a deadlock.
      ValueError: If `exclusive_resource_access == True` and
        another `CriticalSection` has an execution requesting the same
        resources as `fn``.  Note, even if `exclusive_resource_access` is
        `True`, if another execution in another `CriticalSection` was created
        without `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    with ops.name_scope(name, "critical_section_execute", []):

      # Ensure that mutex locking only happens *after* all args and
      # kwargs have been executed.  This avoids certain types of deadlocks.
      lock = gen_resource_variable_ops.mutex_lock(self._handle)

      if not context.executing_eagerly():
        # NOTE(ebrevdo): This is to ensure we don't pick up spurious
        # Operations created by other threads.
        with ops.get_default_graph()._lock:  # pylint: disable=protected-access
          existing_ops = ops.get_default_graph().get_operations()
          with ops.control_dependencies([lock]):
            r = fn()
          # TODO(ebrevdo): If creating critical sections in a python loop, this
          # makes graph creation time quadratic.  Revisit if this
          # becomes a problem.
          created_ops = (set(ops.get_default_graph().get_operations())
                         .difference(existing_ops))
      else:
        with ops.control_dependencies([lock]):
          r = fn()

      if not context.executing_eagerly():
        self._add_control_dependencies_to_lock(created_ops, lock.op)

        # captured_resources is a list of resources that are directly
        # accessed only by ops created during fn(), not by any
        # ancestors of those ops in the graph.
        captured_resources = set([
            input_ for op in created_ops
            for input_ in op.inputs
            if input_.dtype == dtypes.resource
        ])

        # NOTE(ebrevdo): The only time self._is_self_handle() is True
        # in this call is if one of the recently created ops, within
        # the execute(), themselves attempt to access the
        # CriticalSection.  This will cause a deadlock.
        if any(self._is_self_handle(x) for x in captured_resources):
          raise ValueError("The function fn attempts to directly access the "
                           "CriticalSection in which it would be running.  "
                           "This is illegal and would cause deadlocks.")

        self._check_multiple_access_to_resources(
            captured_resources, exclusive_resource_access)

      r_flat = [_identity(x) for x in nest.flatten(r)]

      with ops.control_dependencies(r_flat):
        # The identity must run on the same machine as self._handle
        with ops.colocate_with(self._handle):
          # Do not use array_ops.identity as there are special
          # optimizations within TensorFlow which seem to elide it
          # even when optimizations are disabled(!).
          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
              lock)

        # Make sure that if any element of r is accessed, all of
        # them are executed together.
        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))

      with ops.control_dependencies([ensure_lock_exists]):
        outputs = nest.map_structure(_identity, r)

      if not context.executing_eagerly():
        signature = _ExecutionSignature(
            op=lock.op,
            handle=self._handle,
            resources=list(captured_resources),
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return outputs

  def _add_control_dependencies_to_lock(self, created_ops, lock_op):
    """To avoid deadlocks, all args must be executed before lock_op."""
    # Get all arguments (explicit and captured) of all ops created by fn().
    all_args = set([input_.op for op in created_ops for input_ in op.inputs])
    all_args.update(
        input_op for op in created_ops for input_op in op.control_inputs)
    # Unfortunately, we can't use sets throughout because TF seems to
    # create new Operation objects for the same op sometimes; and we
    # can't rely on id(op).

    # pylint: disable=protected-access
    all_args_dict = dict((op._id, op) for op in all_args)

    # Remove ops created within fn, or that lock_op already has a
    # control dependency on.  Also remove a possible self-loop.
    for op in created_ops:
      all_args_dict.pop(op._id, None)
    for op in lock_op.control_inputs:
      all_args_dict.pop(op._id, None)
    for input_ in lock_op.inputs:
      all_args_dict.pop(input_.op._id, None)
    all_args_dict.pop(lock_op._id, None)

    all_args = all_args_dict.values()

    if not all_args:
      # No control dependencies to add; return early.
      return

    # This group is important: it ensures that any ops in all_args
    # outside the control context of the lock_op (and this fn, which
    # runs in the same context) are added to this context before
    # being added to the control dependencies of lock_op.
    all_args = control_flow_ops.group(*all_args)

    lock_op._add_control_input(all_args)
    # pylint: enable=protected-access

  def _is_self_handle(self, x):
    """Check if the tensor `x` is the same Mutex as `self._handle`."""
    if isinstance(x, ops.EagerTensor):
      return x is self._handle
    return (x.op.type == "MutexV2"
            # blank shared_name means the op will create a unique one.
            and x.op.get_attr("shared_name")
            and (x.op.get_attr("shared_name") ==
                 self._handle.op.get_attr("shared_name"))
            and (x.op.device == self._handle.op.device
                 or _get_colocation(x.op) == _get_colocation(self._handle.op)))

  def _check_multiple_access_to_resources(
      self, captured_resources, exclusive_resource_access):
    """Raise if captured_resources are accessed by another CriticalSection.

    Args:
      captured_resources: Set of tensors of type resource.
      exclusive_resource_access: Whether this execution requires exclusive
        resource access.

    Raises:
      ValueError: If any tensors in `captured_resources` are also accessed
        by another `CriticalSection`, and at least one of them requires
        exclusive resource access.
    """
    # Collections and op introspection does not work in eager
    # mode.  This is generally ok; since eager mode (as of
    # writing) executes sequentially anyway.
    for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
      if self._is_self_handle(sg.handle):
        # Other executions in the same critical section are allowed.
        continue
      if not (exclusive_resource_access or sg.exclusive_resource_access):
        # Neither execution requested exclusive access.
        continue
      resource_intersection = captured_resources.intersection(sg.resources)
      if resource_intersection:
        raise ValueError(
            "This execution would access resources: %s.  Either this "
            "lock (CriticalSection: %s) or lock '%s' "
            "(CriticalSection: %s) requested exclusive resource access "
            "of this resource.  Did you mean to call execute with keyword "
            "argument exclusive_resource_access=False?" %
            (list(resource_intersection), self._handle, sg, sg.handle))

  # TODO(ebrevdo): Re-enable once CriticalSection is in core.

  # def to_proto(self, export_scope=None):
  #   """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.

  #   Args:
  #     export_scope: Optional `string`. Name scope to remove.

  #   Returns:
  #     A `CriticalSectionDef` protocol buffer, or `None` if the
  #     `CriticalSection` is not in the specified name scope.
  #   """
  #   if export_scope is None or self.handle.name.startswith(export_scope):
  #     cs_def = critical_section_pb2.CriticalSectionDef()
  #     cs_def.critical_section_name = ops.strip_name_scope(
  #         self._handle.name, export_scope)
  #     return cs_def
  #   else:
  #     return None

  # @staticmethod
  # def from_proto(critical_section_def, import_scope=None):
  #   return CriticalSection(
  #       critical_section_def=critical_section_def, import_scope=import_scope)


# TODO(ebrevdo): Re-enable once CriticalSection is in core.

# def _execution_to_proto_fn(execution_signature, export_scope=None):
#   """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
#   # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.

#   Args:
#     execution_signature: Instance of `_ExecutionSignature`.
#     export_scope: The export scope, if any.

#   Returns:
#     An instance of `CriticalSectionExecutionDef`.
#   """
#   if (export_scope is None
#       or execution_signature.op.name.startswith(export_scope)):
#     op_def = critical_section_pb2.CriticalSectionExecutionDef()
#     op_def.execute_in_critical_section_name = ops.strip_name_scope(
#         execution_signature.op.name, export_scope)
#     op_def.exclusive_resource_access = (
#         execution_signature.exclusive_resource_access)
#     return op_def
#   else:
#     return None


# def _execution_from_proto_fn(op_def, import_scope=None):
#   """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
#   # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
#   assert isinstance(
#       op_def, critical_section_pb2.CriticalSectionExecutionDef)

#   # Create from op_def.
#   g = ops.get_default_graph()
#   execution_op = g.as_graph_element(
#       ops.prepend_name_scope(
#           op_def.execute_in_critical_section_name,
#           import_scope=import_scope))
#   return _ExecutionSignature(
#       op=execution_op,
#       exclusive_resource_access=op_def.exclusive_resource_access)

# ops.register_proto_function(
#     CRITICAL_SECTIONS,
#     proto_type=critical_section_pb2.CriticalSectionDef,
#     to_proto=CriticalSection.to_proto,
#     from_proto=CriticalSection.from_proto)

# ops.register_proto_function(
#     CRITICAL_SECTION_EXECUTIONS,
#     proto_type=critical_section_pb2.CriticalSectionExecutionDef,
#     to_proto=_execution_to_proto_fn,
#     from_proto=_execution_from_proto_fn)