Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensorflow / purelib / tensorflow / python / ops / cond_v2.py
Size: Mime:
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""cond_v2 and gradient.

This is a version of cond that emits a single If op, as well as the gradient
function for If ops produced by cond_v2. This will eventually replace the
current tf.cond implementation once it reaches feature and performance parity.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2 as util
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest


# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
# that they aren't part of the official public API. These protected members
# often need to be used by implementation code however. Rather than litter the
# code with pylint comments, we ignore protected access violations for
# readability.
# pylint: disable=protected-access

_COND = 1
_CASE = 2


def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)

    verify_captures(_COND, [true_graph, false_graph])
    return _build_cond(pred, true_graph, false_graph,
                       true_graph.external_captures,
                       false_graph.external_captures,
                       name=scope)


@ops.RegisterGradient("If")
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of an If op produced by cond_v2."""
  # Get the if operator (this logic handles the case where op is a MockOp)
  if_op = op.outputs[0].op
  true_graph, false_graph = _get_func_graphs(if_op)
  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
  # of a nested cond.
  assert true_graph.outer_graph == if_op.graph
  assert false_graph.outer_graph == if_op.graph

  # Create grad functions that compute the gradient of the true/false forward
  # graphs. These functions will capture tensors from the forward pass
  # functions.
  true_grad_graph = _create_grad_func(
      true_graph, grads, util.unique_grad_fn_name(true_graph.name))
  false_grad_graph = _create_grad_func(
      false_graph, grads, util.unique_grad_fn_name(false_graph.name))

  if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
    # Modify 'op' to output the intermediates needed by the grad functions. Note
    # that all needed intermediates are wrapped in optionals. Each optional
    # intermediate output will have a value iff its corresponding branch is
    # taken.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so output intermediates directly and
      # make them match via FakeParams, which can be converted to zeros in XLA.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      true_intermediates = true_grad_graph.xla_intermediates
      false_intermediates = false_grad_graph.xla_intermediates
      extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
          [true_graph, false_graph], [true_intermediates, false_intermediates])
    else:
      true_intermediates = true_grad_graph.wrapped_intermediates
      false_intermediates = false_grad_graph.wrapped_intermediates
      # Make outputs match by adding none optionals.
      extra_true_outputs, extra_false_outputs = _make_intermediates_match(
          [true_graph, false_graph], [true_intermediates, false_intermediates])

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)
    # TODO(skyewm): indicate it's an internal bug if this fails.
    _check_same_outputs(_COND, [true_graph, false_graph])

    true_graph.name += "_rewritten"
    false_graph.name += "_rewritten"

    if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
    if_op._set_func_attr("else_branch",
                         util.create_new_tf_function(false_graph))
    if_op._set_type_list_attr("Tout", true_graph.output_types)
    if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
    if_op._add_outputs(
        [t.dtype for t in extra_true_outputs],
        [t.shape for t in extra_true_outputs])

  # Resolve references to forward graph tensors in grad graphs and ensure
  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

  # This modifies true_grad_graph and false_grad_graph.
  _make_output_composite_tensors_match(_COND,
                                       [true_grad_graph, false_grad_graph])

  outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
                        true_grad_inputs, false_grad_inputs)

  # The predicate has no gradient.
  return [None] + outputs


def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
  _check_same_outputs(_COND, [true_graph, false_graph])

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match([true_graph, false_graph],
                                   [true_inputs, false_inputs])

  # Create the If op.
  with ops.control_dependencies(
      list(true_graph.control_captures) + list(false_graph.control_captures)):
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
  if_op = tensors[0].op
  util.maybe_set_lowering_attr(if_op)
  util.maybe_propagate_compile_time_consts_in_xla(if_op)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  if_op.graph.prevent_fetching(if_op)
  return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                            tensors)


def _get_func_graphs(op):
  """Returns `FuncGraph`s for the input op branches.

  Args:
    op: The If or Case Operation.

  Returns:
    A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
    for Case).
  """

  def _get_func_graph_for_branch(name_attr_list):
    """Generates and returns a FuncGraph for the given branch."""
    inputs = op.inputs[1:]  # First input is pred.
    input_shapes = [t.shape for t in inputs]
    fdef = op.graph._get_function(name_attr_list.name).definition
    # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
    # in the case of nested if ops or when the gradient is being computed
    # from inside a Defun. We build the `func_graph` with `op.graph` as its
    # `outer_graph`. This resembles how the `FuncGraph` was built in the
    # forward pass. We need this so that we can resolve references to tensors
    # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
    with op.graph.as_default():
      func_graph = function_def_to_graph.function_def_to_graph(
          fdef, input_shapes)
    func_graph.captures = collections.OrderedDict(zip(inputs,
                                                      func_graph.inputs))
    # Link the op so that the gradient code can use it.
    func_graph._forward_cond = op
    return func_graph

  if op.type == "If":
    return (_get_func_graph_for_branch(op.get_attr("then_branch")),
            _get_func_graph_for_branch(op.get_attr("else_branch")))
  elif op.type == "Case":
    return [_get_func_graph_for_branch(branch_fn)
            for branch_fn in op.get_attr("branches")]
  else:
    raise ValueError("Unsupported op type: {}".format(op.type))


def _grad_fn(func_graph, grads):
  """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
  # Filter out untrainable function outputs.
  # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
  # cause _GradientsHelper to raise an exception (e.g. the implementation
  # doesn't expect 'ys' to contain boolean tensors).
  assert len(func_graph.outputs) == len(grads)
  ys = []
  grad_ys = []
  for y, grad_y in zip(func_graph.outputs, grads):
    if not gradients_util.IsTrainable(y):
      continue
    ys.append(y)
    grad_ys.append(grad_y)

  # Build the gradient graph. Note that this builds the gradient computation of
  # func_graph in the current graph, which requires capturing tensors from
  # func_graph. The captured func_graph tensors are resolved to external tensors
  # in _resolve_grad_inputs.
  result = gradients_util._GradientsHelper(
      ys, func_graph.inputs, grad_ys=grad_ys,
      src_graph=func_graph)

  # Functions can't return None; replace Nones with zero tensors.
  # TODO(b/80444525): don't return anything here and make _IfGrad return None if
  # both branches have zero gradient.
  for i in range(len(result)):
    if result[i] is None:
      if func_graph.inputs[i].dtype == dtypes.resource:
        result[i] = array_ops.zeros(
            gen_resource_variable_ops.variable_shape(func_graph.inputs[i]))
      else:
        result[i] = array_ops.zeros_like(func_graph.inputs[i])

  return result


