Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_dataset / sarus_dataset / dataset.py
Size: Mime:
from abc import ABC, abstractmethod
from sarus_dataset import schema_pb2
from sarus_dataset.converters import multistage_converters
from sarus_dataset.column_name_encoding import encode as encode_name

from typing import List, Dict, Union, Tuple
from enum import Enum

import numpy as np
import tensorflow as tf


class BaseTypeName(Enum):
    INT8 = "INT8"
    INT16 = "INT16"
    INT32 = "INT32"
    INT64 = "INT64"

    UINT8 = "UINT8"
    UINT16 = "UINT16"

    FLOAT32 = "FLOAT32"
    FLOAT64 = "FLOAT64"

    STRING = "STRING"


BASE_TYPE_NAME_TO_NP_DTYPE = {
    BaseTypeName.INT8: np.dtype("int8"),
    BaseTypeName.INT16: np.dtype("int16"),
    BaseTypeName.INT32: np.dtype("int32"),
    BaseTypeName.INT64: np.dtype("int64"),
    BaseTypeName.UINT8: np.dtype("uint8"),
    BaseTypeName.UINT16: np.dtype("uint16"),
    BaseTypeName.FLOAT32: np.dtype("float32"),
    BaseTypeName.FLOAT64: np.dtype("float64"),
    BaseTypeName.STRING: np.dtype("str"),
}


class Dataset:
    features: Dict[str, "Type"]
    name: str

    def __init__(self, name: str, features: Dict[str, "Type"]):
        self.name = name
        self.features = features

    @classmethod
    def from_proto(cls, dataset_pb: schema_pb2.DataSet):
        features = {
            feature_pb.name: Type.from_union_type_proto(feature_pb.type)
            for feature_pb in dataset_pb.features
        }

        return cls(
            name=dataset_pb.name,
            features=features,
        )

    def transform(self, add_id=False, delete_images=False, encode_names=False):
        """
        add_id adds a LongType field named "id" is added to the schema
        delete_images filters out the images features
        encode_names encode the names of the features (except "id")
        """
        features = {
            feature_name: feature_type
            for feature_name, feature_type in self.features.items()
            if not (delete_images and isinstance(feature_type, Image))
        }

        if encode_names:
            features = {
                encode_name(feature_name): feature_type
                for feature_name, feature_type in features.items()
            }

        if add_id:
            if "id" in self.features:
                raise ValueError(
                    "add_id=True is invalid where is already an existing column named id."
                )

            features["id"] = Integer(BaseTypeName.INT64)

        return self.__class__(self.name, features)

    def get_spark_schema(
        self,
        for_parsing=False,
        for_synthetic=False,
    ):
        """
        Depending on the flags provided, the returned spark_schema is:

         - no flags: schema tied to the stored Parquet files
         - for_parsing: schema needed at the CSV parsing step (datetime, categorical as String)
         - for_synthetic: Force nullable fields for some types (Real, Datetime)

        The flags above does not add/delete any fields, they simply modulate the types of the existing fields.
        """
        from pyspark.sql import types as spark_types

        if for_parsing and for_synthetic:
            raise ValueError(
                "Only one of [for_parsing, for_synthetic] should be set to True"
            )

        struct_fields = []

        for name, feature_type in self.features.items():
            # first set nullable
            # if it is optional we unwrap the nested type, that's why it should be first
            if isinstance(feature_type, Optional):
                feature_type = feature_type.type
                nullable = True
            elif (
                for_synthetic
                and name != "id"
                and isinstance(feature_type, SYNTHETIC_NULLABLE_TYPES)
            ):
                nullable = True
            else:
                nullable = False

            # finally set the spark_type to use
            if for_parsing and isinstance(feature_type, TO_PARSE_TYPES):
                spark_type = spark_types.StringType()
            else:
                spark_type = feature_type.get_spark_type()

            struct_field = spark_types.StructField(name, spark_type, nullable)

            struct_fields.append(struct_field)

        struct_type = spark_types.StructType(struct_fields)

        return struct_type

    def get_raw_np_schema(self):
        return {
            name: type.get_raw_np_type()
            for name, type in self.features.items()
        }

    def get_processed_np_schema(self):
        return {
            name: type.get_processed_np_type()
            for name, type in self.features.items()
        }

    def get_tf_schema(self):
        tf_schema = {}

        for name, type in self.features.items():
            shape, dtype = type.get_tf_type()

            tf_schema[name] = tf.TensorSpec(shape=shape, dtype=dtype)

        return tf_schema


