Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import functools
import typing as t

try:
    import tensorflow as tf
except ModuleNotFoundError:
    pass  # Warning is displayed by t.py

import sarus_data_spec.typing as st

# Inspired by:
# https://github.com/tensorflow/datasets/blob/v4.4.0/tensorflow_datasets/core/utils/type_utils.py#L43
PathType = t.Tuple[str, ...]
InternalTFSignature = t.Dict[PathType, tf.TensorSpec]

FlattenedTensors = t.List[tf.Tensor]
# Recursive definition below is ignored; this is a known issue with mypy:
# https://github.com/python/mypy/issues/731
NestedTensors = t.Union[tf.Tensor, t.Dict[str, "NestedTensors"]]
TensorDict = t.Dict[str, tf.Tensor]

InputFeatureFunction = t.Callable[[tf.Tensor], tf.train.Feature]
PreFeature = t.List[t.Tuple[str, int, InputFeatureFunction]]

FeatureDescription = t.Dict[str, tf.io.FixedLenFeature]


def _position_to_column(i: int) -> str:
    return f"_{i}"


def _column_to_position(column: str) -> int:
    return int(column.replace("_", ""))


class TensorflowSignatureConverter(st.TypeVisitor):
    """This visitors builds an internal structure which is the cornerstone to
    solve the following problems:
      - flattening the nested structure to a structure suitable for TFRecords
      - nesting a flattened structure to reshape the data from TFRecords
      - describing the Feature to serialize data as protobuf messages

    It does so by scanning in a deterministic way the Sarus Type of the data,
    and building along the exploration of the structure a dictionary mapping
    the path to a leaf in the structure with a position in the protobuf
    message and the TensorSpec to use to serialize the data of the leaf.

    The signature abides by the standard defined in `tensorflow_visitor.py`.

    More about the Features:
      - https://www.tensorflow.org/api_docs/python/tf/train/Feature
      - https://www.tensorflow.org/tutorials/load_data/tfrecord#data_types_for_tftrainexample  # noqa: E501
    """

    def __init__(self) -> None:
        self.flattened_features: InternalTFSignature = {}
        self.path_index: t.Dict[PathType, int] = {}
        self.path: PathType = t.cast(PathType, ())
        self.shape: tf.TensorShape = tf.TensorShape([])

    def _add_feature(self, path: PathType, tf_spec: tf.TensorSpec) -> None:
        """Registers a new feature with the given path and spec.
        It sets or reuse the flattened index (Integer or Float can be
        overridden by Int32/64 or Float32/64).
        """
        if path in self.path_index:
            i = self.path_index[path]
        else:
            i = len(self.path_index)
            self.path_index[path] = i
        name = _position_to_column(i)
        named_tf_spec = tf.TensorSpec.from_spec(tf_spec, name=name)
        self.flattened_features[path] = named_tf_spec

    def _add_feature_with_mask(
        self, path: PathType, tf_spec: tf.TensorSpec
    ) -> None:
        """Helper that pairs the addition of a value with the mask associated
        to it.
        """
        mask_path = path[:-1] + ("input_mask",)
        mask_spec = tf.TensorSpec(tf_spec.shape, dtype=tf.dtypes.bool)

        self._add_feature(mask_path, mask_spec)
        self._add_feature(path, tf_spec)

    def _add_value_spec_with_mask(self, tf_type: tf.TensorSpec) -> None:
        """Helper to register a field with its mask given its TF type."""
        tf_spec = tf.TensorSpec(self.shape, dtype=tf_type)
        self._add_feature_with_mask(self.path + ("values",), tf_spec)

    def Null(self, properties: t.Optional[t.Mapping[str, str]] = None) -> None:
        raise ValueError

    def Unit(self, properties: t.Optional[t.Mapping[str, str]] = None) -> None:
        self._add_value_spec_with_mask(tf.dtypes.float64)

    def Text(
        self,
        encoding: str,
        possible_values: t.Iterable[str],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.string)

    def Bytes(
        self, properties: t.Optional[t.Mapping[str, str]] = None
    ) -> None:
        raise ValueError

    def Boolean(
        self, properties: t.Optional[t.Mapping[str, str]] = None
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.bool)

    def Id(
        self,
        unique: bool,
        reference: t.Optional[st.Path] = None,
        base: t.Optional[st.IdBase] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.string)

    def Enum(
        self,
        name: str,
        name_values: t.Sequence[t.Tuple[str, int]],
        ordered: bool,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.string)

    def Integer(
        self,
        min: int,
        max: int,
        base: st.IntegerBase,
        possible_values: t.Iterable[int],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.int64)

    def Float(
        self,
        min: float,
        max: float,
        base: st.FloatBase,
        possible_values: t.Iterable[float],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.float64)

    def Float32(self, min: float, max: float) -> None:
        self._add_value_spec_with_mask(tf.dtypes.float32)

    def Float64(self, min: float, max: float) -> None:
        self._add_value_spec_with_mask(tf.dtypes.float64)

    def Datetime(
        self,
        format: str,
        min: str,
        max: str,
        base: st.DatetimeBase,
        possible_values: t.Iterable[str],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.string)

    def Date(
        self,
        format: str,
        min: str,
        max: str,
        base: st.DateBase,
        possible_values: t.Iterable[str],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.string)

    def Time(
        self,
        format: str,
        min: str,
        max: str,
        base: st.TimeBase,
        possible_values: t.Iterable[str],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.string)

    def Duration(
        self,
        unit: str,
        min: int,
        max: int,
        possible_values: t.Iterable[int],
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        self._add_value_spec_with_mask(tf.dtypes.int64)

    def Struct(
        self,
        fields: t.Mapping[str, st.Type],
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        path = self.path
        for field_name in fields.keys():
            self.path = path + (field_name,)
            field = fields[field_name]
            field.accept(self)
        self.path = path

    def Constrained(
        self,
        type: st.Type,
        constraint: st.Predicate,
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        raise NotImplementedError

    def List(
        self,
        type: st.Type,
        max_size: int,
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        raise NotImplementedError

    def Array(
        self,
        type: st.Type,
        shape: t.Tuple[int, ...],
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        raise NotImplementedError

    def Optional(
        self,
        type: st.Type,
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        type.accept(self)

    def Union(
        self,
        fields: t.Mapping[str, st.Type],
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        path = self.path
        for field_name, field_type in fields.items():
            self.path = path + (field_name,)
            field_type.accept(self)
        self.path = path

    def Hypothesis(
        self,
        *types: t.Tuple[st.Type, float],
        name: t.Optional[str] = None,
        properties: t.Optional[t.Mapping[str, str]] = None,
    ) -> None:
        raise NotImplementedError


def to_internal_signature(sarus_type: st.Type) -> InternalTFSignature:
    """Uses the visitor implemented above to build the TF signature mapping
    paths to tensor specs.
    """
    visitor = TensorflowSignatureConverter()
    sarus_type.accept(visitor)
    return visitor.flattened_features


# From: https://www.tensorflow.org/tutorials/load_data/tfrecord#tftrainexample


# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.
def _bytes_feature(value: tf.Tensor) -> tf.train.Feature:
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = (
            value.numpy()
        )  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value: tf.Tensor) -> tf.train.Feature:
    """Returns a float_list from a float / double."""
    if isinstance(value, type(tf.constant(0))):
        value = (
            value.numpy()
        )  # FloatList won't unpack a float from an EagerTensor.
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value: tf.Tensor) -> tf.train.Feature:
    """Returns an int64_list from a bool / enum / int / uint."""
    value = tf.cast(value, tf.int64)
    if isinstance(value, type(tf.constant(0))):
        value = (
            value.numpy()
        )  # Int64List won't unpack an int from an EagerTensor.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _input_feature_factory(tf_spec: tf.TensorSpec) -> InputFeatureFunction:
    """Maps a TF type to the suitable TF Feature.

    For more details:
      - types: https://www.tensorflow.org/api_docs/python/tf/dtypes
      - features: https://www.tensorflow.org/tutorials/load_data/tfrecord#data_types_for_tftrainexample  # noqa: E501
    """
    tf_type = tf_spec.dtype
    if tf_type in (tf.dtypes.string,):
        return _bytes_feature
    elif tf_type in (tf.dtypes.float32, tf.dtypes.float64):
        return _float_feature
    elif tf_type in (tf.dtypes.bool, tf.dtypes.int32, tf.dtypes.int64):
        return _int64_feature
    else:
        raise NotImplementedError(f"{tf_type} is not supported")


def serialize_example(
    features: FlattenedTensors, pre_feature: PreFeature
) -> tf.dtypes.string:
    """Creates a tf.train.Example message ready to be written to a file."""
    # Create a dictionary mapping the feature name to the
    # tf.train.Example-compatible data type.
    feature = {
        name: tf_feature(features[position])
        for name, position, tf_feature in pre_feature
    }
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=feature)
    )
    return example_proto.SerializeToString()


def serialize(signature: InternalTFSignature) -> t.Callable:
    """Helper to precompute the features and return a serializing function."""
    pre_feature = [
        (
            tf_spec.name,
            _column_to_position(tf_spec.name),
            _input_feature_factory(tf_spec),
        )
        for tf_spec in signature.values()
    ]
    return functools.partial(serialize_example, pre_feature=pre_feature)


def _output_feature_factory(tf_spec: tf.TensorSpec) -> tf.io.FixedLenFeature:
    """Maps a TF type to the suitable TF Feature.

    For more details:
      - types: https://www.tensorflow.org/api_docs/python/tf/dtypes
      - features: https://www.tensorflow.org/api_docs/python/tf/io/FixedLenFeature  # noqa: E501
    """
    tf_type = tf_spec.dtype
    if tf_type in (tf.dtypes.string,):
        return tf.io.FixedLenFeature(
            shape=(),
            dtype=tf.dtypes.string,
        )  # default_value='')
    elif tf_type in (tf.dtypes.float32, tf.dtypes.float64):
        return tf.io.FixedLenFeature(
            shape=(),
            dtype=tf.dtypes.float32,
        )  # default_value=0.0
    elif tf_type in (tf.dtypes.bool, tf.dtypes.int32, tf.dtypes.int64):
        return tf.io.FixedLenFeature(
            shape=(),
            dtype=tf.dtypes.int64,
        )  # default_value=0
    else:
        raise NotImplementedError(f"{tf_type} is not supported")


def deserialize_example(
    raw_bytes: tf.Tensor, feature_description: FeatureDescription
) -> TensorDict:
    """Creates a tf.train.Example message out of raw bytes.

    Docs:
      https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file
    """
    # Parse the input `tf.train.Example` proto using the dictionary
    # `feature_description`.
    return t.cast(
        TensorDict, tf.io.parse_example(raw_bytes, feature_description)
    )


def deserialize(signature: InternalTFSignature) -> t.Callable:
    """Helper to precompute the feature description part and return a
    deserializing function.
    """
    # Create a description of the features.
    feature_description = {
        tf_spec.name: _output_feature_factory(tf_spec)
        for tf_spec in signature.values()
    }
    return functools.partial(
        deserialize_example, feature_description=feature_description
    )


def _flatten(
    batch: NestedTensors, signature: InternalTFSignature
) -> FlattenedTensors:
    """Recursively flattens the batch structure to set a leaf value at its
    position stored in the internal TF signature.
    A missing value in the structure will leave the position at None, which
    should trigger an exception.
    An additional leaf will have its path not found in the signature, raising
    again an exception.

    Motivation:
      From:https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict#example_4  # noqa: E501

    TF Datasets flattens the structure of a datum in order to serialize it as
    a protobuf message. From the docs linked above:

    ```
    tfds.features.FeaturesDict({
        'input': tf.int32,
        'target': {
            'height': tf.int32,
            'width': tf.int32,
        },
    })
    ```
    Will internally store the data as:
    ```
    {
        'input': tf.io.FixedLenFeature(shape=(), dtype=tf.int32),
        'target/height': tf.io.FixedLenFeature(shape=(), dtype=tf.int32),
        'target/width': tf.io.FixedLenFeature(shape=(), dtype=tf.int32),
    }
    ```

    We copy this behavior but with the additional information from the Sarus
    schema, we do not need the complex machinery of TF Datasets.
    """
    flattened_batch = [None for i in signature]
    heap = [(t.cast(PathType, ()), batch)]

    while heap:
        path, node = heap.pop()

        if isinstance(node, dict):
            # internal node
            for field in sorted(node.keys()):
                heap.append((path + (field,), node[field]))
        elif isinstance(node, tf.Tensor):
            # leaf
            tf_spec = signature[path]
            position = _column_to_position(tf_spec.name)
            flattened_batch[position] = node
        else:
            raise ValueError

    return flattened_batch


def flatten(signature: InternalTFSignature) -> t.Callable:
    """Helper to wrap the signature in a flattening function"""
    return functools.partial(_flatten, signature=signature)


def _nest(batch: TensorDict, signature: InternalTFSignature) -> NestedTensors:
    """Implementation which does not aim at performance.
    Rebuilds the nested structure of the data from the TFRecords
    """
    output = t.cast(NestedTensors, {})
    for path, tf_spec in signature.items():
        column = tf_spec.name
        data = output
        stop = len(path) - 1
        for i, part in enumerate(path):
            if i < stop:
                if part not in data:
                    data[part] = {}
                data = data[part]
            else:
                data[part] = tf.cast(batch[column], tf_spec.dtype)
    return output


def nest(signature: InternalTFSignature) -> t.Callable:
    """Helper to wrap the signature in a nesting function"""
    return functools.partial(_nest, signature=signature)