def _create_grad_func(func_graph, grads, name):
  """Returns the FuncGraph representation of _grad_fn."""
  return func_graph_module.func_graph_from_py_func(
      name,
      lambda: _grad_fn(func_graph, grads), [], {},
      func_graph=_CondGradFuncGraph(name, func_graph))


def _resolve_grad_inputs(cond_graph, grad_graph):
  """Returns the tensors to pass as inputs to `grad_graph`.

  The `grad_graph` may have external references to
  1. Its outer graph containing the input gradients. These references are kept
     as is.
  2. Tensors in the forward pass graph. These tensors may not be "live"
     when the gradient is being computed. We replace such references by their
     corresponding tensor in `cond_graph.outer_graph`. In the case of nested
     control flow or functions, the gradient logic handling
     `grad_graph.outer_graph` will make sure the tensor from
     `cond_graph.outer_graph` is also correctly captured.

  Args:
    cond_graph: FuncGraph. The forward-pass function.
    grad_graph: FuncGraph. The gradients function.

  Returns:
    A list of inputs tensors to be passed to grad_graph.
  """
  new_inputs = []

  for t in grad_graph.external_captures:
    # `t` must either be in `grad_graph.outer_graph` or in the forward
    # `cond_graph`.
    if t.graph != grad_graph.outer_graph:
      assert t.graph == cond_graph
      # `internal_captures` are not treated as intermediates and hence not added
      # to If op outputs. So we get the outer tensor corresponding to those
      # from the list of `external_captures`.
      try:
        t = t.graph._forward_cond.outputs[t.graph.outputs.index(t)]
      except ValueError:
        index = t.graph.internal_captures.index(t)
        t = t.graph.external_captures[index]

      # Note: We rely on the capturing logic of the gradient If op graph to
      # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
      # and while_v2 handle this while building their gradient functions.
      assert t.graph == cond_graph.outer_graph
    new_inputs.append(t)

  return new_inputs


