Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensor-like objects that are composed from tf.Tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.util import nest
@six.add_metaclass(abc.ABCMeta)
class CompositeTensor(object):
"""Abstract base class for Tensor-like objects that are composed from Tensors.
Each `CompositeTensor` can be decomposed into a structured collection of
component `tf.Tensor`s, and reconstructed from those components.
The `tensorflow.python.util.nest` module has support for treating composite
tensors as structure, which makes it easy to flatten and reconstruct
composite tensors (or larger structures that contain composite tensors).
E.g.:
```python
ct = ... # Create a composite tensor.
flat_list_of_tensors = nest.flatten(ct, expand_composites=True)
transformed_list_of_tensors = ... # do something with the flat tensors.
result = nest.pack_sequence_as(ct, transformed_list_of_tensors,
expand_composites=True)
```
"""
@abc.abstractmethod
def _to_components(self):
"""Decomposes this composite tensor into its component tensors.
Returns:
A nested structure of `tf.Tensor`s and `CompositeTensor`s that can be
used to reconstruct this composite tensor (along with metadata returned
by `_component_metadata`).
"""
raise NotImplementedError("CompositeTensor._to_components")
def _component_metadata(self):
"""Returns any non-tensor metadata needed to reconstruct a composite tensor.
Returns:
A nested structure of metadata that can be used to reconstruct this
composite tensor (along with the tensors returned by `_to_components`).
"""
return None
@abc.abstractmethod
def _from_components(cls, components, metadata): # pylint: disable=no-self-argument
"""Creates a composite tensor of type `cls` from components.
Args:
components: A nested structure whose values are `tf.Tensor`s or
`tf.CompositeTensor`s (as returned by `_to_components`).
metadata: A nested structure containing any additional metadata needed to
reconstruct the composite tensor (as returned by `_composite_metadata`).
Returns:
A `CompositeTensor` of type `cls`.
"""
raise NotImplementedError("CompositeTensor._from_components")
@abc.abstractmethod
def _shape_invariant_to_components(self, shape=None):
"""Converts a shape invariant into invariants for individual components.
Args:
shape: A `tf.TensorShape` object. The shape invariant for this
`CompositeTensor`, or `None` if a default shape invariant should be
used (based on the value of this `CompositeTensor`).
Returns:
A nested structure whose values are `tf.TensorShape` objects, specifying
the shape invariants for the tensors that comprise this `CompositeTensor`.
"""
raise NotImplementedError("CompositeTensor._shape_invariant_to_components")
@abc.abstractproperty
def _is_graph_tensor(self):
"""Returns True if this tensor's components belong to a TF graph."""
raise NotImplementedError("CompositeTensor._is_graph_tensor")
def _consumers(self):
"""Returns a list of `Operation`s that consume this `CompositeTensor`.
Returns:
A list of `Operation`s.
Raises:
RuntimeError: If this method is called while executing eagerly.
"""
consumers = nest.flatten([
component.consumers()
for component in self._to_components()
if getattr(component, "graph", None) is not None
])
return list(set(consumers))
pywrap_tensorflow.RegisterType("CompositeTensor", CompositeTensor)
def replace_composites_with_components(structure):
"""Recursively replaces CompositeTensors with their components.
Args:
structure: A `nest`-compatible structure, possibly containing composite
tensors.
Returns:
A copy of `structure`, where each composite tensor has been replaced by
its components. The result will contain no composite tensors.
Note that `nest.flatten(replace_composites_with_components(structure))`
returns the same value as `nest.flatten(structure)`.
"""
if isinstance(structure, CompositeTensor):
return replace_composites_with_components(structure._to_components()) # pylint: disable=protected-access
elif not nest.is_sequence(structure):
return structure
else:
return nest.map_structure(replace_composites_with_components, structure,
expand_composites=False)
# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just
# convert_to_tensor_or_composite? Alternatively, should composite tensors
# register a dispatch override for tf.convert_to_tensor?