Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import typing as t
import warnings

import sarus_data_spec.typing as st

try:
    import pyspark.sql.types as pst

    def spark_type(_type: st.Type) -> t.Tuple[pst.DataType, bool]:
        """Visitor that returns the schema Spark given the Sarus Type"""

        class SchemaVisitor(st.TypeVisitor):
            schema: pst.DataType = pst.StructType(fields=None)
            nullable: bool = False

            def Boolean(
                self, properties: t.Optional[t.Mapping[str, str]] = None
            ) -> None:
                self.schema = pst.BooleanType()

            def Bytes(
                self, properties: t.Optional[t.Mapping[str, str]] = None
            ) -> None:
                self.schema = pst.ByteType()

            def Datetime(
                self,
                format: str,
                min: str,
                max: str,
                base: st.DatetimeBase,
                possible_values: t.Iterable[str],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                # TODO : decide how to deal with different datetime types
                if base == st.DatetimeBase.INT64_NS:
                    self.schema = pst.TimestampType()
                elif base == st.DatetimeBase.INT64_MS:
                    self.schema = pst.TimestampType()
                else:
                    self.schema = pst.StringType()

            def Date(
                self,
                format: str,
                min: str,
                max: str,
                base: st.DateBase,
                possible_values: t.Iterable[str],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError

            def Duration(
                self,
                unit: str,
                min: int,
                max: int,
                possible_values: t.Iterable[int],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError

            def Time(
                self,
                format: str,
                min: str,
                max: str,
                base: st.TimeBase,
                possible_values: t.Iterable[str],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError

            def Enum(
                self,
                name: str,
                name_values: t.Sequence[t.Tuple[str, int]],
                ordered: bool,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                self.schema = pst.StringType()

            def Float(
                self,
                min: float,
                max: float,
                base: st.FloatBase,
                possible_values: t.Iterable[float],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                if base == st.FloatBase.FLOAT64:
                    self.schema = pst.DoubleType()
                elif base == st.FloatBase.FLOAT32:
                    self.schema = pst.FloatType()
                elif base == st.FloatBase.FLOAT16:
                    self.schema = pst.FloatType()
                else:
                    raise NotImplementedError("Type not implemented")

            def Id(
                self,
                unique: bool,
                reference: t.Optional[st.Path] = None,
                base: t.Optional[st.IdBase] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                if base == st.IdBase.INT64:
                    self.schema = pst.LongType()
                elif base == st.IdBase.INT32:
                    self.schema = pst.IntegerType()
                elif base == st.IdBase.INT16:
                    self.schema = pst.ShortType()
                elif base == st.IdBase.INT8:
                    self.schema = pst.ByteType()
                elif base == st.IdBase.STRING:
                    self.schema = pst.StringType()
                else:
                    self.schema = pst.BinaryType()

            def Integer(
                self,
                min: int,
                max: int,
                base: st.IntegerBase,
                possible_values: t.Iterable[int],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                if base == st.IntegerBase.INT64:
                    self.schema = pst.LongType()
                elif base == st.IntegerBase.INT32:
                    self.schema = pst.IntegerType()
                elif base == st.IntegerBase.INT16:
                    self.schema = pst.ShortType()
                elif base == st.IntegerBase.INT8:
                    self.schema = pst.ByteType()
                else:
                    raise NotImplementedError("Type not implemented")

            def Null(
                self, properties: t.Optional[t.Mapping[str, str]] = None
            ) -> None:
                self.schema = pst.NullType()
                self.nullable = True

            def Optional(
                self,
                field_type: st.Type,
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                self.schema = spark_type(field_type)[0]
                self.nullable = True

            def Struct(
                self,
                fields: t.Mapping[str, st.Type],
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                spark_fields = []
                for field_name, field_type in fields.items():
                    spark_field_type, nullable = spark_type(field_type)
                    spark_fields.append(
                        pst.StructField(
                            name=field_name,
                            dataType=spark_field_type,
                            nullable=nullable,
                        )
                    )
                self.schema = pst.StructType(fields=spark_fields)

            def Text(
                self,
                encoding: str,
                possible_values: t.Iterable[str],
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                self.schema = pst.StringType()

            def Union(
                self,
                fields: t.Mapping[str, st.Type],
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError("Type not implemented")

            def Unit(
                self, properties: t.Optional[t.Mapping[str, str]] = None
            ) -> None:
                self.schema = pst.NullType()
                self.nullable = True

            def Constrained(
                self,
                type: st.Type,
                constraint: st.Predicate,
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError("Type not implemented")

            def List(
                self,
                type: st.Type,
                max_size: int,
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError("Type not implemented")

            def Array(
                self,
                type: st.Type,
                shape: t.Tuple[int, ...],
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError("Type not implemented")

            def Hypothesis(
                self,
                *types: t.Tuple[st.Type, float],
                name: t.Optional[str] = None,
                properties: t.Optional[t.Mapping[str, str]] = None,
            ) -> None:
                raise NotImplementedError("Type not implemented")

        visitor = SchemaVisitor()
        _type.accept(visitor)
        return visitor.schema, visitor.nullable

except ModuleNotFoundError:
    warnings.warn("pyspark not found, pyspark functions not available")