Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library of common shape functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six.moves
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
def has_fully_defined_shape(tensor):
"""Returns true if tensor has a fully defined shape."""
return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
def rank(tensor):
"""Return a rank if it is a tensor, else return None."""
if isinstance(tensor, ops.Tensor):
return tensor._rank() # pylint: disable=protected-access
return None
def scalar_shape(unused_op):
"""Shape function for ops that output a scalar value."""
return [tensor_shape.scalar()]
def unchanged_shape(op):
"""Shape function for ops that output a tensor like their first input."""
return [op.inputs[0].get_shape()]
def unchanged_shape_with_rank(rank):
"""Returns a shape function for ops that constrain the rank of their input.
Args:
rank: The exact rank of the input and output.
Returns:
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank(rank)]
return _ShapeFunction
def unchanged_shape_with_rank_at_least(rank):
"""Returns a shape function for ops that constrain the rank of their input.
Args:
rank: A lower bound on the rank of the input and output.
Returns:
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank_at_least(rank)]
return _ShapeFunction
def unchanged_shape_with_rank_at_most(rank):
"""Returns a shape function for ops that constrain the rank of their input.
Args:
rank: An upper bound on the rank of the input and output.
Returns:
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank_at_most(rank)]
return _ShapeFunction
def matmul_shape(op):
"""Shape function for a MatMul op."""
a_shape = op.inputs[0].get_shape().with_rank(2)
transpose_a = op.get_attr("transpose_a")
b_shape = op.inputs[1].get_shape().with_rank(2)
transpose_b = op.get_attr("transpose_b")
output_rows = a_shape[1] if transpose_a else a_shape[0]
output_cols = b_shape[0] if transpose_b else b_shape[1]
inner_a = a_shape[0] if transpose_a else a_shape[1]
inner_b = b_shape[1] if transpose_b else b_shape[0]
inner_a.assert_is_compatible_with(inner_b)
return [tensor_shape.TensorShape([output_rows, output_cols])]
def get_conv_output_size(input_size, filter_size, strides, padding_type):
"""Returns the spatial size of a n-d convolution/pooling output."""
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
strides = [int(x) for x in strides]
if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
return input_size
if any(x is not None and y is not None and x > y for x, y in
zip(filter_size, input_size)):
raise ValueError("Filter must not be larger than the input: "
"Filter: %r Input: %r" % (filter_size, input_size))
if padding_type == b"VALID":
def _valid(in_dim, k_dim, s_dim):
if in_dim is not None and k_dim is not None:
return (in_dim - k_dim + s_dim) // s_dim
else:
return None
output_size = [
_valid(in_dim, k_dim, s_dim)
for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
]
elif padding_type == b"SAME":
def _same(in_dim, s_dim):
if in_dim is not None:
return (in_dim + s_dim - 1) // s_dim
else:
return None
output_size = [_same(in_dim, s_dim)
for in_dim, s_dim in zip(input_size, strides)]
else:
raise ValueError("Invalid padding: %r" % padding_type)
return tuple(output_size)
def get2d_conv_output_size(input_height, input_width, filter_height,
filter_width, row_stride, col_stride, padding_type):
"""Returns the number of rows and columns in a convolution/pooling output."""
return get_conv_output_size((input_height, input_width),
(filter_height, filter_width),
(row_stride, col_stride), padding_type)
def conv2d_shape(op):
"""Shape function for a Conv2D op.
This op has two inputs:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
* filter, a 4D tensor with shape = [filter_rows, filter_cols,
depth_in, depth_out]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows and out_cols depend on the
value of the op's "padding" and "strides" attrs.
Args:
op: A Conv2D Operation.
Returns:
A list containing the Shape of the Conv2D output.
Raises:
ValueError: If the shapes of the input or filter are incompatible.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
filter_shape = op.inputs[1].get_shape().with_rank(4)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Convert input shape to the default NHWC for inference.
input_shape = [input_shape[0], input_shape[2], input_shape[3],
input_shape[1]]
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
filter_rows = filter_shape[0]
filter_cols = filter_shape[1]
depth_out = filter_shape[3]
# Check that the input depths are compatible.
input_shape[3].assert_is_compatible_with(filter_shape[2])
if data_format == b"NCHW":
stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
else:
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
filter_cols, stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth_out]
if data_format == b"NCHW":
# Convert output shape back to NCHW.
output_shape = [output_shape[0], output_shape[3], output_shape[1],
output_shape[2]]
return [tensor_shape.TensorShape(output_shape)]
def depthwise_conv2d_native_shape(op):
"""Shape function for a DepthwiseConv2D op.
This op has two inputs:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
* filter, a 4D tensor with shape = [filter_rows, filter_cols,
depth_in, depthwise_multiplier]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend
on the value of the op's "padding" and "strides" attrs.
Args:
op: A DepthwiseConv2dNative Operation.
Returns:
A list containing the Shape of the DepthwiseConv2DNative output.
Raises:
ValueError: If the shapes of the input or filter are incompatible.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
filter_shape = op.inputs[1].get_shape().with_rank(4)
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
filter_rows = filter_shape[0]
filter_cols = filter_shape[1]
depth_out = filter_shape[3] * filter_shape[2]
# Check that the input depths are compatible.
input_shape[3].assert_is_compatible_with(filter_shape[2])
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
if stride_r != stride_c:
# TODO(shlens): Add support for this.
raise ValueError("Current implementation only supports equal length "
"strides in the row and column dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
stride = stride_r
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
filter_cols, stride, stride,
padding)
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
def separable_conv2d_shape(op):
"""Shape function for a SeparableConv2D op.
This op has three inputs:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
* depthwise_filter, a 4D tensor with shape = [filter_rows,
filter_cols, depth_in, depth_multiplier]
* pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
depth_multiplier, depth_out]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows and out_cols depend on the
value of the op's "padding" and "strides" attrs.
Args:
op: A SeparableConv2D Operation.
Returns:
A list containing the Shape of the SeparableConv2D output.
Raises:
ValueError: If the shapes of the input or filter are incompatible.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
tensor_shape.TensorShape([None, None, input_shape[3], None]))
pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
filter_rows = depthwise_filter_shape[0]
filter_cols = depthwise_filter_shape[1]
depth_out = pointwise_filter_shape[3]
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
if stride_r != stride_c:
# TODO(shlens): Add support for this.
raise ValueError("Current implementation only supports equal length "
"strides in the row and column dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
stride = stride_r
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
filter_cols, stride, stride,
padding)
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
def avg_pool_shape(op):
"""Shape function for an AvgPool op.
This op has one input:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows and out_cols depend on the
value of the op's "ksize", "strides", and "padding" attrs.
Args:
op: An AvgPool Operation.
Returns:
A single-element list containing the Shape of the AvgPool output.
Raises:
ValueError: If the shape of the input is invalid or incompatible with
the values of the attrs.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Convert input shape to the default NHWC for inference.
input_shape = [input_shape[0], input_shape[2], input_shape[3],
input_shape[1]]
if data_format == b"NCHW":
ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
else:
ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
depth = input_shape[3]
if ksize_b != 1 or ksize_d != 1:
raise ValueError("Current implementation does not support pooling "
"in the batch and depth dimensions.")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not support strides "
"in the batch and depth dimensions.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
ksize_c, stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth]
if data_format == b"NCHW":
# Convert output shape back to NCHW.
output_shape = [output_shape[0], output_shape[3], output_shape[1],
output_shape[2]]
return [tensor_shape.TensorShape(output_shape)]
def max_pool_shape(op):
"""Shape function for a MaxPool op.
This op has one input:
* input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
The output is a 4D tensor with shape = [batch_size, out_rows,
out_cols, depth_out], where out_rows, out_cols, and depth_out depend
on the value of the op's "ksize", "strides", and "padding" attrs.
Args:
op: A MaxPool Operation.
Returns:
A single-element list containing the Shape of the MaxPool output.
Raises:
ValueError: If the shape of the input is invalid or incompatible with
the values of the attrs.
"""
input_shape = op.inputs[0].get_shape().with_rank(4)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Convert input shape to the default NHWC for inference.
input_shape = [input_shape[0], input_shape[2], input_shape[3],
input_shape[1]]
if data_format == b"NCHW":
ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
else:
ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
depth = input_shape[3]
if ksize_b != 1:
raise ValueError("Current implementation does not support pooling "
"in the batch dimension.")
if stride_b != 1:
raise ValueError("Current implementation does not support strides "
"in the batch dimension.")
if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
raise ValueError("MaxPooling supports exactly one of pooling across depth "
"or pooling across width/height.")
# TODO(mrry,shlens): Raise an error if the stride would cause
# information in the input to be ignored. This will require a change
# in the kernel implementation.
if ksize_d == 1:
padding = op.get_attr("padding")
out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
ksize_c, stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth]
else:
if depth % ksize_d > 0:
raise ValueError("Depthwise max pooling requires the depth window "
"to evenly divide the input depth.")
if stride_d != ksize_d:
raise ValueError("Depthwise max pooling requires the depth window "
"to equal the depth stride.")
output_shape = [batch_size, in_rows, in_cols, depth // ksize_d]
if data_format == b"NCHW":
# Convert output shape back to NCHW.
output_shape = [output_shape[0], output_shape[3], output_shape[1],
output_shape[2]]
return [tensor_shape.TensorShape(output_shape)]
def no_outputs(unused_op):
"""Shape function for use with ops that have no outputs."""
return []
def unknown_shape(op):
"""Shape function for use with ops whose output shapes are unknown."""
return [tensor_shape.unknown_shape() for _ in op.outputs]
def _broadcast_shape_helper(shape_x, shape_y):
"""Helper functions for is_broadcast_compatible and broadcast_shape.
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
Returns None if the shapes are not broadcast compatible,
a list of the broadcast dimensions otherwise.
"""
# To compute the broadcasted dimensions, we zip together shape_x and shape_y,
# and pad with 1 to make them the same length.
broadcasted_dims = reversed(list(six.moves.zip_longest(
reversed(shape_x.dims),
reversed(shape_y.dims),
fillvalue=tensor_shape.Dimension(1))))
# Next we combine the dimensions according to the numpy broadcasting rules.
# http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
return_dims = []
for (dim_x, dim_y) in broadcasted_dims:
if dim_x.value is None or dim_y.value is None:
# One or both dimensions is unknown. If either dimension is greater than
# 1, we assume that the program is correct, and the other dimension will
# be broadcast to match it.
# TODO(mrry): If we eliminate the shape checks in C++, we must still
# assert that the unknown dim is either 1 or the same as the known dim.
if dim_x.value is not None and dim_x.value > 1:
return_dims.append(dim_x)
elif dim_y.value is not None and dim_y.value > 1:
return_dims.append(dim_y)
else:
return_dims.append(None)
elif dim_x.value == 1:
# We will broadcast dim_x to dim_y.
return_dims.append(dim_y)
elif dim_y.value == 1:
# We will broadcast dim_y to dim_x.
return_dims.append(dim_x)
elif dim_x.value == dim_y.value:
# The dimensions are compatible, so output is the same size in that
# dimension.
return_dims.append(dim_x.merge_with(dim_y))
else:
return None
return return_dims
def is_broadcast_compatible(shape_x, shape_y):
"""Returns True if `shape_x` and `shape_y` are broadcast compatible.
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
to. False otherwise.
"""
if shape_x.ndims is None or shape_y.ndims is None:
return False
return _broadcast_shape_helper(shape_x, shape_y) is not None
def broadcast_shape(shape_x, shape_y):
"""Returns the broadcasted shape between `shape_x` and `shape_y`.
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
A `TensorShape` representing the broadcasted shape.
Raises:
ValueError: If the two shapes can not be broadcasted.
"""
if shape_x.ndims is None or shape_y.ndims is None:
return tensor_shape.unknown_shape()
return_dims = _broadcast_shape_helper(shape_x, shape_y)
if return_dims is None:
raise ValueError("Incompatible shapes for broadcasting: %s and %s"
% (shape_x, shape_y))
return tensor_shape.TensorShape(return_dims)
def call_cpp_shape_fn(op, require_shape_fn=True):
"""A shape function that delegates to the registered C++ shape function.
Args:
op: the node in the graph for which to compute output shapes.
require_shape_fn: If true, and the C++ shape function is not registered
in the current binary then an exception is raised; otherwise, if the
C++ shape function is not registered then unknown_shape is used.
Returns:
A dictionary with the following keys:
shapes: A TensorShape list of the output shapes of the op, as computed
using the C++ shape inference function registered for the op.
handle_shapes: A TensorShape list of the shapes for handle outputs, if
any.
handle_dtypes: A list of DataType enums for the handle outputs, if any.
Raises:
ValueError: If the C++ shape function returned an error (e.g. because the
shapes of the inputs are of the wrong rank or otherwise incompatible
according to the shape function).
RuntimeError: If the C++ shape function is not registered and
<require_shape_fn> is True.
"""
if op.type == "Const":
# To avoid serializing large constants, we special-case constant
# here, even though it has a C++ shape function. When Python
# calls the C / C-API directly, we should be able to remove this.
return {
"shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
"handle_data": [None]
}
input_tensors_needed = []
input_tensors_as_shapes_needed = []
while True:
res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
input_tensors_as_shapes_needed,
require_shape_fn)
if not isinstance(res, dict):
# Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
return res
# See if we need to evaluate some inputs.
if not res["inputs_needed"]:
return res
p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded()
p = p.FromString(res["inputs_needed"])
changed = False
for idx in p.input_tensors_needed:
if idx not in input_tensors_needed:
input_tensors_needed.append(idx)
changed = True
for idx in p.input_tensors_as_shapes_needed:
if idx not in input_tensors_as_shapes_needed:
input_tensors_as_shapes_needed.append(idx)
changed = True
if not changed:
return res
def _call_cpp_shape_fn_impl(
op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
"""Core implementation of call_cpp_shape_fn."""
graph_def_version = op.graph.graph_def_versions.producer
node_def_str = op.node_def.SerializeToString()
def tensor_to_inference_result(t):
r = cpp_shape_inference_pb2.CppShapeInferenceResult()
r.shape.CopyFrom(t.get_shape().as_proto())
# pylint: disable=protected-access
if t._handle_data is not None:
r.handle_data.CopyFrom(t._handle_data)
# pylint: enable=protected-access
return r.SerializeToString()
input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
input_tensors = [None for i in input_shapes]
for idx in input_tensors_needed:
v = tensor_util.constant_value(op.inputs[idx])
if v is not None:
input_tensors[idx] = np.asarray(v)
serialized_unknown_shape = (
tensor_shape.TensorShape(None).as_proto().SerializeToString())
arr = [serialized_unknown_shape for i in input_shapes]
for idx in input_tensors_as_shapes_needed:
s = tensor_util.constant_value_as_shape(op.inputs[idx])
if s is not None:
arr[idx] = s.as_proto().SerializeToString()
input_tensors_as_shapes = arr
missing_shape_fn = False
try:
output = pywrap_tensorflow.RunCppShapeInference(
graph_def_version, node_def_str, input_shapes, input_tensors,
input_tensors_as_shapes)
except errors.InvalidArgumentError as err:
if err.message.startswith("No shape inference function exists for op"):
missing_shape_fn = True
else:
raise ValueError(err.message)
if missing_shape_fn:
if require_shape_fn:
raise RuntimeError(
"No C++ shape function registered for standard op: %s" % op.type)
return unknown_shape(op)
output_shapes = output[:-1]
# Convert TensorShapeProto values in output_shapes.
result_protos = [
cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
for s in output_shapes
]
result = [r.shape for r in result_protos]
result_handle_data = [
r.handle_data if r.handle_data.is_set else None for r in result_protos
]
return {
"shapes": result,
"handle_data": result_handle_data,
"inputs_needed": output[-1]
}
# pylint: disable=protected-access
ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
# pylint: enable=protected-access