def _get_intermediates(func_graph):
  """Returns all tensors in `func_graph` that aren't inputs or outputs."""
  intermediates = []
  for op in func_graph.get_operations():
    for t in op.outputs:
      if t in func_graph.inputs: continue
      if t in func_graph.outputs: continue
      intermediates.append(t)
  return intermediates


def _make_intermediates_match(branch_graphs, branch_optionals):
  """Returns new optionals lists that have matching signatures.

  This is done by mirroring each list in the other using none optionals.
  There is no merging of like optionals.

  Args:
    branch_graphs: `list` of `FuncGraph`.
    branch_optionals: `list` of `list`s of optional `Tensor`s from other
      branch_graphs

  Returns:
    A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
    same number of `Tensor`s, all of which will be optionals of the same
    shape/type.
  """
  new_branch_optionals = []
  # Since the intermediates are optionals with dtype variant, we only need
  # enough room for the longest list of intermediates.
  intermediates_size = max(len(o) for o in branch_optionals)
  for i, branch_graph in enumerate(branch_graphs):
    other_optionals = _create_none_optionals(
        branch_graph, intermediates_size - len(branch_optionals[i]))
    new_branch_optionals.append(branch_optionals[i] + other_optionals)
  return new_branch_optionals


def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
  """Like _make_intermediates_match but for the XLA case."""
  new_branch_intermediates = []
  for i, branch_graph in enumerate(branch_graphs):
    other_fakeparams = _create_fakeparams(
        branch_graph,
        sum((bi for bi in branch_intermediates
             if bi is not branch_intermediates[i]), []))
    num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
    new_branch_intermediates.append(other_fakeparams[:num_preceding] +
                                    branch_intermediates[i] +
                                    other_fakeparams[num_preceding:])
  return new_branch_intermediates


def _make_inputs_match(branch_graphs, branch_inputs):
  """Modifies branch_graphs so they have the same input signature.

  This method reorders and/or adds parameters to each graph in branch_graphs so
  they have the same input signature, and updates the 'inputs' and 'captured'
  fields of each graph accordingly. It uses the input tensors from the outer
  graph to avoid duplicating shared arguments.

  Args:
    branch_graphs: a `list` of `FuncGraph`
    branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
      inputs for the corresponding graph in `branch_graphs`.

  Returns:
    A new list of Tensors from the outer graph that are the new inputs for each
    branch_graph. This is a deduped version of `sum(branch_inputs)`.
  """
  assert len(branch_graphs) == len(branch_inputs)
  new_inputs = set()
  for branch_in in branch_inputs:
    new_inputs |= set(branch_in)
  new_inputs = list(new_inputs)

  for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
    branch_input_to_param = dict(zip(branch_in, branch_graph.inputs))
    input_list = []
    for in_t in new_inputs:
      param = branch_input_to_param.get(in_t, None)
      if param is None:
        param = _create_dummy_input(branch_graph, in_t)
      input_list.append(param)

    branch_graph.inputs = input_list

    # Rewrite the FuncGraphs' state to reflect the new inputs.
    branch_graph.captures = collections.OrderedDict(
        zip(new_inputs, branch_graph.inputs))

  return new_inputs


def _make_output_composite_tensors_match(op_type, branch_graphs):
  """Modifies each branch_graph's outputs to have the same output signature.

  Currently the only transformation implemented is turning a Tensor into an
  equivalent IndexedSlices if the other branch returns an IndexedSlices.
  Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
  branch_graphs.

  Args:
    op_type: _COND or _CASE
    branch_graphs: `list` of `FuncGraph`

  Raises:
    TypeError: if a set of outputs cannot be rewritten.
  """
  # Note: since this is only used for gradient graphs, we do not expect the
  # outputs to be structured (e.g. nested lists), and thus do not need to use
  # nest.flatten, etc.
  assert branch_graphs
  branch_outputs = [g.structured_outputs for g in branch_graphs]
  outputs_per_branch = list(len(outs) for outs in branch_outputs)
  assert len(set(outputs_per_branch)) == 1, outputs_per_branch

  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
    if len(set(type(out) for out in branch_outs)) == 1:
      continue
    if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs):
      continue
    for branch_idx, branch_out in enumerate(branch_outs):
      if isinstance(branch_out, ops.IndexedSlices):
        continue
      elif isinstance(branch_out, ops.Tensor):
        with branch_graphs[branch_idx].as_default():
          branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
              branch_out)
      else:
        raise TypeError(
            "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
            "  outputs from all branches: {outputs}".format(
                op_name="tf.cond" if op_type == _COND else "tf.case",
                output_idx=output_idx,
                outputs=branch_outs))

  for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
    branch_graph.structured_outputs = branch_outs
    branch_graph.outputs = func_graph_module.flatten(branch_outs)


