Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
import functools
import typing as t
try:
import tensorflow as tf
except ModuleNotFoundError:
pass # Warning is displayed by t.py
import sarus_data_spec.typing as st
# Inspired by:
# https://github.com/tensorflow/datasets/blob/v4.4.0/tensorflow_datasets/core/utils/type_utils.py#L43
PathType = t.Tuple[str, ...]
InternalTFSignature = t.Dict[PathType, tf.TensorSpec]
FlattenedTensors = t.List[tf.Tensor]
# Recursive definition below is ignored; this is a known issue with mypy:
# https://github.com/python/mypy/issues/731
NestedTensors = t.Union[tf.Tensor, t.Dict[str, "NestedTensors"]]
TensorDict = t.Dict[str, tf.Tensor]
InputFeatureFunction = t.Callable[[tf.Tensor], tf.train.Feature]
PreFeature = t.List[t.Tuple[str, int, InputFeatureFunction]]
FeatureDescription = t.Dict[str, tf.io.FixedLenFeature]
def _position_to_column(i: int) -> str:
return f"_{i}"
def _column_to_position(column: str) -> int:
return int(column.replace("_", ""))
class TensorflowSignatureConverter(st.TypeVisitor):
"""This visitors builds an internal structure which is the cornerstone to
solve the following problems:
- flattening the nested structure to a structure suitable for TFRecords
- nesting a flattened structure to reshape the data from TFRecords
- describing the Feature to serialize data as protobuf messages
It does so by scanning in a deterministic way the Sarus Type of the data,
and building along the exploration of the structure a dictionary mapping
the path to a leaf in the structure with a position in the protobuf
message and the TensorSpec to use to serialize the data of the leaf.
The signature abides by the standard defined in `tensorflow_visitor.py`.
More about the Features:
- https://www.tensorflow.org/api_docs/python/tf/train/Feature
- https://www.tensorflow.org/tutorials/load_data/tfrecord#data_types_for_tftrainexample # noqa: E501
"""
def __init__(self) -> None:
self.flattened_features: InternalTFSignature = {}
self.path_index: t.Dict[PathType, int] = {}
self.path: PathType = t.cast(PathType, ())
self.shape: tf.TensorShape = tf.TensorShape([])
def _add_feature(self, path: PathType, tf_spec: tf.TensorSpec) -> None:
"""Registers a new feature with the given path and spec.
It sets or reuse the flattened index (Integer or Float can be
overridden by Int32/64 or Float32/64).
"""
if path in self.path_index:
i = self.path_index[path]
else:
i = len(self.path_index)
self.path_index[path] = i
name = _position_to_column(i)
named_tf_spec = tf.TensorSpec.from_spec(tf_spec, name=name)
self.flattened_features[path] = named_tf_spec
def _add_feature_with_mask(
self, path: PathType, tf_spec: tf.TensorSpec
) -> None:
"""Helper that pairs the addition of a value with the mask associated
to it.
"""
mask_path = path[:-1] + ("input_mask",)
mask_spec = tf.TensorSpec(tf_spec.shape, dtype=tf.dtypes.bool)
self._add_feature(mask_path, mask_spec)
self._add_feature(path, tf_spec)
def _add_value_spec_with_mask(self, tf_type: tf.TensorSpec) -> None:
"""Helper to register a field with its mask given its TF type."""
tf_spec = tf.TensorSpec(self.shape, dtype=tf_type)
self._add_feature_with_mask(self.path + ("values",), tf_spec)
def Null(self, properties: t.Optional[t.Mapping[str, str]] = None) -> None:
raise ValueError
def Unit(self, properties: t.Optional[t.Mapping[str, str]] = None) -> None:
self._add_value_spec_with_mask(tf.dtypes.float64)
def Text(
self,
encoding: str,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.string)
def Bytes(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
raise ValueError
def Boolean(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
self._add_value_spec_with_mask(tf.dtypes.bool)
def Id(
self,
unique: bool,
reference: t.Optional[st.Path] = None,
base: t.Optional[st.IdBase] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.string)
def Enum(
self,
name: str,
name_values: t.Sequence[t.Tuple[str, int]],
ordered: bool,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.string)
def Integer(
self,
min: int,
max: int,
base: st.IntegerBase,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.int64)
def Float(
self,
min: float,
max: float,
base: st.FloatBase,
possible_values: t.Iterable[float],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.float64)
def Float32(self, min: float, max: float) -> None:
self._add_value_spec_with_mask(tf.dtypes.float32)
def Float64(self, min: float, max: float) -> None:
self._add_value_spec_with_mask(tf.dtypes.float64)
def Datetime(
self,
format: str,
min: str,
max: str,
base: st.DatetimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.string)
def Date(
self,
format: str,
min: str,
max: str,
base: st.DateBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.string)
def Time(
self,
format: str,
min: str,
max: str,
base: st.TimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.string)
def Duration(
self,
unit: str,
min: int,
max: int,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self._add_value_spec_with_mask(tf.dtypes.int64)
def Struct(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
path = self.path
for field_name in fields.keys():
self.path = path + (field_name,)
field = fields[field_name]
field.accept(self)
self.path = path
def Constrained(
self,
type: st.Type,
constraint: st.Predicate,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def List(
self,
type: st.Type,
max_size: int,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Array(
self,
type: st.Type,
shape: t.Tuple[int, ...],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Optional(
self,
type: st.Type,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
type.accept(self)
def Union(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
path = self.path
for field_name, field_type in fields.items():
self.path = path + (field_name,)
field_type.accept(self)
self.path = path
def Hypothesis(
self,
*types: t.Tuple[st.Type, float],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def to_internal_signature(sarus_type: st.Type) -> InternalTFSignature:
"""Uses the visitor implemented above to build the TF signature mapping
paths to tensor specs.
"""
visitor = TensorflowSignatureConverter()
sarus_type.accept(visitor)
return visitor.flattened_features
# From: https://www.tensorflow.org/tutorials/load_data/tfrecord#tftrainexample
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.
def _bytes_feature(value: tf.Tensor) -> tf.train.Feature:
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = (
value.numpy()
) # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value: tf.Tensor) -> tf.train.Feature:
"""Returns a float_list from a float / double."""
if isinstance(value, type(tf.constant(0))):
value = (
value.numpy()
) # FloatList won't unpack a float from an EagerTensor.
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value: tf.Tensor) -> tf.train.Feature:
"""Returns an int64_list from a bool / enum / int / uint."""
value = tf.cast(value, tf.int64)
if isinstance(value, type(tf.constant(0))):
value = (
value.numpy()
) # Int64List won't unpack an int from an EagerTensor.
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _input_feature_factory(tf_spec: tf.TensorSpec) -> InputFeatureFunction:
"""Maps a TF type to the suitable TF Feature.
For more details:
- types: https://www.tensorflow.org/api_docs/python/tf/dtypes
- features: https://www.tensorflow.org/tutorials/load_data/tfrecord#data_types_for_tftrainexample # noqa: E501
"""
tf_type = tf_spec.dtype
if tf_type in (tf.dtypes.string,):
return _bytes_feature
elif tf_type in (tf.dtypes.float32, tf.dtypes.float64):
return _float_feature
elif tf_type in (tf.dtypes.bool, tf.dtypes.int32, tf.dtypes.int64):
return _int64_feature
else:
raise NotImplementedError(f"{tf_type} is not supported")
def serialize_example(
features: FlattenedTensors, pre_feature: PreFeature
) -> tf.dtypes.string:
"""Creates a tf.train.Example message ready to be written to a file."""
# Create a dictionary mapping the feature name to the
# tf.train.Example-compatible data type.
feature = {
name: tf_feature(features[position])
for name, position, tf_feature in pre_feature
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(
features=tf.train.Features(feature=feature)
)
return example_proto.SerializeToString()
def serialize(signature: InternalTFSignature) -> t.Callable:
"""Helper to precompute the features and return a serializing function."""
pre_feature = [
(
tf_spec.name,
_column_to_position(tf_spec.name),
_input_feature_factory(tf_spec),
)
for tf_spec in signature.values()
]
return functools.partial(serialize_example, pre_feature=pre_feature)
def _output_feature_factory(tf_spec: tf.TensorSpec) -> tf.io.FixedLenFeature:
"""Maps a TF type to the suitable TF Feature.
For more details:
- types: https://www.tensorflow.org/api_docs/python/tf/dtypes
- features: https://www.tensorflow.org/api_docs/python/tf/io/FixedLenFeature # noqa: E501
"""
tf_type = tf_spec.dtype
if tf_type in (tf.dtypes.string,):
return tf.io.FixedLenFeature(
shape=(),
dtype=tf.dtypes.string,
) # default_value='')
elif tf_type in (tf.dtypes.float32, tf.dtypes.float64):
return tf.io.FixedLenFeature(
shape=(),
dtype=tf.dtypes.float32,
) # default_value=0.0
elif tf_type in (tf.dtypes.bool, tf.dtypes.int32, tf.dtypes.int64):
return tf.io.FixedLenFeature(
shape=(),
dtype=tf.dtypes.int64,
) # default_value=0
else:
raise NotImplementedError(f"{tf_type} is not supported")
def deserialize_example(
raw_bytes: tf.Tensor, feature_description: FeatureDescription
) -> TensorDict:
"""Creates a tf.train.Example message out of raw bytes.
Docs:
https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file
"""
# Parse the input `tf.train.Example` proto using the dictionary
# `feature_description`.
return t.cast(
TensorDict, tf.io.parse_example(raw_bytes, feature_description)
)
def deserialize(signature: InternalTFSignature) -> t.Callable:
"""Helper to precompute the feature description part and return a
deserializing function.
"""
# Create a description of the features.
feature_description = {
tf_spec.name: _output_feature_factory(tf_spec)
for tf_spec in signature.values()
}
return functools.partial(
deserialize_example, feature_description=feature_description
)
def _flatten(
batch: NestedTensors, signature: InternalTFSignature
) -> FlattenedTensors:
"""Recursively flattens the batch structure to set a leaf value at its
position stored in the internal TF signature.
A missing value in the structure will leave the position at None, which
should trigger an exception.
An additional leaf will have its path not found in the signature, raising
again an exception.
Motivation:
From:https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict#example_4 # noqa: E501
TF Datasets flattens the structure of a datum in order to serialize it as
a protobuf message. From the docs linked above:
```
tfds.features.FeaturesDict({
'input': tf.int32,
'target': {
'height': tf.int32,
'width': tf.int32,
},
})
```
Will internally store the data as:
```
{
'input': tf.io.FixedLenFeature(shape=(), dtype=tf.int32),
'target/height': tf.io.FixedLenFeature(shape=(), dtype=tf.int32),
'target/width': tf.io.FixedLenFeature(shape=(), dtype=tf.int32),
}
```
We copy this behavior but with the additional information from the Sarus
schema, we do not need the complex machinery of TF Datasets.
"""
flattened_batch = [None for i in signature]
heap = [(t.cast(PathType, ()), batch)]
while heap:
path, node = heap.pop()
if isinstance(node, dict):
# internal node
for field in sorted(node.keys()):
heap.append((path + (field,), node[field]))
elif isinstance(node, tf.Tensor):
# leaf
tf_spec = signature[path]
position = _column_to_position(tf_spec.name)
flattened_batch[position] = node
else:
raise ValueError
return flattened_batch
def flatten(signature: InternalTFSignature) -> t.Callable:
"""Helper to wrap the signature in a flattening function"""
return functools.partial(_flatten, signature=signature)
def _nest(batch: TensorDict, signature: InternalTFSignature) -> NestedTensors:
"""Implementation which does not aim at performance.
Rebuilds the nested structure of the data from the TFRecords
"""
output = t.cast(NestedTensors, {})
for path, tf_spec in signature.items():
column = tf_spec.name
data = output
stop = len(path) - 1
for i, part in enumerate(path):
if i < stop:
if part not in data:
data[part] = {}
data = data[part]
else:
data[part] = tf.cast(batch[column], tf_spec.dtype)
return output
def nest(signature: InternalTFSignature) -> t.Callable:
"""Helper to wrap the signature in a nesting function"""
return functools.partial(_nest, signature=signature)