Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for running a computation across multiple devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import enum # pylint: disable=g-bad-import-order
import threading
import weakref
import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context as eager_context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import loss_reduction
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
# ------------------------------------------------------------------------------
# Context tracking whether in a strategy.update() or .update_non_slot() call.
_update_device = threading.local()
def get_update_device():
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
try:
return _update_device.current
except AttributeError:
return None
class UpdateContext(object):
"""Context manager when you are in `update()` or `update_non_slot()`."""
def __init__(self, device):
self._device = device
self._old_device = None
def __enter__(self):
self._old_device = get_update_device()
_update_device.current = self._device
def __exit__(self, exception_type, exception_value, traceback):
del exception_type, exception_value, traceback
_update_device.current = self._old_device
# ------------------------------------------------------------------------------
# Public utility functions.
@tf_export(v1=["distribute.get_loss_reduction"])
def get_loss_reduction():
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction.
This is used to decide whether loss should be scaled in optimizer (used only
for estimator + v1 optimizer use case).
Returns:
`tf.distribute.ReduceOp` corresponding to the last loss reduction for
estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
"""
if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access
# If we are not in Estimator context then return 'SUM'. We do not need to
# scale loss in the optimizer.
return reduce_util.ReduceOp.SUM
last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if (last_reduction == losses_impl.Reduction.SUM or
last_reduction == loss_reduction.ReductionV2.SUM):
return reduce_util.ReduceOp.SUM
return reduce_util.ReduceOp.MEAN
# ------------------------------------------------------------------------------
# Internal API for validating the current thread mode
def _require_cross_replica_or_default_context_extended(extended):
"""Verify in cross-replica context."""
context = _get_per_thread_mode()
cross_replica = context.cross_replica_context
if cross_replica is not None and cross_replica.extended is extended:
return
if context is _get_default_replica_mode():
return
strategy = extended._container_strategy() # pylint: disable=protected-access
# We have an error to report, figure out the right message.
if context.strategy is not strategy:
_wrong_strategy_scope(strategy, context)
assert cross_replica is None
raise RuntimeError("Method requires being in cross-replica context, use "
"get_replica_context().merge_call()")
def _wrong_strategy_scope(strategy, context):
# Figure out the right error message.
if not distribution_strategy_context.has_strategy():
raise RuntimeError(
'Need to be inside "with strategy.scope()" for %s' %
(strategy,))
else:
raise RuntimeError(
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
(context.strategy, strategy))
def require_replica_context(replica_ctx):
"""Verify in `replica_ctx` replica context."""
context = _get_per_thread_mode()
if context.replica_context is replica_ctx: return
# We have an error to report, figure out the right message.
if context.replica_context is None:
raise RuntimeError("Need to be inside `call_for_each_replica()`")
if context.strategy is replica_ctx.strategy:
# Two different ReplicaContexts with the same tf.distribute.Strategy.
raise RuntimeError("Mismatching ReplicaContext.")
raise RuntimeError(
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
(context.strategy, replica_ctx.strategy))
def _require_strategy_scope_strategy(strategy):
"""Verify in a `strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.strategy is strategy: return
_wrong_strategy_scope(strategy, context)
def _require_strategy_scope_extended(extended):
"""Verify in a `distribution_strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.strategy.extended is extended: return
# Report error.
strategy = extended._container_strategy() # pylint: disable=protected-access
_wrong_strategy_scope(strategy, context)
# ------------------------------------------------------------------------------
# Internal context managers used to implement the DistributionStrategy
# base class
class _CurrentDistributionContext(object):
"""Context manager setting the current `tf.distribute.Strategy`.
Also: overrides the variable creator and optionally the current device.
"""
def __init__(self,
strategy,
var_creator_scope,
var_scope=None,
default_device=None):
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
if default_device:
self._device_scope = ops.device(default_device)
else:
self._device_scope = None
self._same_scope_again_count = 0
def __enter__(self):
# Allow this scope to be entered if this strategy is already in scope.
if distribution_strategy_context.has_strategy():
_require_cross_replica_or_default_context_extended(
self._context.strategy.extended)
self._same_scope_again_count += 1
else:
_push_per_thread_mode(self._context)
if self._var_scope:
self._var_scope.__enter__()
self._var_creator_scope.__enter__()
if self._device_scope:
self._device_scope.__enter__()
return self._context.strategy
def __exit__(self, exception_type, exception_value, traceback):
if self._same_scope_again_count > 0:
self._same_scope_again_count -= 1
return
if self._device_scope:
try:
self._device_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Device scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
try:
self._var_creator_scope.__exit__(
exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable creator scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
if self._var_scope:
try:
self._var_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
_pop_per_thread_mode()
# TODO(yuefengz): add more replication modes.
@tf_export("distribute.InputReplicationMode")
class InputReplicationMode(enum.Enum):
"""Replication mode for input function.
* `PER_WORKER`: The input function will be called on each worker
independently, creating as many input pipelines as number of workers.
Replicas will dequeue from the local Dataset on their worker.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
"""
PER_WORKER = "PER_WORKER"
@tf_export("distribute.InputContext")
class InputContext(object):
"""A class wrapping information needed by an input function.
This is a context class that is passed to the user's input fn and contains
information about the compute replicas and input pipelines. The number of
compute replicas (in sync training) helps compute per input pipeline batch
size from the desired global batch size. Input pipeline information can be
used to return a different subset of the input in each input pipeline (for
e.g. shard the input pipeline, use a different input source etc).
"""
def __init__(self,
num_input_pipelines=1,
input_pipeline_id=0,
num_replicas_in_sync=1):
"""Initializes an InputContext object.
Args:
num_input_pipelines: the number of input pipelines in a cluster.
input_pipeline_id: the current input pipeline id, should be an int in
[0,`num_input_pipelines`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._num_input_pipelines = num_input_pipelines
self._input_pipeline_id = input_pipeline_id
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def input_pipeline_id(self):
"""Returns the input pipeline ID."""
return self._input_pipeline_id
@property
def num_input_pipelines(self):
"""Returns the number of input pipelines."""
return self._num_input_pipelines
def get_per_replica_batch_size(self, global_batch_size):
"""Returns the per-replica batch size.
Args:
global_batch_size: the global batch size which should be divisible by
`num_replicas_in_sync`.
Returns:
the per-replica batch size.
Raises:
ValueError: if `global_batch_size` not divisible by
`num_replicas_in_sync`.
"""
if global_batch_size % self._num_replicas_in_sync != 0:
raise ValueError("The `global_batch_size` %r is not divisible by "
"`num_replicas_in_sync` %r " %
(global_batch_size, self._num_replicas_in_sync))
return global_batch_size // self._num_replicas_in_sync
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.
@tf_export("distribute.Strategy", v1=[])
class Strategy(object):
"""A list of devices with a state & compute distribution policy.
See [the guide](https://www.tensorflow.org/alpha/guide/distribute_strategy)
for overview and examples.
"""
# TODO(josh11b): Raise an exception if variable partitioning requested before
# we add support.
# TODO(josh11b): Also `parameter_device_index` property?
# TODO(josh11b): `map()`
# TODO(josh11b): ClusterSpec/ClusterResolver
# TODO(josh11b): Partitioned computations, state; sharding
# TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
# TODO(josh11b): List of replicas with their worker and parameter devices
# (where the parameter devices may overlap in the ps case).
def __init__(self, extended):
self._extended = extended
# Flag that is used to indicate whether distribution strategy is used with
# Estimator. This is required for backward compatibility of loss scaling
# when using v1 optimizer with estimator.
self._scale_loss_for_estimator = False
@property
def extended(self):
"""`tf.distribute.StrategyExtended` with additional methods."""
return self._extended
@tf_contextlib.contextmanager
def _scale_loss_for_estimator_enabled(self):
"""Scope which sets a flag used for scaling losses in optimizer.
Yields:
`_scale_loss_for_estimator_enabled` is a context manager with a
side effect, but doesn't return a value.
"""
self._scale_loss_for_estimator = True
try:
yield
finally:
self._scale_loss_for_estimator = False
def scope(self):
"""Returns a context manager selecting this Strategy as current.
Inside a `with strategy.scope():` code block, this thread
will use a variable creator set by `strategy`, and will
enter its "cross-replica context".
Returns:
A context manager.
"""
return self._extended._scope(self) # pylint: disable=protected-access
@doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended`
def colocate_vars_with(self, colocate_with_variable):
"""DEPRECATED: use extended.colocate_vars_with() instead."""
return self._extended.colocate_vars_with(colocate_with_variable)
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def make_dataset_iterator(self, dataset):
"""DEPRECATED TF 1.x ONLY."""
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
"""DEPRECATED TF 1.x ONLY."""
if replication_mode != InputReplicationMode.PER_WORKER:
raise ValueError(
"Input replication mode not supported: %r" % replication_mode)
with self.scope():
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
input_fn, replication_mode=replication_mode)
def experimental_make_numpy_dataset(self, numpy_input):
"""Makes a dataset for input provided via a numpy array.
This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked,
as that is normal `tf.data.Dataset` behavior.
Returns:
A `tf.data.Dataset` representing `numpy_input`.
"""
return self.extended.experimental_make_numpy_dataset(
numpy_input, session=None)
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def experimental_run(self, fn, input_iterator=None):
"""DEPRECATED TF 1.x ONLY."""
with self.scope():
args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.experimental_run_v2(fn, args=args)
def experimental_distribute_dataset(self, dataset):
"""Distributes a tf.data.Dataset instance provided via `dataset`.
In a multi-worker setting, we will first attempt to distribute the dataset
by attempting to detect whether the dataset is being created out of
ReaderDatasets (e.g. TFRecordDataset, TextLineDataset, etc.) and if so,
attempting to shard the input files. Note that there has to be at least one
input file per worker. If you have less than one input file per worker, we
suggest that you should disable distributing your dataset using the method
below.
If that attempt is unsuccessful (e.g. the dataset is created from a
Dataset.range), we will shard the dataset evenly at the end by appending a
`.shard` operation to the end of the processing pipeline. This will cause
the entire preprocessing pipeline for all the data to be run on every
worker, and each worker will do redundant work. We will print a warning
if this method of sharding is selected.
You can disable dataset distribution using the `auto_shard` option in
`tf.data.experimental.DistributeOptions`.
Within each host, we will also split the data among all the worker devices
(if more than one a present), and this will happen even if multi-worker
sharding is disabled using the method above.
The following is an example:
```python
strategy = tf.distribute.MirroredStrategy()
# Create a dataset
dataset = dataset_ops.Dataset.TFRecordDataset([
"/a/1.tfr", "/a/2.tfr", "/a/3.tfr", /a/4.tfr"])
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
# Iterate over the distributed dataset
for x in dist_dataset:
# process dataset elements
strategy.experimental_run_v2(train_step, args=(x,))
```
Args:
dataset: `tf.data.Dataset` that will be sharded across all replicas using
the rules stated above.
Returns:
A `DistributedDataset` which returns inputs for each step of the
computation.
"""
return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""Runs ops in `fn` on each replica, with the given arguments.
When eager execution is enabled, executes ops specified by `fn` on each
replica. Otherwise, builds a graph to execute the ops on each replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `replica_id_in_sync_group`.
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
used, and whether eager execution is enabled, `fn` may be called one or more
times (once for each replica).
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `PerReplica` (if the values are unsynchronized),
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
single replica).
"""
with self.scope():
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
def reduce(self, reduce_op, value, axis):
"""Reduce `value` across replicas.
Given a per-replica value returned by `experimental_run_v2`, say a
per-example loss, the batch will be divided across all the replicas. This
function allows you to aggregate across replicas and optionally also across
batch elements. For example, if you have a global batch size of 8 and 2
replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
`[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just
aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful
when each replica is computing a scalar or some other value that doesn't
have a "batch" dimension (like a gradient). More often you will want to
aggregate across the global batch, which you can get by specifying the batch
dimension as the `axis`, typically `axis=0`. In this case it would return a
scalar `0+1+2+3+4+5+6+7`.
If there is a last partial batch, you will need to specify an axis so
that the resulting shape is consistent across replicas. So if the last
batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
would get a shape mismatch unless you specify `axis=0`. If you specify
`tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
denominator of 6. Contrast this with computing `reduce_mean` to get a
scalar value on each replica and this function to average those means,
which will weigh some values `1/8` and others `1/4`.
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined.
value: A "per replica" value, e.g. returned by `experimental_run_v2` to
be combined into a single tensor.
axis: Specifies the dimension to reduce along within each
replica's tensor. Should typically be set to the batch dimension, or
`None` to only reduce across replicas (e.g. if the tensor has no batch
dimension).
Returns:
A `Tensor`.
"""
# TODO(josh11b): support `value` being a nest.
_require_cross_replica_or_default_context_extended(self._extended)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
if axis is None:
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
if reduce_op == reduce_util.ReduceOp.SUM:
value = self.experimental_run_v2(
lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,))
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
if reduce_op != reduce_util.ReduceOp.MEAN:
raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
"not: %r" % reduce_op)
# TODO(josh11b): Support list/tuple and tensor axis values.
if not isinstance(axis, six.integer_types):
raise TypeError("Expected `axis` to be an integer not: %r" % axis)
def mean_reduce_helper(v, axis=axis):
"""Computes the numerator and denominator on each replica."""
numer = math_ops.reduce_sum(v, axis=axis)
if v.shape.rank is not None:
# Note(joshl): We support axis < 0 to be consistent with the
# tf.math.reduce_* operations.
if axis < 0:
if axis + v.shape.rank < 0:
raise ValueError(
"`axis` = %r out of range for `value` with rank %d" %
(axis, v.shape.rank))
axis += v.shape.rank
elif axis >= v.shape.rank:
raise ValueError(
"`axis` = %r out of range for `value` with rank %d" %
(axis, v.shape.rank))
# TF v2 returns `None` for unknown dimensions and an integer for
# known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
# or tensor_shape.Dimension(integer). `dimension_value` hides this
# difference, always returning `None` or an integer.
dim = tensor_shape.dimension_value(v.shape[axis])
if dim is not None:
# By returning a python value in the static shape case, we can
# maybe get a fast path for reducing the denominator.
return numer, dim
elif axis < 0:
axis = axis + array_ops.rank(v)
denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis]
# TODO(josh11b): Should we cast denom to v.dtype here instead of after the
# reduce is complete?
return numer, denom
numer, denom = self.experimental_run_v2(mean_reduce_helper, args=(value,))
# TODO(josh11b): Should batch reduce here instead of doing two.
numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access
denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access
denom = math_ops.cast(denom, numer.dtype)
return math_ops.truediv(numer, denom)
@doc_controls.do_not_doc_inheritable # DEPRECATED
def unwrap(self, value):
"""Returns the list of all local per-replica values contained in `value`.
DEPRECATED: Please use `experimental_local_results` instead.
Note: This only returns values on the workers initiated by this client.
When using a `Strategy` like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
will be its own client, and this function will only return values
computed on that worker.
Args:
value: A value returned by `experimental_run()`,
`extended.call_for_each_replica()`, or a variable created in `scope`.
Returns:
A tuple of values contained in `value`. If `value` represents a single
value, this returns `(value,).`
"""
return self._extended._local_results(value) # pylint: disable=protected-access
def experimental_local_results(self, value):
"""Returns the list of all local per-replica values contained in `value`.
Note: This only returns values on the workers initiated by this client.
When using a `Strategy` like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
will be its own client, and this function will only return values
computed on that worker.
Args:
value: A value returned by `experimental_run()`, `experimental_run_v2()`,
`extended.call_for_each_replica()`, or a variable created in `scope`.
Returns:
A tuple of values contained in `value`. If `value` represents a single
value, this returns `(value,).`
"""
return self._extended._local_results(value) # pylint: disable=protected-access
@doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only
def group(self, value, name=None):
"""Shortcut for `tf.group(self.experimental_local_results(value))`."""
return self._extended._group(value, name) # pylint: disable=protected-access
@property
def num_replicas_in_sync(self):
"""Returns number of replicas over which gradients are aggregated."""
return self._extended._num_replicas_in_sync # pylint: disable=protected-access
@doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string
def configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""DEPRECATED: use `update_config_proto` instead.
Configures the strategy class.
DEPRECATED: This method's functionality has been split into the strategy
constructor and `update_config_proto`. In the future, we will allow passing
cluster and config_proto to the constructor to configure the strategy. And
`update_config_proto` can be used to update the config_proto based on the
specific strategy.
"""
return self._extended._configure( # pylint: disable=protected-access
session_config, cluster_spec, task_type, task_id)
@doc_controls.do_not_generate_docs # DEPRECATED
def update_config_proto(self, config_proto):
"""DEPRECATED TF 1.x ONLY."""
return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
def __deepcopy__(self, memo):
# First do a regular deepcopy of `self`.
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
# One little fix-up: we want `result._extended` to reference `result`
# instead of `self`.
result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access
return result
def __copy__(self):
raise RuntimeError("Must only deepcopy DistributionStrategy.")
# TF v1.x version has additional deprecated APIs
@tf_export(v1=["distribute.Strategy"])
class StrategyV1(Strategy):
"""A list of devices with a state & compute distribution policy.
See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
for overview and examples.
"""
def make_dataset_iterator(self, dataset):
"""Makes an iterator for input provided via `dataset`.
DEPRECATED: This method is not available in TF 2.x.
Data from the given dataset will be distributed evenly across all the
compute replicas. We will assume that the input dataset is batched by the
global batch size. With this assumption, we will make a best effort to
divide each batch across all the replicas (one or more workers).
If this effort fails, an error will be thrown, and the user should instead
use `make_input_fn_iterator` which provides more control to the user, and
does not try to divide a batch across replicas.
The user could also use `make_input_fn_iterator` if they want to
customize which input is fed to which replica/worker etc.
Args:
dataset: `tf.data.Dataset` that will be distributed evenly across all
replicas.
Returns:
An `tf.distribute.InputIterator` which returns inputs for each step of the
computation. User should call `initialize` on the returned iterator.
"""
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
"""Returns an iterator split across replicas created from an input function.
DEPRECATED: This method is not available in TF 2.x.
The `input_fn` should take an `tf.distribute.InputContext` object where
information about batching and input sharding can be accessed:
```
def input_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
with strategy.scope():
iterator = strategy.make_input_fn_iterator(input_fn)
replica_results = strategy.experimental_run(replica_fn, iterator)
```
The `tf.data.Dataset` returned by `input_fn` should have a per-replica
batch size, which may be computed using
`input_context.get_per_replica_batch_size`.
Args:
input_fn: A function taking a `tf.distribute.InputContext` object and
returning a `tf.data.Dataset`.
replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
Only `PER_WORKER` is supported currently, which means there will be
a single call to `input_fn` per worker. Replicas will dequeue from the
local `tf.data.Dataset` on their worker.
Returns:
An iterator object that should first be `.initialize()`-ed. It may then
either be passed to `strategy.experimental_run()` or you can
`iterator.get_next()` to get the next value to pass to
`strategy.extended.call_for_each_replica()`.
"""
return super(StrategyV1, self).make_input_fn_iterator(
input_fn, replication_mode)
def experimental_make_numpy_dataset(self, numpy_input, session=None):
"""Makes a dataset for input provided via a numpy array.
This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked,
as that is normal `tf.data.Dataset` behavior.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
A `tf.data.Dataset` representing `numpy_input`.
"""
return self.extended.experimental_make_numpy_dataset(
numpy_input, session=session)
def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
DEPRECATED: This method is not available in TF 2.x. Please switch
to using `experimental_run_v2` instead.
When eager execution is enabled, executes ops specified by `fn` on each
replica. Otherwise, builds a graph to execute the ops on each replica.
Each replica will take a single, different input from the inputs provided by
one `get_next` call on the input iterator.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `replica_id_in_sync_group`.
IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
used, and whether eager execution is enabled, `fn` may be called one or more
times (once for each replica).
Args:
fn: The function to run. The inputs to the function must match the outputs
of `input_iterator.get_next()`. The output must be a `tf.nest` of
`Tensor`s.
input_iterator: (Optional) input iterator from which the inputs are taken.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `PerReplica` (if the values are unsynchronized),
`Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
single replica).
"""
return super(StrategyV1, self).experimental_run(
fn, input_iterator)
def reduce(self, reduce_op, value, axis=None):
return super(StrategyV1, self).reduce(reduce_op, value, axis)
reduce.__doc__ = Strategy.reduce.__doc__
def update_config_proto(self, config_proto):
"""Returns a copy of `config_proto` modified for use with this strategy.
DEPRECATED: This method is not available in TF 2.x.
The updated config has something needed to run a strategy, e.g.
configuration to run collective ops, or device filters to improve
distributed training performance.
Args:
config_proto: a `tf.ConfigProto` object.
Returns:
The updated copy of the `config_proto`.
"""
return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
# instead descend from StrategyExtendedV1.
@tf_export("distribute.StrategyExtended", v1=[])
class StrategyExtendedV2(object):
"""Additional APIs for algorithms that need to be distribution-aware.
The intent is that you can write an algorithm in a stylized way and
it will be usable with a variety of different
`tf.distribute.Strategy`
implementations. Each descendant will implement a different strategy
for distributing the algorithm across multiple devices/machines.
Furthermore, these changes can be hidden inside the specific layers
and other library classes that need special treatment to run in a
distributed setting, so that most users' model definition code can
run unchanged. The `tf.distribute.Strategy` API works the same way
with eager and graph execution.
First let's introduce a few high-level concepts:
* _Data parallelism_ is where we run multiple copies of the model
on different slices of the input data. This is in contrast to
_model parallelism_ where we divide up a single copy of a model
across multiple devices.
Note: we only support data parallelism for now, but
hope to add support for model parallelism in the future.
* A _replica_ is one copy of the model, running on one slice of the
input data.
* _Synchronous_, or more commonly _sync_, training is where the
updates from each replica are aggregated together before updating
the model variables. This is in contrast to _asynchronous_, or
_async_ training, where each replica updates the model variables
independently.
* Furthermore you might run your computation on multiple devices
on one machine (or "host"), or on multiple machines/hosts.
If you are running on multiple machines, you might have a
single master host that drives computation across all of them,
or you might have multiple clients driving the computation
asynchronously.
To distribute an algorithm, we might use some of these ingredients:
* Parameter servers: These are hosts that hold a single copy of
parameters/variables. All replicas that want to operate on a variable
retrieve it at the beginning of a step and send an update to be
applied at the end of the step. Can support either sync or async
training.
* Mirrored variables: These are variables that are copied to multiple
devices, where we keep the copies in sync by applying the same
updates to every copy. Normally would only be used with sync training.
* Reductions and Allreduce: A _reduction_ is some method of
aggregating multiple values into one value, like "sum" or
"mean". If doing sync training, we will perform a reduction on the
gradients to a parameter from all replicas before applying the
update. Allreduce is an algorithm for performing a reduction on
values from multiple devices and making the result available on
all of those devices.
* In the future we will have support for TensorFlow's partitioned
variables, where a single variable is split across multiple
devices.
We have then a few approaches we want to support:
* Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
This code should work as before, even if some of the layers, etc.
used by that code are written to be distribution-aware. This is done
by having a default `tf.distribute.Strategy` that gives ordinary behavior,
and by default being in a single replica context.
* Ordinary model code that you want to run using a specific
`tf.distribute.Strategy`. This can be as simple as:
```
with my_strategy.scope():
iterator = my_strategy.make_dataset_iterator(dataset)
session.run(iterator.initialize())
replica_train_ops = my_strategy.experimental_run_v2(
replica_fn, args=(iterator.get_next(),))
train_op = my_strategy.group(replica_train_ops)
```
This takes an ordinary `dataset` and `replica_fn` and runs it
distributed using a particular `tf.distribute.Strategy` in
`my_strategy`. Any variables created in `replica_fn` are created
using `my_strategy`'s policy, and library functions called by
`replica_fn` can use the `get_replica_context()` API to get enhanced
behavior in this case.
* If you want to write a distributed algorithm, you may use any of
the `tf.distribute.Strategy` APIs inside a
`with my_strategy.scope():` block of code.
Lower-level concepts:
* Wrapped values: In order to represent values parallel across devices
(either replicas or the devices associated with a particular value), we
wrap them in a "PerReplica" or "Mirrored" object that contains a map
from device to values. "PerReplica" is used when the value may be
different across replicas, and "Mirrored" when the value are the same.
* Unwrapping and merging: Consider calling a function `fn` on multiple
replicas, like `experimental_run_v2(fn, args=[w])` with an
argument `w` that is a wrapped value. This means `w` will have a map taking
replica device `d0` to `w0`, replica device `d1` to `w1`,
etc. `experimental_run_v2()` unwraps `w` before calling `fn`, so
it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return
values from `fn()`, which can possibly result in wrapped values. For
example, let's say `fn()` returns a tuple with three components: `(x, a,
v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component
is the same object `x` from every replica, then the first component of the
merged result will also be `x`. If the second component is different (`a`,
`b`, ...) from each replica, then the merged value will have a wrapped map
from replica device to the different values. If the third component is the
members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.),
then the merged result will be that mirrored variable (`v`).
* Replica context vs. Cross-replica context: _replica context_ is when we
are in some function that is being called once for each replica.
Otherwise we are in cross-replica context, which is useful for
calling `tf.distribute.Strategy` methods which operate across the
replicas (like `reduce_to()`). By default you start in a replica context
(the default "single replica context") and then some methods can
switch you back and forth, as described below.
* Worker devices vs. parameter devices: Most replica computations will
happen on worker devices. Since we don't yet support model
parallelism, there will be one worker device per replica. When using
parameter servers (see above), the set of devices holding
variables may be different, otherwise the parameter devices might
match the worker devices.
* Non-slot devices are some subset of the parameter devices where we
put all the non-slot variables. We need to ensure that all
non-slot variables are allocated on the same device, or mirrored
across the same set of devices. If you have some variable you want
to colocate all the non-slot variables with, you can use
`colocate_vars_with()` to get the remaining non-slot variables on
the same device. Otherwise you can use `non_slot_devices()` to
pick a consistent set of devices to pass to both
`colocate_vars_with()` and `update_non_slot()`.
When using a `tf.distribute.Strategy`, we have a new type dimension
called _locality_ that says what values are compatible with which
APIs:
* T: different value for each replica (e.g. a PerReplica-wrapped value).
* M: value is "mirrored" across replicas, i.e. there are copies with the
same value on each replica (e.g. a Mirrored-wrapped value).
* V(`v`): value is "mirrored" across all the devices which have a
copy of variable `v` (also a Mirrored-wrapped value, but over
parameter devices instead of worker devices).
* N: value is "mirrored" across all the "non-slot" devices
Rules for methods with respect to locality and single-replica vs.
cross-replica context:
* `with d.scope()`: default single-replica context -> cross-replica context
for `d`
* `with d.extended.colocate_vars_with(v)`: in replica/cross-replica context,
variables will be created with locality V(`v`). That is, if we write
`with d.extended.colocate_vars_with(v1):
v2 = tf.Variable(...)`, then `v2` will have locality V(`v1`),
i.e. locality V(`v2`) will equal V(`v1`).
* `with d.extended.colocate_vars_with(d.extended.non_slot_devices(...))`: in
replica/cross-replica context, variables will be created with locality N
* `v = tf.Variable(...)`: in replica/cross-replica context,
creates a variable (which by definition will have locality V(`v`), though
will match another locality if inside a `colocate_vars_with`
scope).
* `d.make_dataset_iterator(dataset)`: in cross-replica
context, produces an iterator with locality T
* `d.experimental_run_v2(fn, ...)`: in cross-replica context, runs
`fn()` in a replica context (and so may call `get_replica_context()` and
use its API, including `merge_call()` to get back to cross-replica
context), once for each replica. May use values with locality T or
M, and any variable.
* `d.extended.reduce_to(m, t, t)`: in cross-replica context, accepts t with
locality T and produces a value with locality M.
* `d.extended.reduce_to(m, t, v)`: in cross-replica context, accepts t with
locality T and produces a value with locality V(`v`).
* `d.extended.batch_reduce_to(m, [(t, v)]): see `d.extended.reduce_to()`
* `d.extended.update(v, fn, ...)`: in cross-replica context, runs `fn()` once
for each device `v` is copied to, all inputs should have locality
V(`v`), output will have locality V(`v`) as well.
* `d.extended.update_non_slot(d.extended.non_slot_devices(), fn)`: in
cross-replica context, like `d.extended.update()` except with locality N.
The standard pattern for updating variables is to:
1. Create an input iterator with `d.make_dataset_iterator()`.
2. Define each replica `d.experimental_run_v2()` up to the point of
getting a list of gradient, variable pairs.
3. Call `d.extended.reduce_to(VariableAggregation.SUM, t, v)` or
`d.extended.batch_reduce_to()` to sum the gradients (with locality T)
into values with locality V(`v`).
4. Call `d.extended.update(v)` for each variable to update its value.
Steps 3 and 4 are done automatically by class `Optimizer` if you call
its `apply_gradients` method in a replica context. Otherwise you can
manually call its `_distributed_apply` method in a cross-replica context.
Another thing you might want to do in the middle of your replica function is
an all-reduce of some intermediate value, using `d.extended.reduce_to()` or
`d.extended.batch_reduce_to()`. You simply provide the same tensor as the
input and destination.
Layers should expect to be called in a replica context, and can use
the `tf.distribute.get_replica_context` function to get a
`tf.distribute.ReplicaContext` object. The
`ReplicaContext` object has a `merge_call()` method for entering
cross-replica context where you can use `reduce_to()` (or
`batch_reduce_to()`) and then optionally `update()` to update state.
You may use this API whether or not a `tf.distribute.Strategy` is
being used, since there is a default implementation of
`ReplicaContext` and `tf.distribute.Strategy`.
NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
the `tf.distribute.Strategy` subclass is for instantiating your subclass of
`tf.distribute.StrategyExtended` in the `__init__` method.
"""
def __init__(self, container_strategy):
self._container_strategy_weakref = weakref.ref(container_strategy)
self._default_device = None
# This property is used to determine if we should set drop_remainder=True
# when creating Datasets from numpy array inputs.
self._require_static_shapes = False
def _container_strategy(self):
"""Get the containing `tf.distribute.Strategy`.
This should not generally be needed except when creating a new
`ReplicaContext` and to validate that the caller is in the correct
`scope()`.
Returns:
The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
"""
container_strategy = self._container_strategy_weakref()
assert container_strategy is not None
return container_strategy
def _scope(self, strategy):
"""Implementation of tf.distribute.Strategy.scope()."""
def creator_with_resource_vars(*args, **kwargs):
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["distribute_strategy"] = strategy
return self._create_variable(*args, **kwargs)
def distributed_getter(getter, *args, **kwargs):
if not self._allow_variable_partition():
if kwargs.pop("partitioner", None) is not None:
tf_logging.log_first_n(
tf_logging.WARN, "Partitioned variables are disabled when using "
"current tf.distribute.Strategy.", 1)
return getter(*args, **kwargs)
return _CurrentDistributionContext(
strategy,
variable_scope.variable_creator_scope(creator_with_resource_vars),
variable_scope.variable_scope(
variable_scope.get_variable_scope(),
custom_getter=distributed_getter), self._default_device)
def _allow_variable_partition(self):
return False
def _create_variable(self, next_creator, *args, **kwargs):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
def variable_created_in_scope(self, v):
"""Tests whether `v` was created while this strategy scope was active.
Variables created inside the strategy scope are "owned" by it:
>>> with strategy.scope():
... v = tf.Variable(1.)
>>> strategy.variable_created_in_scope(v)
True
Variables created outside the strategy are not owned by it:
>>> v = tf.Variable(1.)
>>> strategy.variable_created_in_scope(v)
False
Args:
v: A `tf.Variable` instance.
Returns:
True if `v` was created inside the scope, False if not.
"""
return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access
def colocate_vars_with(self, colocate_with_variable):
"""Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it
should only be used when creating variables (some implementations
work by changing variable creation, others work by using a
tf.compat.v1.colocate_with() scope).
This may only be used inside `self.scope()`.
Example usage:
```
with strategy.scope():
var1 = tf.Variable(...)
with strategy.extended.colocate_vars_with(var1):
# var2 and var3 will be created on the same device(s) as var1
var2 = tf.Variable(...)
var3 = tf.Variable(...)
def fn(v1, v2, v3):
# operates on v1 from var1, v2 from var2, and v3 from var3
# `fn` runs on every device `var1` is on, `var2` and `var3` will be there
# too.
strategy.extended.update(var1, fn, args=(var2, var3))
```
Args:
colocate_with_variable: A variable created in this strategy's `scope()`.
Variables created while in the returned context manager will be on the
same set of devices as `colocate_with_variable`.
Returns:
A context manager.
"""
def create_colocated_variable(next_creator, *args, **kwargs):
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["colocate_with"] = colocate_with_variable
return next_creator(*args, **kwargs)
_require_strategy_scope_extended(self)
self._validate_colocate_with_variable(colocate_with_variable)
return variable_scope.variable_creator_scope(create_colocated_variable)
def _validate_colocate_with_variable(self, colocate_with_variable):
"""Validate `colocate_with_variable` argument to `colocate_vars_with`."""
pass
def _make_dataset_iterator(self, dataset):
raise NotImplementedError("must be implemented in descendants")
def _make_input_fn_iterator(self, input_fn, replication_mode):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_dataset(self, dataset):
raise NotImplementedError("must be implemented in descendants")
def _reduce(self, reduce_op, value):
# Default implementation until we have an implementation for each strategy.
return self._local_results(
self._reduce_to(reduce_op, value,
device_util.current() or "/device:CPU:0"))[0]
def reduce_to(self, reduce_op, value, destinations):
"""Combine (via e.g. sum or mean) values across replicas.
Args:
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value: A per-replica value with one value per replica.
destinations: A mirrored variable, a per-replica tensor, or a device
string. The return value will be copied to all destination devices (or
all the devices where the `destinations` value resides). To perform an
all-reduction, pass `value` to `destinations`.
Returns:
A value mirrored to `destinations`.
"""
# TODO(josh11b): More docstring
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
assert (reduce_op == reduce_util.ReduceOp.SUM or
reduce_op == reduce_util.ReduceOp.MEAN)
return self._reduce_to(reduce_op, value, destinations)
def _reduce_to(self, reduce_op, value, destinations):
raise NotImplementedError("must be implemented in descendants")
def batch_reduce_to(self, reduce_op, value_destination_pairs):
"""Combine multiple `reduce_to` calls into one for faster execution.
Args:
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce_to()` for a description.
Returns:
A list of mirrored values, one per pair in `value_destination_pairs`.
"""
# TODO(josh11b): More docstring
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(reduce_op, variable_scope.VariableAggregation)
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
return self._batch_reduce_to(reduce_op, value_destination_pairs)
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
return [
self.reduce_to(reduce_op, t, destinations=v)
for t, v in value_destination_pairs
]
def update(self, var, fn, args=(), kwargs=None, group=True):
"""Run `fn` to update `var` using inputs mirrored to the same devices.
If `var` is mirrored across multiple devices, then this implements
logic like:
```
results = {}
for device, v in var:
with tf.device(device):
# args and kwargs will be unwrapped if they are mirrored.
results[device] = fn(v, *args, **kwargs)
return merged(results)
```
Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
Neither `args` nor `kwargs` may contain per-replica values.
If they contain mirrored values, they will be unwrapped before
calling `fn`.
Args:
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
args: Tuple or list. Additional positional arguments to pass to `fn()`.
kwargs: Dict with keyword arguments to pass to `fn()`.
group: Boolean. Defaults to True. If False, the return value will be
unwrapped.
Returns:
By default, the merged return value of `fn` across all replicas. The
merged result has dependencies to make sure that if it is evaluated at
all, the side effects (updates) will happen on every replica. If instead
"group=False" is specified, this function will return a nest of lists
where each list has an element per replica, and the caller is responsible
for ensuring all elements are executed.
"""
_require_cross_replica_or_default_context_extended(self)
if kwargs is None:
kwargs = {}
with self._container_strategy().scope():
return self._update(var, fn, args, kwargs, group)
def _update(self, var, fn, args, kwargs, group):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(
self, colocate_with, fn, args=(), kwargs=None, group=True):
"""Runs `fn(*args, **kwargs)` on `colocate_with` devices.
Args:
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
args: Tuple or list. Positional arguments to pass to `fn()`.
kwargs: Dict with keyword arguments to pass to `fn()`.
group: Boolean. Defaults to True. If False, the return value will be
unwrapped.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_replica_or_default_context_extended(self)
if kwargs is None:
kwargs = {}
with self._container_strategy().scope():
return self._update_non_slot(colocate_with, fn, args, kwargs, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
raise NotImplementedError("must be implemented in descendants")
def _local_results(self, distributed_value):
raise NotImplementedError("must be implemented in descendants")
def value_container(self, value):
"""Returns the container that this per-replica `value` belongs to.
Args:
value: A value returned by `experimental_run_v2()` or a variable
created in `scope()`.
Returns:
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
`value in experimental_local_results(value_container(value))` will
always be true.
"""
raise NotImplementedError("must be implemented in descendants")
def _group(self, value, name=None):
"""Implementation of `group`."""
value = nest.flatten(self._local_results(value))
if len(value) != 1 or name is not None:
return control_flow_ops.group(value, name=name)
# Special handling for the common case of one op.
v, = value
if hasattr(v, "op"):
v = v.op
return v
@property
def experimental_require_static_shapes(self):
return self._require_static_shapes
@property
def _num_replicas_in_sync(self):
"""Returns number of replicas over which gradients are aggregated."""
raise NotImplementedError("must be implemented in descendants")
@property
def worker_devices(self):
"""Returns the tuple of all devices used to for compute replica execution.
"""
# TODO(josh11b): More docstring
raise NotImplementedError("must be implemented in descendants")
@property
def parameter_devices(self):
"""Returns the tuple of all devices used to place variables."""
# TODO(josh11b): More docstring
raise NotImplementedError("must be implemented in descendants")
def non_slot_devices(self, var_list):
"""Device(s) for non-slot variables.
Create variables on these devices in a
`with colocate_vars_with(non_slot_devices(...)):` block.
Update those using `update_non_slot()`.
Args:
var_list: The list of variables being optimized, needed with the
default `tf.distribute.Strategy`.
"""
raise NotImplementedError("must be implemented in descendants")
def _configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
"""Configures the strategy class."""
del session_config, cluster_spec, task_type, task_id
def _update_config_proto(self, config_proto):
return copy.deepcopy(config_proto)
@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring
class StrategyExtendedV1(StrategyExtendedV2):
__doc__ = StrategyExtendedV2.__doc__
def experimental_make_numpy_dataset(self, numpy_input, session=None):
"""Makes a dataset for input provided via a numpy array.
This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked, as
that is normal `tf.data.Dataset` behavior.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
A `tf.data.Dataset` representing `numpy_input`.
"""
_require_cross_replica_or_default_context_extended(self)
return self._experimental_make_numpy_dataset(numpy_input, session=session)
def _experimental_make_numpy_dataset(self, numpy_input, session):
raise NotImplementedError("must be implemented in descendants")
def broadcast_to(self, tensor, destinations):
"""Mirror a tensor on one device to all worker devices.
Args:
tensor: A Tensor value to broadcast.
destinations: A mirrored variable or device string specifying the
destination devices to copy `tensor` to.
Returns:
A value mirrored to `destinations` devices.
"""
assert destinations is not None # from old strategy.broadcast()
# TODO(josh11b): More docstring
_require_cross_replica_or_default_context_extended(self)
assert not isinstance(destinations, (list, tuple))
return self._broadcast_to(tensor, destinations)
def _broadcast_to(self, tensor, destinations):
raise NotImplementedError("must be implemented in descendants")
def experimental_run_steps_on_iterator(self,
fn,
iterator,
iterations=1,
initial_loop_values=None):
"""Run `fn` with input from `iterator` for `iterations` times.
This method can be used to run a step function for training a number of
times using input from a dataset.
Args:
fn: function to run using this distribution strategy. The function must
have the following signature: `def fn(context, inputs)`. `context` is an
instance of `MultiStepContext` that will be passed when `fn` is run.
`context` can be used to specify the outputs to be returned from `fn`
by calling `context.set_last_step_output`. It can also be used to
capture non tensor outputs by `context.set_non_tensor_output`. See
`MultiStepContext` documentation for more information. `inputs` will
have same type/structure as `iterator.get_next()`. Typically, `fn`
will use `call_for_each_replica` method of the strategy to distribute
the computation over multiple replicas.
iterator: Iterator of a dataset that represents the input for `fn`. The
caller is responsible for initializing the iterator as needed.
iterations: (Optional) Number of iterations that `fn` should be run.
Defaults to 1.
initial_loop_values: (Optional) Initial values to be passed into the
loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
initial_loop_values argument when we have a mechanism to infer the
outputs of `fn`.
Returns:
Returns the `MultiStepContext` object which has the following properties,
among other things:
- run_op: An op that runs `fn` `iterations` times.
- last_step_outputs: A dictionary containing tensors set using
`context.set_last_step_output`. Evaluating this returns the value of
the tensors after the last iteration.
- non_tensor_outputs: A dictionatry containing anything that was set by
`fn` by calling `context.set_non_tensor_output`.
"""
_require_cross_replica_or_default_context_extended(self)
with self._container_strategy().scope():
return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
initial_loop_values)
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values):
raise NotImplementedError("must be implemented in descendants")
def call_for_each_replica(self, fn, args=(), kwargs=None):
"""Run `fn` once per replica.
`fn` may call `tf.get_replica_context()` to access methods such as
`replica_id_in_sync_group` and `merge_call()`.
`merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a `merge_call()` call. After that the
`merge_fn`-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
`fn` is complete or encounters another `merge_call()`. Example:
```python
# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
# sum the values across replicas
return sum(distribution.experimental_local_results(three_plus_replica_id))
# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
replica_ctx = tf.get_replica_context()
v = three + replica_ctx.replica_id_in_sync_group
# Computes the sum of the `v` values across all replicas.
s = replica_ctx.merge_call(merge_fn, args=(v,))
return s + v
with distribution.scope():
# in "cross-replica" context
...
merged_results = distribution.experimental_run_v2(fn, args=[3])
# merged_results has the values from every replica execution of `fn`.
# This statement prints a list:
print(distribution.experimental_local_results(merged_results))
```
Args:
fn: function to run (will be run once per replica).
args: Tuple or list with positional arguments for `fn`.
kwargs: Dict with keyword arguments for `fn`.
Returns:
Merged return value of `fn` across all replicas.
"""
_require_cross_replica_or_default_context_extended(self)
if kwargs is None:
kwargs = {}
with self._container_strategy().scope():
return self._call_for_each_replica(fn, args, kwargs)
def _call_for_each_replica(self, fn, args, kwargs):
raise NotImplementedError("must be implemented in descendants")
def read_var(self, v):
"""Reads the value of a variable.
Returns the aggregate value of a replica-local variable, or the
(read-only) value of any other variable.
Args:
v: A variable allocated within the scope of this `tf.distribute.Strategy`.
Returns:
A tensor representing the value of `v`, aggregated across replicas if
necessary.
"""
raise NotImplementedError("must be implemented in descendants")
@property
def experimental_between_graph(self):
"""Whether the strategy uses between-graph replication or not.
This is expected to return a constant value that will not be changed
throughout its life cycle.
"""
raise NotImplementedError("must be implemented in descendants")
@property
def experimental_should_init(self):
"""Whether initialization is needed."""
raise NotImplementedError("must be implemented in descendants")
@property
def should_checkpoint(self):
"""Whether checkpointing is needed."""
raise NotImplementedError("must be implemented in descendants")
@property
def should_save_summary(self):
"""Whether saving summaries is needed."""
raise NotImplementedError("must be implemented in descendants")
# A note about the difference between the context managers
# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
# (defined above) used by `tf.distribute.Strategy.scope()`:
#
# * a ReplicaContext is only present during a `experimental_run_v2()`
# call (except during a `merge_run` call) and in such a scope it
# will be returned by calls to `get_replica_context()`. Implementers of new
# Strategy descendants will frequently also need to
# define a descendant of ReplicaContext, and are responsible for
# entering and exiting this context.
#
# * Strategy.scope() sets up a variable_creator scope that
# changes variable creation calls (e.g. to make mirrored
# variables). This is intended as an outer scope that users enter once
# around their model creation and graph definition. There is no
# anticipated need to define descendants of _CurrentDistributionContext.
# It sets the current Strategy for purposes of
# `get_strategy()` and `has_strategy()`
# and switches the thread mode to a "cross-replica context".
@tf_export("distribute.ReplicaContext")
class ReplicaContext(object):
"""`tf.distribute.Strategy` API when in a replica context.
To be used inside your replicated step function, such as in a
`tf.distribute.Strategy.experimental_run_v2` call.
"""
def __init__(self, strategy, replica_id_in_sync_group):
self._strategy = strategy
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
self)
self._replica_id_in_sync_group = replica_id_in_sync_group
self._summary_recording_distribution_strategy = None
def __enter__(self):
_push_per_thread_mode(self._thread_context)
ctx = eager_context.context()
def replica_id_is_zero():
return math_ops.equal(self._replica_id_in_sync_group,
constant_op.constant(0))
self._summary_recording_distribution_strategy = (
ctx.summary_recording_distribution_strategy)
ctx.summary_recording_distribution_strategy = replica_id_is_zero
def __exit__(self, exception_type, exception_value, traceback):
ctx = eager_context.context()
ctx.summary_recording_distribution_strategy = (
self._summary_recording_distribution_strategy)
_pop_per_thread_mode()
def merge_call(self, merge_fn, args=(), kwargs=None):
"""Merge args across replicas and run `merge_fn` in a cross-replica context.
This allows communication and coordination when there are multiple calls
to a model function triggered by a call to
`strategy.experimental_run_v2(model_fn, ...)`.
See `tf.distribute.Strategy.experimental_run_v2` for an
explanation.
If not inside a distributed scope, this is equivalent to:
```
strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
return merge_fn(strategy, *args, **kwargs)
```
Args:
merge_fn: function that joins arguments from threads that are given as
PerReplica. It accepts `tf.distribute.Strategy` object as
the first argument.
args: List or tuple with positional per-thread arguments for `merge_fn`.
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
Returns:
The return value of `merge_fn`, except for `PerReplica` values which are
unpacked.
"""
require_replica_context(self)
if kwargs is None:
kwargs = {}
return self._merge_call(merge_fn, args, kwargs)
def _merge_call(self, merge_fn, args, kwargs):
"""Default implementation for single replica."""
_push_per_thread_mode( # thread-local, so not needed with multiple threads
distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access
try:
return merge_fn(self._strategy, *args, **kwargs)
finally:
_pop_per_thread_mode()
@property
def num_replicas_in_sync(self):
"""Returns number of replicas over which gradients are aggregated."""
return self._strategy.num_replicas_in_sync
@property
def replica_id_in_sync_group(self):
"""Which replica is being defined, from 0 to `num_replicas_in_sync - 1`."""
require_replica_context(self)
return self._replica_id_in_sync_group
@property
def strategy(self):
"""The current `tf.distribute.Strategy` object."""
return self._strategy
@property
def devices(self):
"""The devices this replica is to be executed on, as a tuple of strings."""
require_replica_context(self)
return (device_util.current(),)
def all_reduce(self, reduce_op, value):
"""All-reduces the given `Tensor` nest across replicas.
If `all_reduce` is called in any replica, it must be called in all replicas.
The nested structure and `Tensor` shapes must be identical in all replicas.
IMPORTANT: The ordering of communications must be identical in all replicas.
Example with two replicas:
Replica 0 `value`: {'a': 1, 'b': [40, 1]}
Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
If `reduce_op` == `SUM`:
Result (on all replicas): {'a': 4, 'b': [42, 99]}
If `reduce_op` == `MEAN`:
Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
Args:
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value: The nested structure of `Tensor`s to all-reduced.
The structure must be compatible with `tf.nest`.
Returns:
A `Tensor` nest with the reduced `value`s from each replica.
"""
def batch_all_reduce(strategy, *value_flat):
return strategy.extended.batch_reduce_to(
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@custom_gradient.custom_gradient
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_reduce, args=xs)
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
else:
# TODO(cjfj): Implement gradients for other reductions.
reduced = nest.pack_sequence_as(
value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
return nest.map_structure(array_ops.prevent_gradient, reduced)
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
# all-reduce. It would return a function returning the result of reducing `t`
# across all replicas. The caller would wait to call this function until they
# needed the reduce result, allowing an efficient implementation:
# * With eager execution, the reduction could be performed asynchronously
# in the background, not blocking until the result was needed.
# * When constructing a graph, it could batch up all reduction requests up
# to that point that the first result is needed. Most likely this can be
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
def _batch_reduce_destination(x):
"""Returns the destinations for batch all-reduce."""
if isinstance(x, ops.Tensor): # One device strategies.
return x.device
else:
return x
# ------------------------------------------------------------------------------
class _DefaultDistributionStrategy(StrategyV1):
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
def __init__(self):
super(_DefaultDistributionStrategy, self).__init__(
_DefaultDistributionExtended(self))
class _DefaultDistributionExtended(StrategyExtendedV1):
"""Implementation of _DefaultDistributionStrategy."""
def _scope(self, strategy):
"""Context manager setting a variable creator and `self` as current."""
if distribution_strategy_context.has_strategy():
raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
def creator(next_creator, *args, **kwargs):
_require_strategy_scope_strategy(strategy)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
strategy, variable_scope.variable_creator_scope(creator))
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
_require_strategy_scope_extended(self)
return ops.colocate_with(colocate_with_variable)
def variable_created_in_scope(self, v):
return v._distribute_strategy is None # pylint: disable=protected-access
def _experimental_distribute_dataset(self, dataset):
return dataset
def _make_dataset_iterator(self, dataset):
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
def _make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
dataset = input_fn(InputContext())
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
def _experimental_make_numpy_dataset(self, numpy_input, session):
numpy_flat = nest.flatten(numpy_input)
vars_flat = tuple(
variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
trainable=False, use_resource=True)
for i in numpy_flat
)
for v, i in zip(vars_flat, numpy_flat):
numpy_dataset.init_var_from_numpy(v, i, session)
vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
return dataset_ops.Dataset.from_tensor_slices(vars_nested)
def _broadcast_to(self, tensor, destinations):
if destinations is None:
return tensor
else:
raise NotImplementedError("TODO")
def _call_for_each_replica(self, fn, args, kwargs):
with ReplicaContext(
self._container_strategy(),
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
return fn(*args, **kwargs)
def _reduce_to(self, reduce_op, value, destinations):
# TODO(josh11b): Use destinations?
del reduce_op, destinations
return value
def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with UpdateContext(colocate_with):
result = fn(*args, **kwargs)
if should_group:
return result
else:
return nest.map_structure(self._local_results, result)
def read_var(self, replica_local_var):
return array_ops.identity(replica_local_var)
def _local_results(self, distributed_value):
return (distributed_value,)
def value_container(self, value):
return value
@property
def _num_replicas_in_sync(self):
return 1
@property
def worker_devices(self):
raise RuntimeError("worker_devices() method unsupported by default "
"tf.distribute.Strategy.")
@property
def parameter_devices(self):
raise RuntimeError("parameter_devices() method unsupported by default "
"tf.distribute.Strategy.")
def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name)
# TODO(priyag): This should inherit from `InputIterator`, once dependency
# issues have been resolved.
class DefaultInputIterator(object):
"""Default implementation of `InputIterator` for default strategy."""
def __init__(self, dataset):
self._dataset = dataset
if eager_context.executing_eagerly():
self._iterator = dataset.make_one_shot_iterator()
else:
self._iterator = dataset.make_initializable_iterator()
def get_next(self):
return self._iterator.get_next()
def initialize(self):
if eager_context.executing_eagerly():
self._iterator = self._dataset.make_one_shot_iterator()
return []
else:
return [self._iterator.initializer]
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
"""Global and per-replica batching are equivalent for this strategy."""
return True
# ------------------------------------------------------------------------------
# We haven't yet implemented deserialization for DistributedVariables.
# So here we catch any attempts to deserialize variables
# when using distribution strategies.
# pylint: disable=protected-access
_original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None):
if distribution_strategy_context.has_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using a "
"tf.distribute.Strategy.")
else:
return _original_from_proto(v, import_scope=import_scope)
resource_variable_ops._from_proto_fn = _from_proto_fn
# pylint: enable=protected-access
#-------------------------------------------------------------------------------
# Shorthand for some methods from distribution_strategy_context.
_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access
_get_default_replica_mode = (
distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access