TypePbUnion = Union[
    schema_pb2.Type.Null,
    schema_pb2.Type.Boolean,
    schema_pb2.Type.Categorical,
    schema_pb2.Type.Real,
    schema_pb2.Type.Text,
    schema_pb2.Type.Optional,
    schema_pb2.Type.Struct,
    schema_pb2.Type.Union,
    schema_pb2.Type.DateTime,
    schema_pb2.Type.Latitude,
    schema_pb2.Type.Longitude,
    schema_pb2.Type.Image,
]


class Type(ABC):
    @classmethod
    def from_proto(cls, type_pb: TypePbUnion) -> "Type":
        assert isinstance(type_pb, cls.type_pb_cls)

        return cls._from_proto(type_pb)

    @classmethod
    @abstractmethod
    def _from_proto(cls, type_pb: TypePbUnion) -> "Type":
        """
        From a protobuffer type, instanciate a Python Type object

        Parameters
        ----------
        type_pb: instance of one of the nested messages of schema_pb2.Type
        """
        raise NotImplementedError()

    @staticmethod
    def from_union_type_proto(union_type_pb: schema_pb2.Type) -> "Type":
        assert isinstance(union_type_pb, schema_pb2.Type)

        type_name = union_type_pb.WhichOneof("type")
        type_pb = getattr(union_type_pb, type_name)

        name_to_cls = {
            "boolean": Boolean,
            "integer": Integer,
            "categorical": Categorical,
            "real": Real,
            "text": Text,
            "optional": Optional,
            "datetime": DateTime,
            "latitude": Latittude,
            "longitude": Longitude,
            "image": Image,
        }

        type_cls = name_to_cls[type_name]
        instance = type_cls.from_proto(type_pb)

        return instance

    @abstractmethod
    def get_raw_np_type(self) -> np.dtype:
        """
        Returns expected type at the Raw Numpy Stage
        """
        raise NotImplementedError()

    @abstractmethod
    def get_processed_np_type(self) -> Tuple[Tuple[int], np.dtype]:
        """
        Returns expected type at the Processed Numpy Stage.
        The return value is a tuple of (shape, dtype)
        """
        raise NotImplementedError()

    @abstractmethod
    def get_tf_type(self) -> Tuple[Tuple[int], np.dtype]:
        """
        Returns expected tensor spec at the TF Stage.
        """
        raise NotImplementedError()

    @staticmethod
    def get_spark_type_from_np_dtype(
        np_dtype: np.dtype,
    ) -> "spark_types.DataType":
        from pyspark.sql import types as spark_types

        NP_DTYPE_TO_SPARK_TYPE_CLS = {
            np.dtype("int64"): spark_types.LongType,
            np.dtype("int32"): spark_types.IntegerType,
            np.dtype("int16"): spark_types.ShortType,
            np.dtype("int8"): spark_types.ByteType,
            np.dtype("float32"): spark_types.FloatType,
            np.dtype("float64"): spark_types.DoubleType,
            np.dtype("datetime64[us]"): spark_types.TimestampType,
            np.dtype("bool"): spark_types.BooleanType,
            np.dtype("str"): spark_types.StringType,
            np.dtype("bytes"): spark_types.BinaryType,
        }

        spark_type_cls = NP_DTYPE_TO_SPARK_TYPE_CLS[np_dtype]

        return spark_type_cls()

    def get_spark_type(self) -> "spark_types.DataType":
        raw_np_type = self.get_raw_np_type()
        spark_type = self.get_spark_type_from_np_dtype(raw_np_type)

        return spark_type

    def __repr__(self):
        cls_name = self.__class__.__name__

        return f"{cls_name}()"