def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
  """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
  assert branch_graphs
  indexed_slice_indices = []
  current_index = 0
  branch_outputs_flat_with_composites = [
      nest.flatten(branch_graph.structured_outputs, expand_composites=False)
      for branch_graph in branch_graphs
  ]
  outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
  assert len(set(outs_per_branch)) == 1, outs_per_branch
  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
  for output_idx, branch_outs in enumerate(
      zip(*branch_outputs_flat_with_composites)):
    if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
      raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n"
                      "  branches returned: {outputs}".format(
                          op_name="tf.cond" if op_type == _COND else "tf.case",
                          output_idx=output_idx,
                          outputs=branch_outs))
    if isinstance(branch_outs[0], ops.IndexedSlices):
      # indices is the second component of the composite tensor.
      indexed_slice_indices.append(current_index + 1)
    if nest.is_sequence_or_composite(branch_outs[0]):
      current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
    else:
      current_index += 1

  if not indexed_slice_indices:
    return

  if current_index != len(branch_graphs[0].outputs):
    raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
                     "Expected: %i\n"
                     "Actual: %i" %
                     (current_index, len(branch_graphs[0].outputs)))

  # Cast indices with mismatching types to int64.
  for index in indexed_slice_indices:
    if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
           for bg in branch_graphs):
      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
                      "Found: %s" %
                      str([bg.outputs[index].dtype for bg in branch_graphs]))
    if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
      for branch_graph in branch_graphs:
        if branch_graph.outputs[index].dtype == dtypes.int32:
          with branch_graph.as_default():
            branch_graph.outputs[index] = math_ops.cast(
                branch_graph.outputs[index], dtypes.int64)

  for branch_graph in branch_graphs:
    branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
        branch_graph.structured_outputs, branch_graph.outputs)


def _wrap_intermediates(func_graph, intermediates):
  with func_graph.as_default():
    return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]


def _create_dummy_input(func_graph, template_tensor):
  """Creates tensors in func_graph to represent template_tensors.

  Args:
    func_graph: FuncGraph.
    template_tensor: a tensor in the outer graph.

  Returns:
    A tensor in func_graph.
  """
  with func_graph.as_default():
    return array_ops.placeholder(
        template_tensor.dtype, shape=template_tensor.shape)


def _create_none_optionals(func_graph, n):
  """Creates `n` `None` optionals in func_graph.

  Args:
    func_graph: FuncGraph.
    n: `int` the number of `None` optionals to make.

  Returns:
    A list of tensors in func_graph.
  """
  with func_graph.as_default():
    return [gen_dataset_ops.optional_none() for _ in range(n)]


def _create_fakeparams(func_graph, template_tensors):
  """Create FakeParams for the XLA case."""
  with func_graph.as_default():
    return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
            for t in template_tensors]


def _check_same_outputs(op_type, graphs):
  """Raises an error if `graphs` have different outputs."""

  def error(branch_idx, error_detail):
    raise TypeError(
        "{b0_name} and {bn_name} arguments to {op_name} must have the same "
        "number, type, and overall structure of return values.\n"
        "\n"
        "{b0_name} output: {b0_out}\n"
        "{bn_name} output: {bn_out}\n"
        "\n"
        "Error details:\n"
        "{detail}".format(
            b0_name="true_fn" if op_type == _COND else "branches[0]",
            bn_name=("false_fn" if op_type == _COND else
                     "branches[{}]".format(branch_idx)),
            op_name="tf.cond" if op_type == _COND else "tf.case",
            b0_out=graphs[0].structured_outputs,
            bn_out=graphs[branch_idx].structured_outputs,
            detail=error_detail))

  for b in range(1, len(graphs)):
    try:
      nest.assert_same_structure(
          graphs[0].structured_outputs,
          graphs[b].structured_outputs,
          expand_composites=True)
    except (ValueError, TypeError) as e:
      error(b, str(e))

    assert len(graphs[0].outputs) == len(graphs[b].outputs)
    for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
      if b0_out.dtype != bn_out.dtype:
        error(b, "%s and %s have different types" % (b0_out, bn_out))


def _get_output_shapes(*branch_graph_outputs):
  output_shapes = []
  for out_by_branch in zip(*branch_graph_outputs):
    shape = out_by_branch[0].shape
    for other_out in out_by_branch[1:]:
      shape = shape.most_specific_compatible_shape(other_out.shape)
    output_shapes.append(shape)
  return output_shapes


def verify_captures(op_type, branch_graphs):
  """Verify that a branch's tensor is not accessed in another branch fn."""
  # Note: It is technically not possible for lower-branch_index branches to
  # capture tensors from higher-branch_index branches, because of the order of
  # branch graph construction, but we check all for completeness and to
  # guard against potential future changes.
  other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
  for i, branch_graph in enumerate(branch_graphs):
    for t in branch_graph.external_captures:
      if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
        branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
            "branch {}".format(bi) for bi in range(len(branch_graphs))]
        raise ValueError(
            "Tensor {tname} in {b0name} is accessed from {b1name}.".format(
                tname=t.name,
                b0name=branch_names[other_branch_graphs[t.graph]],
                b1name=branch_names[i]))


