Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
import typing as t
import warnings
import sarus_data_spec.typing as st
try:
import pyspark.sql.types as pst
def spark_type(_type: st.Type) -> t.Tuple[pst.DataType, bool]:
"""Visitor that returns the schema Spark given the Sarus Type"""
class SchemaVisitor(st.TypeVisitor):
schema: pst.DataType = pst.StructType(fields=None)
nullable: bool = False
def Boolean(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
self.schema = pst.BooleanType()
def Bytes(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
self.schema = pst.ByteType()
def Datetime(
self,
format: str,
min: str,
max: str,
base: st.DatetimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
# TODO : decide how to deal with different datetime types
if base == st.DatetimeBase.INT64_NS:
self.schema = pst.TimestampType()
elif base == st.DatetimeBase.INT64_MS:
self.schema = pst.TimestampType()
else:
self.schema = pst.StringType()
def Date(
self,
format: str,
min: str,
max: str,
base: st.DateBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Duration(
self,
unit: str,
min: int,
max: int,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Time(
self,
format: str,
min: str,
max: str,
base: st.TimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Enum(
self,
name: str,
name_values: t.Sequence[t.Tuple[str, int]],
ordered: bool,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self.schema = pst.StringType()
def Float(
self,
min: float,
max: float,
base: st.FloatBase,
possible_values: t.Iterable[float],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
if base == st.FloatBase.FLOAT64:
self.schema = pst.DoubleType()
elif base == st.FloatBase.FLOAT32:
self.schema = pst.FloatType()
elif base == st.FloatBase.FLOAT16:
self.schema = pst.FloatType()
else:
raise NotImplementedError("Type not implemented")
def Id(
self,
unique: bool,
reference: t.Optional[st.Path] = None,
base: t.Optional[st.IdBase] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
if base == st.IdBase.INT64:
self.schema = pst.LongType()
elif base == st.IdBase.INT32:
self.schema = pst.IntegerType()
elif base == st.IdBase.INT16:
self.schema = pst.ShortType()
elif base == st.IdBase.INT8:
self.schema = pst.ByteType()
elif base == st.IdBase.STRING:
self.schema = pst.StringType()
else:
self.schema = pst.BinaryType()
def Integer(
self,
min: int,
max: int,
base: st.IntegerBase,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
if base == st.IntegerBase.INT64:
self.schema = pst.LongType()
elif base == st.IntegerBase.INT32:
self.schema = pst.IntegerType()
elif base == st.IntegerBase.INT16:
self.schema = pst.ShortType()
elif base == st.IntegerBase.INT8:
self.schema = pst.ByteType()
else:
raise NotImplementedError("Type not implemented")
def Null(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
self.schema = pst.NullType()
self.nullable = True
def Optional(
self,
field_type: st.Type,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self.schema = spark_type(field_type)[0]
self.nullable = True
def Struct(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
spark_fields = []
for field_name, field_type in fields.items():
spark_field_type, nullable = spark_type(field_type)
spark_fields.append(
pst.StructField(
name=field_name,
dataType=spark_field_type,
nullable=nullable,
)
)
self.schema = pst.StructType(fields=spark_fields)
def Text(
self,
encoding: str,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self.schema = pst.StringType()
def Union(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError("Type not implemented")
def Unit(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
self.schema = pst.NullType()
self.nullable = True
def Constrained(
self,
type: st.Type,
constraint: st.Predicate,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError("Type not implemented")
def List(
self,
type: st.Type,
max_size: int,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError("Type not implemented")
def Array(
self,
type: st.Type,
shape: t.Tuple[int, ...],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError("Type not implemented")
def Hypothesis(
self,
*types: t.Tuple[st.Type, float],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError("Type not implemented")
visitor = SchemaVisitor()
_type.accept(visitor)
return visitor.schema, visitor.nullable
except ModuleNotFoundError:
warnings.warn("pyspark not found, pyspark functions not available")