class WithRawNpType:
    """
    Type that are directly mapped to a raw np dtype.
    The intherited class shall define the raw_np_type attribute
    """

    raw_np_type: np.dtype

    def get_raw_np_type(self):
        return self.raw_np_type


class ProcessedNpAsRawNp:
    def get_processed_np_type(self):
        shape = ()
        dtype = self.get_raw_np_type()

        return (shape, dtype)


class TfAsProccessedNp:
    def get_tf_type(self):
        return self.get_processed_np_type()


class DirectType(
    TfAsProccessedNp,
    ProcessedNpAsRawNp,
):
    pass


class RawNpFromBaseType:
    """
    Logical Type that are not statically mapped to a Physical Type,
    such as Integer, Real, Longitude, Latitude,
    """

    base_name: BaseTypeName

    def __init__(self, base_name: BaseTypeName, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if base_name.value not in self.type_pb_cls.Base.keys():
            raise ValueError(
                f"No base {base_name.value} within protobuffer definition"
            )

        self.base_name = base_name

    def __repr__(self):
        cls_name = self.__class__.__name__

        return f"{cls_name}({self.base_name.value})"

    def get_raw_np_type(self):
        raw_np_type = BASE_TYPE_NAME_TO_NP_DTYPE[self.base_name]

        return raw_np_type

    @classmethod
    def _from_proto(cls, type_pb):
        base_name_str = cls.type_pb_cls.Base.Name(type_pb.base)
        base_name = BaseTypeName[base_name_str]

        return cls(base_name=base_name)


class WithMultiStageConverter:
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.multistage_converter = self.multistage_converter_cls(self)

    def raw_np_to_processed_np(self, raw_np_value):
        return self.multistage_converter.raw_np_to_processed_np(raw_np_value)

    def processed_np_to_raw_np(self, processed_np_value):
        return self.multistage_converter.processed_np_to_raw_np(
            processed_np_value
        )

    def raw_np_to_tf(self, raw_np_value):
        return self.multistage_converter.raw_np_to_tf(raw_np_value)

    def tf_to_raw_np(self, tf_value):
        return self.multistage_converter.tf_to_raw_np(tf_value)


class Boolean(DirectType, WithRawNpType, WithMultiStageConverter, Type):
    type_pb_cls = schema_pb2.Type.Boolean
    raw_np_type = np.dtype("bool")
    multistage_converter_cls = multistage_converters.Identity

    @classmethod
    def _from_proto(cls, type_pb):
        return cls()


class Integer(DirectType, RawNpFromBaseType, WithMultiStageConverter, Type):
    type_pb_cls = schema_pb2.Type.Integer
    multistage_converter_cls = multistage_converters.Identity


class Categorical(
    DirectType, RawNpFromBaseType, WithMultiStageConverter, Type
):
    type_pb_cls = schema_pb2.Type.Categorical
    multistage_converter_cls = multistage_converters.Identity


class Real(DirectType, RawNpFromBaseType, WithMultiStageConverter, Type):
    type_pb_cls = schema_pb2.Type.Real
    multistage_converter_cls = multistage_converters.Identity


class DateTime(
    ProcessedNpAsRawNp, WithRawNpType, WithMultiStageConverter, Type
):
    """
    WARNING

    Base was not used by our pipeline at all and was ignored.
    Instead, a DateTime was converted into:

     - raw_np, processed_np: datetime64 (without specifying time unit)
     - tf: int64

    It was handled as a string only when parsing the CSV file.

    We stick to the old behavior.
    """

    type_pb_cls = schema_pb2.Type.DateTime
    # We use datetime64[us] because
    #  1. Spark internally uses a Long (int64) with microsecond as time unit
    #     starting from version 1.5.0.
    #     We also configured Spark to output Parquet with MICROS Timestamp:
    #         "spark.sql.parquet.outputTimestampType": "TIMESTAMP_MICROS"
    #  2. tolist do not return datetime.datetime if timeunit=ns because datetime has
    #     only a microsecond resolution.
    #     But tolist returns datetime only if the year is within (1, 999) and other constraints
    #     (see numpy/core/src/multiarray/datetime.c)
    raw_np_type = np.dtype("datetime64[us]")
    tf_dtype = np.int64
    multistage_converter_cls = multistage_converters.DateTime

    @classmethod
    def _from_proto(cls, type_pb):
        return cls()

    def get_tf_type(self):
        shape = ()
        dtype = self.tf_dtype

        return shape, dtype


class Latittude(DirectType, RawNpFromBaseType, WithMultiStageConverter, Type):
    type_pb_cls = schema_pb2.Type.Latitude
    multistage_converter_cls = multistage_converters.Identity


class Longitude(DirectType, RawNpFromBaseType, WithMultiStageConverter, Type):
    type_pb_cls = schema_pb2.Type.Longitude
    multistage_converter_cls = multistage_converters.Identity


class Optional(Type):
    """
    Optional needs a quite good clarification because Parquet and
    Arrow handle missing values with masking while Numpy uses magic
    values. Currently, if an int64 column has missing values, we will
    get a float point column which is a different type with completely
    different characteristics (especially range).  Let's also point
    out casting a NaN to an int is undefined behavior in the C spec and
    is architecture dependent.
    """

    type_pb_cls = schema_pb2.Type.Optional

    def __init__(self, type: Type, *args, **kwargs):
        super().__init__(*args, **kwargs)

        assert not isinstance(
            type, Optional
        ), "You could not nest Optional types"

        self.type = type

    @classmethod
    def _from_proto(cls, type_pb: schema_pb2.Type.Optional):
        type = Type.from_union_schema_proto(type_pb.type)
        return cls(type=type)

    def get_spark_type(self):
        raise NotImplementedError(
            "Optional is not existentas a parametric Type in the Spark Type System."
        )

    def get_raw_np_type(self):
        return self.type.get_raw_np_type()

    def get_processed_np_type(self):
        return self.type.get_processed_np_type()

    def get_tf_type(self):
        return self.type.get_tf_type()

    def raw_np_to_processed_np(self, raw_np_value):
        return self.type.raw_np_to_processed_np(raw_np_value)

    def processed_np_to_raw_np(self, processed_np_value):
        return self.type.processed_np_to_raw_np(processed_np_value)

    def raw_np_to_tf(self, raw_np_value):
        return self.type.raw_np_value(raw_np_value)

    def tf_to_raw_np(self, tf_value):
        return self._converter.tf_to_raw_np(tf_value)


class Image(
    WithRawNpType,
    TfAsProccessedNp,
    WithMultiStageConverter,
    Type,
):
    type_pb_cls = schema_pb2.Type.Image
    raw_np_type = np.dtype("bytes")
    multistage_converter_cls = multistage_converters.Image

    shape: Tuple[int, int, int]

    def __init__(self, shape: Tuple[int, int, int], *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.shape = shape

    def __repr__(self):
        cls_name = self.__class__.__name__

        return f"{cls_name}{self.shape}"

    @classmethod
    def _from_proto(cls, type_pb):
        return cls(
            shape=(
                type_pb.shape.height,
                type_pb.shape.width,
                type_pb.shape.channel,
            ),
        )

    def get_processed_np_type(self):
        return (self.shape, np.dtype("int32"))


class Text(DirectType, WithRawNpType, WithMultiStageConverter, Type):
    type_pb_cls = schema_pb2.Type.Text
    raw_np_type = np.dtype("str")
    multistage_converter_cls = multistage_converters.IdentityString

    @classmethod
    def _from_proto(cls, type_pb):
        return cls()


SYNTHETIC_NULLABLE_TYPES = (
    DateTime,
    Real,
    Integer,
)
TO_PARSE_TYPES = (
    DateTime,
    Categorical,
)