Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
import typing as t
from pyarrow import csv
import pyarrow as pa
import pyarrow.compute as pc
from sarus_data_spec.arrow.type import to_arrow
from sarus_data_spec.manager.ops.source.csv.schema import (
convert_names,
csv_params_from_clevercsv,
)
import sarus_data_spec.typing as st
async def csv_to_arrow(
dataset: st.Dataset, batch_size: int
) -> t.AsyncIterator[pa.RecordBatch]:
uri = dataset.protobuf().spec.file.uri
source_fs = dataset.manager().filesystem(uri) # type:ignore
with source_fs.open(uri) as f:
sample = "".join([f.readline().decode("utf-8") for _ in range(100)])
read_options, parse_options = csv_params_from_clevercsv(sample=sample)
convert_options = csv.ConvertOptions(
strings_can_be_null=True,
timestamp_parsers=[csv.ISO8601, "%Y-%m-%d"],
true_values=["True"],
false_values=["False"],
)
# TODO: use csv.open_csv to stream batch by batch
# directly, only issue is that batch_size seems
# to be set as bytes and not lines
table = csv.read_csv(
pa.py_buffer(source_fs.open(uri).read()),
read_options=read_options,
parse_options=parse_options,
convert_options=convert_options,
)
new_fields = [
pa.field(name, orig_type)
if orig_type != pa.string()
else pa.field(name, pa.large_string())
for name, orig_type in zip(table.schema.names, table.schema.types)
]
table = table.cast(pa.schema(new_fields))
column_names = convert_names(table.schema.names)
table = table.rename_columns(column_names)
datatype = (await dataset.manager().async_schema(dataset)).data_type()
array = cast_array(datatype, table)
table = pa.Table.from_arrays(
arrays=array.flatten(), names=list(datatype.children().keys())
)
for batch in table.to_batches(max_chunksize=batch_size):
yield batch
def cast_array(_type: st.Type, array: pa.Table) -> pa.Array:
"""Visitor selecting columns based on the type"""
class ArrayConvertor(st.TypeVisitor):
batch_array: pa.Array = array
def Datetime(
self,
format: str,
min: str,
max: str,
base: st.DatetimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
if array.type == pa.large_string():
if base == st.DatetimeBase.INT64_NS:
# strftime,strptime
self.batch_array = pa.compute.strptime(
self.batch_array, format=format, unit="ns"
)
elif base == st.DatetimeBase.INT64_MS:
self.batch_array = pa.compute.strptime(
self.batch_array, format=format, unit="ms"
)
else:
raise NotImplementedError(
"Got string to convert to a datetime"
"with unsupported base"
)
elif array.type.unit == "ns":
self.batch_array = pc.round_temporal(array, unit="second")
elif array.type.unit == "ms":
self.batch_array = self.batch_array.cast(pa.timestamp("ns"))
elif array.type.unit == "s":
self.batch_array = self.batch_array.cast(pa.timestamp("ns"))
else:
raise NotImplementedError(
"Got a datetime with an unsupported unit"
)
def Date(
self,
format: str,
min: str,
max: str,
base: st.DateBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
if array.type == pa.date32():
pass
elif array.type == pa.large_string():
# in this case it is a string, because the type
# has been inferred (for csv source)
self.batch_array = pa.compute.strptime(
self.batch_array, format=format, unit="ns"
).cast(pa.date32())
else:
raise ValueError(
f"Unexpected type {array.type} for array "
f"in Sarus Date type"
)
def Time(
self,
format: str,
min: str,
max: str,
base: st.TimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
if base == st.TimeBase.INT64_US:
self.batch_array = pa.compute.cast(
array,
target_type=pa.time64("us"),
)
elif base == st.TimeBase.INT32_MS:
self.batch_array = pa.compute.cast(
array,
target_type=pa.time32("ms"),
)
else:
self.batch_array = pa.compute.cast(
array, target_type=pa.large_string()
)
def Struct(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self.batch_array = pa.StructArray.from_arrays(
arrays=[
cast_array(fields[field_name], column.combine_chunks())
for field_name, column in zip(fields.keys(), array.columns)
],
fields=[
pa.field(
name=field_name,
type=to_arrow(fields[field_name]),
nullable=fields[field_name]
.protobuf()
.HasField("optional")
or fields[field_name].protobuf().HasField("unit"),
)
for field_name in fields.keys()
],
)
def Optional(
self,
type: st.Type,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self.batch_array = cast_array(type, array)
def Array(
self,
type: st.Type,
shape: t.Tuple[int, ...],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def List(
self,
type: st.Type,
max_size: int,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Boolean(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
pass
def Bytes(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
pass
def Unit(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
pass
def Constrained(
self,
type: st.Type,
constraint: st.Predicate,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Duration(
self,
unit: str,
min: int,
max: int,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Enum(
self,
name: str,
name_values: t.Sequence[t.Tuple[str, int]],
ordered: bool,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Text(
self,
encoding: str,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Hypothesis(
self,
*types: t.Tuple[st.Type, float],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Id(
self,
unique: bool,
reference: t.Optional[st.Path] = None,
base: t.Optional[st.IdBase] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Integer(
self,
min: int,
max: int,
base: st.IntegerBase,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Null(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
pass
def Float(
self,
min: float,
max: float,
base: st.FloatBase,
possible_values: t.Iterable[float],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
def Union(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
pass
visitor = ArrayConvertor()
_type.accept(visitor)
return visitor.batch_array