Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations for automatic batching and unbatching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_batch_ops
# pylint: disable=unused-import
from tensorflow.python.ops.batch_ops import batch
from tensorflow.python.ops.batch_ops import batch_function
from tensorflow.python.ops.batch_ops import unbatch
# pylint: enable=unused-import
@ops.RegisterGradient("Batch")
def _BatchGrad(op, *out_grads): # pylint: disable=invalid-name
"""Gradient for batch op."""
gradients = []
for i in range(len(op.inputs)):
gradients.append(
gen_batch_ops.unbatch(
out_grads[i],
op.outputs[-2],
op.outputs[-1],
timeout_micros=op.get_attr("grad_timeout_micros"),
shared_name="batch_gradient_{}_{}".format(op.name, i)))
return gradients
@ops.RegisterGradient("Unbatch")
def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
return [
gen_batch_ops.unbatch_grad(
op.inputs[0],
op.inputs[1],
grad,
op.inputs[2],
shared_name="unbatch_gradient_{}".format(op.name)), None, None
]
def batch_function_v1(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
grad_timeout_micros=60 * 1000 * 1000,
unbatch_timeout_micros=60 * 1000 * 1000,
max_enqueued_batches=10):
"""Batches the computation done by the decorated function.
This is the older version of batch_function(). Please use the former instead
of this.
Args:
num_batch_threads: Number of scheduling threads for processing batches
of work. Determines the number of batches processed in parallel.
max_batch_size: Batch sizes will never be bigger than this.
batch_timeout_micros: Maximum number of microseconds to wait before
outputting an incomplete batch.
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
grad_timeout_micros: The timeout to use for the gradient. See the
documentation of the unbatch op for more details. Defaults to 60s.
unbatch_timeout_micros: The timeout to use for unbatching. See the
documentation of the unbatch op for more details. Defaults to 60s.
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
def decorator(f): # pylint: disable=missing-docstring
def decorated(*args):
with ops.name_scope("batch") as name:
for a in args:
if not isinstance(a, ops.Tensor):
raise ValueError("All arguments to functions decorated with "
"`batch_function` are supposed to be Tensors; "
"found %s" % repr(a))
batched_tensors, batch_index, id_t = gen_batch_ops.batch(
args,
num_batch_threads=num_batch_threads,
max_batch_size=max_batch_size,
batch_timeout_micros=batch_timeout_micros,
max_enqueued_batches=max_enqueued_batches,
allowed_batch_sizes=allowed_batch_sizes,
grad_timeout_micros=grad_timeout_micros,
shared_name=name)
outputs = f(*batched_tensors)
if isinstance(outputs, ops.Tensor):
outputs_list = [outputs]
else:
outputs_list = outputs
with ops.name_scope("unbatch") as unbatch_name:
unbatched = [
gen_batch_ops.unbatch(t, batch_index, id_t,
timeout_micros=unbatch_timeout_micros,
shared_name=unbatch_name + "/" + t.name)
for t in outputs_list]
if isinstance(outputs, ops.Tensor):
return unbatched[0]
return unbatched
return decorated
return decorator