class _CondGradFuncGraph(util.CondBranchFuncGraph):
  """FuncGraph for the gradient function of the branch of an If op.

  Handles wrapping and unwrapping intermediate values that are captured by the
  gradient computation in optionals.

  Attributes:
    op_needs_rewrite: True if any intermediates were captured, meaning the
      forward If op needs to be written to output the wrapped intermediates.
  """

  def __init__(self, name, forward_graph):
    super(_CondGradFuncGraph, self).__init__(
        name, collections=ops.get_default_graph()._collections)  # pylint: disable=protected-access
    self.op_needs_rewrite = False
    self._forward_graph = forward_graph
    # Maps from forward intermediate tensor -> the unwrapped captured
    # intermediate.
    self._indirect_captures = {}
    # Maps unwrapped intermediate -> optional-wrapped intermediate in the
    # forward graph.
    self._wrapped_intermediates = collections.OrderedDict()
    # Raw intermediates captured from the forward graph. Populated iff we're in
    # an XLA context.
    self._xla_intermediates = []

  @property
  def wrapped_intermediates(self):
    """The optional-wrapped intermediates captured from the forward graph."""
    return list(self._wrapped_intermediates.values())

  @property
  def xla_intermediates(self):
    """Raw intermediates captured from the forward graph if XLA is enabled."""
    return self._xla_intermediates

  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.op_needs_rewrite = True
        self._wrapped_intermediates[tensor] = optional

      optional = self._wrapped_intermediates[tensor]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor


