Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class MirroredStrategy implementing DistributionStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import copy
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import shared_variable_creator
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import coordinator
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
@contextlib.contextmanager
def _enter_graph(g, eager, creator_stack=None):
"""Context manager for selecting a graph and maybe eager mode."""
if eager:
with g.as_default(), context.eager_mode():
if creator_stack is not None:
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
yield
else:
with g.as_default():
if creator_stack is not None:
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
yield
def _cpu_device(device):
cpu_device = tf_device.DeviceSpec.from_string(device)
cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
return cpu_device.to_string()
class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
pass
# _call_for_each_replica is not a member of MirroredStrategy so that it is
# not allowed to use anything specific to MirroredStrategy and thus
# can be shared with other distribution strategies.
# TODO(yuefengz): maybe create a common class for those who need to call this
# _call_for_each_replica.
def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
"""Run `fn` in separate threads, once per replica/worker device.
Args:
distribution: the DistributionStrategy object.
device_map: the DeviceMap with the devices to run `fn` on.
fn: function to run (will be run once per replica, each in its own thread).
args: positional arguments for `fn`
kwargs: keyword arguments for `fn`.
Returns:
Merged return value of `fn` across all replicas.
Raises:
RuntimeError: If fn() calls get_replica_context().merge_call() a different
number of times from the available devices.
"""
# TODO(josh11b): Add this option once we add synchronization to variable
# creation. Until then, this is pretty unsafe to use.
run_concurrently = False
if not context.executing_eagerly():
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
shared_variable_store = {}
# TODO(isaprykin): Create these threads once instead of during every call.
threads = []
for index in range(device_map.num_replicas_in_graph):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = _MirroredReplicaThread(
distribution, coord, index, device_map, variable_creator_fn, fn,
values.select_replica(index, args),
values.select_replica(index, kwargs))
threads.append(t)
for t in threads:
t.start()
# When `fn` starts `should_run` event is set on _MirroredReplicaThread
# (`MRT`) threads. The execution waits until
# `MRT.has_paused` is set, which indicates that either `fn` is
# complete or a `get_replica_context().merge_call()` is called. If `fn` is
# complete, then `MRT.done` is set to True. Otherwise, arguments
# of `get_replica_context().merge_call` from all paused threads are grouped
# and the `merge_fn` is performed. Results of the
# `get_replica_context().merge_call` are then set to `MRT.merge_result`.
# Each such `get_replica_context().merge_call` call returns the
# `MRT.merge_result` for that thread when `MRT.should_run` event
# is reset again. Execution of `fn` resumes.
try:
with coord.stop_on_exception():
all_done = False
while not all_done and not coord.should_stop():
done = []
if run_concurrently:
for t in threads:
t.should_run.set()
for t in threads:
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
else:
for t in threads:
t.should_run.set()
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
if coord.should_stop():
return None
all_done = all(done)
if not all_done:
if any(done):
raise RuntimeError("Some replicas made a different number of "
"replica_context().merge_call() calls.")
# get_replica_context().merge_call() case
merge_args = values.regroup(
device_map, tuple(t.merge_args for t in threads))
merge_kwargs = values.regroup(
device_map, tuple(t.merge_kwargs for t in threads))
# We capture the name_scope of the MRT when we call merge_fn
# to ensure that if we have opened a name scope in the MRT,
# it will be respected when executing the merge function. We only
# capture the name_scope from the first MRT and assume it is
# the same for all other MRTs.
mtt_captured_name_scope = threads[0].captured_name_scope
mtt_captured_var_scope = threads[0].captured_var_scope
# Capture and merge the control dependencies from all the threads.
mtt_captured_control_deps = set()
for t in threads:
mtt_captured_control_deps.update(t.captured_control_deps)
with ops.name_scope(mtt_captured_name_scope),\
ops.control_dependencies(mtt_captured_control_deps), \
variable_scope.variable_scope(mtt_captured_var_scope):
merge_result = threads[0].merge_fn(distribution, *merge_args,
**merge_kwargs)
for r, t in enumerate(threads):
t.merge_result = values.select_replica(r, merge_result)
finally:
for t in threads:
t.should_run.set()
coord.join(threads)
return values.regroup(device_map, tuple(t.main_result for t in threads))
def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: disable=missing-docstring
real_mirrored_creator, *args, **kwargs):
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
collections = kwargs.pop("collections", None)
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
# Get synchronization value
synchronization = kwargs.get("synchronization",
variable_scope.VariableSynchronization.ON_WRITE)
if synchronization == variable_scope.VariableSynchronization.NONE:
raise ValueError("`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please"
" change the `synchronization` for variable: " +
kwargs["name"])
elif synchronization == variable_scope.VariableSynchronization.ON_READ:
# Variables that are to be synced on read are replica local.
is_sync_on_read = True
kwargs["trainable"] = False
elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
synchronization == variable_scope.VariableSynchronization.AUTO):
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
is_sync_on_read = False
else:
raise ValueError(
"Invalid variable synchronization mode: %s for variable: %s" %
(synchronization, kwargs["name"]))
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
variable_scope.VariableAggregation.MEAN,
variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
):
raise ValueError(
"Invalid variable aggregation mode: %s for variable: %s" %
(aggregation, kwargs["name"]))
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
devices = device_map.logical_to_actual_devices(logical_device)
value_list = real_mirrored_creator(devices, *args, **kwargs)
if is_sync_on_read:
result = values.SyncOnReadVariable(
strategy, device_map, value_list, aggregation,
logical_device=logical_device)
else:
result = values.MirroredVariable(
strategy, device_map, value_list, aggregation,
logical_device=logical_device)
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in value_list:
if v in l:
l.remove(v)
g.add_to_collections(collections, result)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
def _is_device_list_local(devices):
"""Checks whether the devices list is for local or multi-worker.
Args:
devices: a list of device strings, either local for remote devices.
Returns:
a boolean indicating whether these device strings are for local or for
remote.
Raises:
ValueError: if device strings are not consistent.
"""
all_local = None
for d in devices:
d_spec = tf_device.DeviceSpec.from_string(d)
is_local = d_spec.job in (None, "localhost")
if all_local is None: # Determine all_local from first device.
all_local = is_local
if all_local:
if not is_local:
raise ValueError("Local device string cannot have job specified other "
"than 'localhost'")
else:
if is_local:
raise ValueError("Remote device string must have job specified.")
if d_spec.task is None:
raise ValueError("Remote device string must have task specified.")
return all_local
def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
"""Returns a device list given a cluster spec."""
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
devices = []
for task_type in ("chief", "worker"):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
if num_gpus_per_worker == 0:
devices.append("/job:%s/task:%d" % (task_type, task_id))
else:
devices.extend([
"/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
for gpu_id in range(num_gpus_per_worker)
])
return devices
def _group_device_list(devices):
"""Groups the devices list by task_type and task_id.
Args:
devices: a list of device strings for remote devices.
Returns:
a dict of list of device strings mapping from task_type to a list of devices
for the task_type in the asceding order of task_id.
"""
assert not _is_device_list_local(devices)
device_dict = {}
for d in devices:
d_spec = tf_device.DeviceSpec.from_string(d)
# Create an entry for the task_type.
if d_spec.job not in device_dict:
device_dict[d_spec.job] = []
# Fill the device list for task_type until it covers the task_id.
while len(device_dict[d_spec.job]) <= d_spec.task:
device_dict[d_spec.job].append([])
device_dict[d_spec.job][d_spec.task].append(d)
return device_dict
def _is_gpu_device(device):
return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
def _infer_num_gpus_per_worker(devices):
"""Infers the number of GPUs on each worker.
Currently to make multi-worker cross device ops work, we need all workers to
have the same number of GPUs.
Args:
devices: a list of device strings, can be either local devices or remote
devices.
Returns:
number of GPUs per worker.
Raises:
ValueError if workers have different number of GPUs or GPU indices are not
consecutive and starting from 0.
"""
if _is_device_list_local(devices):
return sum(1 for d in devices if _is_gpu_device(d))
else:
device_dict = _group_device_list(devices)
num_gpus = None
for _, devices_in_task in device_dict.items():
for device_in_task in devices_in_task:
if num_gpus is None:
num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d))
# Verify other workers have the same number of GPUs.
elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)):
raise ValueError("All workers should have the same number of GPUs.")
for d in device_in_task:
d_spec = tf_device.DeviceSpec.from_string(d)
if (d_spec.device_type == "GPU" and
d_spec.device_index >= num_gpus):
raise ValueError("GPU `device_index` on a worker should be "
"consecutive and start from 0.")
return num_gpus
def all_local_devices(num_gpus=None):
if num_gpus is None:
num_gpus = context.num_gpus()
return device_util.local_devices_from_num_gpus(num_gpus)
def _all_devices():
devices = []
tfconfig = TFConfigClusterResolver()
if tfconfig.cluster_spec().as_dict():
devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
context.num_gpus())
return devices if devices else all_local_devices()
@tf_export("distribute.MirroredStrategy", v1=[])
class MirroredStrategy(distribute_lib.Strategy):
"""Mirrors vars to distribute across multiple devices and machines.
This strategy uses one replica per device and sync replication for its
multi-GPU version.
The multi-worker version will be added in the future.
Args:
devices: a list of device strings. If `None`, all available GPUs are used.
If no GPUs are found, CPU is used.
cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
set, nccl will be used by default.
"""
def __init__(self, devices=None, cross_device_ops=None):
extended = MirroredExtended(
self, devices=devices, cross_device_ops=cross_device_ops)
super(MirroredStrategy, self).__init__(extended)
@tf_export(v1=["distribute.MirroredStrategy"])
class MirroredStrategyV1(distribute_lib.StrategyV1):
__doc__ = MirroredStrategy.__doc__
def __init__(self, devices=None, cross_device_ops=None):
extended = MirroredExtended(
self, devices=devices, cross_device_ops=cross_device_ops)
super(MirroredStrategyV1, self).__init__(extended)
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
class MirroredExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of MirroredStrategy."""
def __init__(self, container_strategy, devices=None, cross_device_ops=None):
super(MirroredExtended, self).__init__(container_strategy)
if devices is None:
devices = _all_devices()
if not devices:
raise ValueError("Got an empty `devices` list. Please make sure the "
"`devices` you pass in is not empty.")
self._cross_device_ops = cross_device_ops
self._initialize_strategy(devices)
def _initialize_strategy(self, devices):
# The _initialize_strategy method is intended to be used by distribute
# coordinator as well.
if _is_device_list_local(devices):
self._initialize_local(devices)
else:
self._initialize_multi_worker(devices)
def _initialize_local(self, devices):
"""Initializes the object for local training."""
self._local_mode = True
assert devices, "Must specify at least one device."
devices = tuple(device_util.resolve(d) for d in devices)
assert len(set(devices)) == len(devices), (
"No duplicates allowed in `devices` argument: %s" % (devices,))
# TODO(josh11b): Require at least 2 devices?
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = input_lib.InputWorkers(self._device_map)
self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best(
devices)
self._host_input_device = numpy_dataset.SingleDevice("/cpu:0")
def _initialize_multi_worker(self, devices):
"""Initializes the object for multi-worker training."""
self._local_mode = False
assert devices, "Must specify at least one device."
devices = tuple(device_util.resolve(d) for d in devices)
assert len(set(devices)) == len(devices), (
"No duplicates allowed in `devices` argument: %s" % devices)
# TODO(josh11b): Require at least 2 devices?
device_dict = _group_device_list(devices)
workers = []
worker_devices = []
for job in ("chief", "worker"):
for task in range(len(device_dict.get(job, []))):
worker = "/job:%s/task:%d" % (job, task)
workers.append(worker)
worker_devices.append((worker, device_dict[job][task]))
# Setting `_default_device` will add a device scope in the
# distribution.scope. We set the default device to the first worker. When
# users specify device under distribution.scope by
# with tf.device("/cpu:0"):
# ...
# their ops will end up on the cpu device of its first worker, e.g.
# "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
self._default_device = workers[0]
self._host_input_device = numpy_dataset.SingleDevice(workers[0])
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = input_lib.InputWorkers(
self._device_map, worker_devices)
if len(workers) > 1:
if not isinstance(self._cross_device_ops,
cross_device_ops_lib.MultiWorkerAllReduce):
raise ValueError(
"In-graph multi-worker training with `MirroredStrategy` is not "
"supported.")
self._inferred_cross_device_ops = self._cross_device_ops
else:
# TODO(yuefengz): make `choose_the_best` work with device strings
# containing job names.
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
device_map = self._device_map
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
else:
device_map = colocate_with.device_map
logical_device = colocate_with.logical_device
def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
if i > 0:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
# We append a / to variable names created on replicas with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
def initial_value_fn(device=d):
if context.executing_eagerly() or ops.inside_function():
init_value = value_list[0].value()
return array_ops.identity(init_value)
else:
with ops.device(device):
init_value = value_list[0].initial_value
return array_ops.identity(init_value)
kwargs["initial_value"] = initial_value_fn
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
# Don't record operations (e.g. other variable reads) during
# variable creation.
with tape.stop_recording():
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
value_list.append(v)
return value_list
return _create_mirrored_variable(
self._container_strategy(), device_map, logical_device,
_real_mirrored_creator, *args, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate_distributed_variable(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
return input_lib.DatasetIterator(
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
input_contexts = []
num_workers = self._input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.InputFunctionIterator(input_fn, self._input_workers,
input_contexts,
self._container_strategy())
def _experimental_distribute_dataset(self, dataset):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._host_input_device, session)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None):
if initial_loop_values is None:
initial_loop_values = {}
initial_loop_values = nest.flatten(initial_loop_values)
ctx = input_lib.MultiStepContext()
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
fn_result = fn(ctx, iterator.get_next())
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self._local_results(output)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
# We capture the control_flow_context at this point, before we run `fn`
# inside a while_loop. This is useful in cases where we might need to exit
# these contexts and get back to the outer context to do some things, for
# e.g. create an op which should be evaluated only once at the end of the
# loop on the host. One such usage is in creating metrics' value op.
self._outer_control_flow_context = (
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
# Convert the last_step_outputs from a list to the original dict structure
# of last_step_outputs.
last_step_tensor_outputs = loop_result[1:]
last_step_tensor_outputs_dict = nest.pack_sequence_as(
ctx.last_step_outputs, last_step_tensor_outputs)
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
output = last_step_tensor_outputs_dict[name]
# For outputs that have already been reduced, wrap them in a Mirrored
# container, else in a PerReplica container.
if reduce_op is None:
last_step_tensor_outputs_dict[name] = values.regroup(self._device_map,
output)
else:
assert len(output) == 1
last_step_tensor_outputs_dict[name] = output[0]
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
return ctx
def _broadcast_to(self, tensor, destinations):
# This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it
# should be converted to. Otherwise we have trouble with:
# global_step.assign_add(1)
# since the `1` gets broadcast as an int32 but global_step is int64.
if isinstance(tensor, (float, int)):
return tensor
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
if not destinations:
# TODO(josh11b): Use current logical device instead of 0 here.
destinations = values.LogicalDeviceSpec(
device_map=self._device_map, logical_device=0)
return self._get_cross_device_ops().broadcast(tensor, destinations)
def _call_for_each_replica(self, fn, args, kwargs):
return _call_for_each_replica(self._container_strategy(), self._device_map,
fn, args, kwargs)
def _configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
del task_type, task_id
if session_config:
session_config.CopyFrom(self._update_config_proto(session_config))
if cluster_spec:
# TODO(yuefengz): remove the following code once cluster_resolver is
# added.
num_gpus_per_worker = _infer_num_gpus_per_worker(
self._device_map.all_devices)
multi_worker_devices = _cluster_spec_to_device_list(
cluster_spec, num_gpus_per_worker)
self._initialize_multi_worker(multi_worker_devices)
def _update_config_proto(self, config_proto):
updated_config = copy.deepcopy(config_proto)
updated_config.isolate_session_state = True
return updated_config
def _get_cross_device_ops(self):
return self._cross_device_ops or self._inferred_cross_device_ops
def _reduce_to(self, reduce_op, value, destinations):
if (isinstance(value, values.Mirrored) and
reduce_op == reduce_util.ReduceOp.MEAN):
return value
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.DistributedValues):
# This function handles reducing values that are not PerReplica or
# Mirrored values. For example, the same value could be present on all
# replicas in which case `value` would be a single value or value could
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, self._device_map, value, destinations)
return self._get_cross_device_ops().reduce(
reduce_op, value, destinations=destinations)
def _batch_reduce_to(self, reduce_op, value_destination_pairs):
return self._get_cross_device_ops().batch_reduce(
reduce_op, value_destination_pairs)
def _update(self, var, fn, args, kwargs, group):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
updates = []
for i, (d, v) in enumerate(zip(var.devices, var.values)):
name = "update_%d" % i
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
# If args and kwargs are not mirrored, the value is returned as is.
updates.append(fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs)))
return values.update_regroup(self, self._device_map, updates, group)
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
assert isinstance(colocate_with, tuple)
# TODO(josh11b): In eager mode, use one thread per device.
updates = []
for i, d in enumerate(colocate_with):
name = "update_%d" % i
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates.append(fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs)))
return values.update_regroup(self, self._device_map, updates, group)
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
if isinstance(replica_local_var, values.SyncOnReadVariable):
return replica_local_var._get_cross_replica() # pylint: disable=protected-access
assert isinstance(replica_local_var, values.Mirrored)
return array_ops.identity(replica_local_var.get())
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
return val.values
return (val,)
def value_container(self, val):
return values.value_container(val)
@property
def _num_replicas_in_sync(self):
return self._device_map.num_replicas_in_graph
@property
def worker_devices(self):
return self._device_map.all_devices
@property
def worker_devices_by_replica(self):
return self._device_map.devices_by_replica
@property
def parameter_devices(self):
return self._device_map.all_devices
@property
def experimental_between_graph(self):
return False
@property
def experimental_should_init(self):
return True
@property
def should_checkpoint(self):
return True
@property
def should_save_summary(self):
return True
def non_slot_devices(self, var_list):
del var_list
# TODO(josh11b): Should this be the last logical device instead?
return self._device_map.logical_to_actual_devices(0)
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
"""`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
`make_input_fn_iterator` assumes per-replica batching.
Returns:
Boolean.
"""
return True
class _MirroredReplicaThread(threading.Thread):
"""A thread that runs() a function on a device."""
def __init__(self, dist, coord, replica_id, device_map, variable_creator_fn,
fn, args, kwargs):
super(_MirroredReplicaThread, self).__init__()
self.coord = coord
self.distribution = dist
self.device_map = device_map
self.replica_id = replica_id
self.variable_creator_fn = variable_creator_fn
# State needed to run and return the results of `fn`.
self.main_fn = fn
self.main_args = args
self.main_kwargs = kwargs
self.main_result = None
self.done = False
# State needed to run the next merge_call() (if any) requested via
# ReplicaContext.
self.merge_fn = None
self.merge_args = None
self.merge_kwargs = None
self.merge_result = None
self.captured_name_scope = None
self.captured_var_scope = None
# We use a thread.Event for the main thread to signal when this
# thread should start running (`should_run`), and another for
# this thread to transfer control back to the main thread
# (`has_paused`, either when it gets to a
# `get_replica_context().merge_call` or when `fn` returns). In
# either case the event starts cleared, is signaled by calling
# set(). The receiving thread waits for the signal by calling
# wait() and then immediately clearing the event using clear().
self.should_run = threading.Event()
self.has_paused = threading.Event()
# These fields have to do with inheriting various contexts from the
# parent thread:
context.ensure_initialized()
ctx = context.context()
self.in_eager = ctx.executing_eagerly()
self.record_thread_local_context_fields()
self.context_device_policy = (
pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
ctx._context_handle))
self.graph = ops.get_default_graph()
with ops.init_scope():
self._init_in_eager = context.executing_eagerly()
self._init_graph = ops.get_default_graph()
self._variable_creator_stack = self.graph._variable_creator_stack[:]
self._var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
self._name_scope = self.graph.get_name_scope()
if self._name_scope:
self._name_scope += "/"
if self.replica_id > 0:
if not self._name_scope:
self._name_scope = ""
self._name_scope += "replica_%d/" % self.replica_id
def run(self):
self.should_run.wait()
self.should_run.clear()
try:
if self.coord.should_stop():
return
self.restore_thread_local_context_fields()
# TODO(josh11b): Use current logical device instead of 0 here.
with self.coord.stop_on_exception(), \
_enter_graph(self._init_graph, self._init_in_eager), \
_enter_graph(self.graph, self.in_eager,
self._variable_creator_stack), \
context.device_policy(self.context_device_policy), \
MirroredReplicaContext(self.distribution, constant_op.constant(
self.replica_id, dtypes.int32)), \
ops.device(self.device_map.logical_to_actual_devices(0)[
self.replica_id]), \
ops.name_scope(self._name_scope), \
variable_scope.variable_scope(
self._var_scope, reuse=self.replica_id > 0), \
variable_scope.variable_creator_scope(self.variable_creator_fn):
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
self.done = True
finally:
self.has_paused.set()
def record_thread_local_context_fields(self):
"""Record thread local fields of context.context() in self."""
ctx = context.context()
self._summary_step = ctx.summary_step
self._summary_writer = ctx.summary_writer
self._summary_recording = ctx.summary_recording
self._summary_recording_distribution_strategy = (
ctx.summary_recording_distribution_strategy)
# TODO(b/125892694): record other fields in EagerContext.
def restore_thread_local_context_fields(self):
"""Restore thread local fields of context.context() from self."""
ctx = context.context()
ctx.summary_step = self._summary_step
ctx.summary_writer = self._summary_writer
ctx.summary_recording = self._summary_recording
ctx.summary_recording_distribution_strategy = (
self._summary_recording_distribution_strategy)
# TODO(b/125892694): restore other fields in EagerContext.
class MirroredReplicaContext(distribute_lib.ReplicaContext):
"""ReplicaContext used in MirroredStrategy.extended.call_for_each_replica().
Opened in `_MirroredReplicaThread`, to allow the user to invoke
`MirroredStrategy`'s specific implementation of `merge_call()`,
which works by delegating the function and its arguments to
the main thread (the one that invoked
`MirroredStrategy.extended.call_for_each_replica()`).
"""
def _merge_call(self, fn, args, kwargs):
"""Delegate to the main thread to actually perform merge_call()."""
t = threading.current_thread() # a _MirroredReplicaThread
t.merge_fn = fn
t.merge_args = args
t.merge_kwargs = kwargs
t.captured_name_scope = t.graph.get_name_scope()
# Adding a "/" at end lets us re-enter this scope later.
if t.captured_name_scope:
t.captured_name_scope += "/"
t.captured_var_scope = variable_scope.get_variable_scope()
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
# NOTE(priyag): Throw an error if there is a merge call in the middle of a
# `fn` passed to call_for_each_replica which changes the graph being used
# while calling `fn`. This can happen when the `fn` is decorated with
# `tf.function` and there is a merge_call in `fn`. This breaks because each
# thread tries to create a distinct tf.function. Each tf.function creation
# takes a lock, and so if there is a merge call in the middle, the lock is
# never released and subsequent replica threads cannot proceed to define
# their own functions. Checking for the graph being the same is one way for
# us to check this didn't happen.
if ops.get_default_graph() != t.graph:
raise RuntimeError(
"`merge_call` called while defining a new graph or a tf.function. "
"This can often happen if the function `fn` passed to "
"`strategy.experimental_run()` is decorated with "
"`@tf.function` (or contains a nested `@tf.function`), and `fn` "
"contains a synchronization point, such as aggregating gradients. "
"This behavior is not yet supported. Instead, please wrap the entire "
"call `strategy.experimental_run(fn)` in a `@tf.function`, and avoid "
"nested `tf.function`s that may potentially cross a synchronization "
"boundary.")
t.has_paused.set()
t.should_run.wait()
t.should_run.clear()
if t.coord.should_stop():
raise _RequestedStop()
return t.merge_result
@property
def devices(self):
distribute_lib.require_replica_context(self)
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
return [self._strategy.extended.worker_devices_by_replica[replica_id]]