def indexed_case(branch_index, branch_fns, name="indexed_case"):
  """Like conv_v2, except emits a Case op instead of an If."""
  if isinstance(branch_index, int):
    raise TypeError("branch_index must not be a Python int", branch_index)

  with ops.name_scope(name) as scope:
    branch_names = [
        util.unique_fn_name(scope, "branch{}".format(b))
        for b in range(len(branch_fns))
    ]

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
    branch_index = ops.convert_to_tensor(branch_index, name="branch_index")

    branch_graphs = []
    for branch_name, branch_fn in zip(branch_names, branch_fns):
      branch_graphs.append(
          func_graph_module.func_graph_from_py_func(
              branch_name,
              branch_fn,
              [],
              {},
              func_graph=util.CondBranchFuncGraph(
                  branch_name,
                  collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
              add_control_dependencies=add_control_dependencies,
              op_return_value=branch_index))

    verify_captures(_CASE, branch_graphs)
    return _build_case(
        branch_index,
        branch_graphs, [g.external_captures for g in branch_graphs],
        name=scope)


@ops.RegisterGradient("Case")
def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a Case op produced (w/ branch_index) by tf.case."""
  # Get the if operator (this logic handles the case where op is a MockOp)
  case_op = op.outputs[0].op
  branch_graphs = _get_func_graphs(case_op)
  assert branch_graphs
  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
  # of a nested cond.
  for branch_graph in branch_graphs:
    assert branch_graph.outer_graph == case_op.graph

  # Create grad functions that compute the gradient of the branch forward
  # graphs. These functions will capture tensors from the forward pass
  # functions.
  branch_grad_graphs = []
  for branch_graph in branch_graphs:
    branch_grad_graphs.append(
        _create_grad_func(branch_graph, grads,
                          util.unique_grad_fn_name(branch_graph.name)))

  if any(g.op_needs_rewrite for g in branch_grad_graphs):
    # Modify 'op' to output the intermediates needed by the grad functions. Note
    # that all needed intermediates are wrapped in optionals. Each optional
    # intermediate output will have a value iff its corresponding branch is
    # taken.
    # NOTE(bjp): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    if control_flow_util.InXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so output intermediates directly and
      # make them match via FakeParams, which can be converted to zeros in XLA.
      # TODO(bjp,jpienaar): can XLA support optionals?
      branches_intermediates = [
          branch_grad_graph.xla_intermediates
          for branch_grad_graph in branch_grad_graphs
      ]
      extra_branch_outputs = _make_intermediates_match_xla(
          branch_graphs, branches_intermediates)
    else:
      branch_intermediates = [
          g.wrapped_intermediates for g in branch_grad_graphs
      ]
      # Make outputs match by adding none optionals.
      extra_branch_outputs = _make_intermediates_match(branch_graphs,
                                                       branch_intermediates)

    for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
      branch_graph.outputs.extend(extra_outputs)
    _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
    # TODO(bjp): indicate it's an internal bug if this fails.
    _check_same_outputs(_CASE, branch_graphs)

    for branch_graph in branch_graphs:
      branch_graph.name += "_rewritten"

    case_op._set_func_list_attr("branches", [
        util.create_new_tf_function(branch_graph)
        for branch_graph in branch_graphs
    ])
    case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
    case_op._set_shape_list_attr("output_shapes",
                                 branch_graphs[0].output_shapes)
    case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
                         [t.shape for t in extra_branch_outputs[0]])

  # Resolve references to forward graph tensors in grad graphs and ensure
  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
  branches_grad_inputs = [
      _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
      branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
  ]

  # This modifies the graphs in branch_grad_graphs.
  _make_output_composite_tensors_match(_CASE, branch_grad_graphs)

  outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
                        branches_grad_inputs)

  # The predicate has no gradient.
  return [None] + outputs


def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
  """Creates an `Case` op from `branch_index`, branch graphs and inputs.

  Note that this modifies `branch_graphs` to make the inputs match, and to
  output all intermediates values so they're available for the gradient
  computation.

  `branch_graphs` need not have the same input types, but they must
  have the same outpute types.

  Args:
    branch_index: integer Tensor
    branch_graphs: List of FuncGraph
    branch_inputs: List of lists of Tensors to be passed to corresponding
      branch_graph as input.
    name: the name for the Case op.

  Returns:
    A list of Tensors which are the outputs of the Case op. Does not include
    added intermediate outputs.
  """
  _check_same_outputs(_CASE, branch_graphs)

  # Add inputs to branch_graphs to make them match. Note that this modifies the
  # graphs in `branch_graphs`.
  case_inputs = _make_inputs_match(branch_graphs, branch_inputs)

  # Create the Case op.
  with ops.control_dependencies(
      sum((list(bg.control_captures) for bg in branch_graphs), [])):
    tensors = gen_functional_ops.case(
        branch_index,
        case_inputs, [t.dtype for t in branch_graphs[0].outputs],
        [util.create_new_tf_function(g) for g in branch_graphs],
        output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
        name=name)

  # TODO(b/110167197) this requires Case to have at least 1 output
  case_op = tensors[0].op
  # TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode.
  # util.maybe_set_lowering_attr(case_op)
  util.maybe_propagate_compile_time_consts_in_xla(case_op)

  # Return identities for each output of the Case op, rather than the output of
  # the Case op directly. This makes pruning work if the output of select_case()
  # is fetched: the lowering pass converts the Case outputs into IdentityN
  # outputs, which if fetched will cause all ops in the taken branch to be run
  # (since it takes all merge ops as input). After lowering, each output
  # identity op will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  case_op.graph.prevent_fetching(case_op)
  return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs,
                                            tensors)