Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import typing as t

from pyarrow import csv
import pyarrow as pa
import pyarrow.compute as pc

from sarus_data_spec.arrow.type import to_arrow
from sarus_data_spec.manager.ops.source.csv.schema import (
    convert_names,
    csv_params_from_clevercsv,
)
import sarus_data_spec.typing as st


async def csv_to_arrow(
    dataset: st.Dataset, batch_size: int
) -> t.AsyncIterator[pa.RecordBatch]:
    uri = dataset.protobuf().spec.file.uri
    source_fs = dataset.manager().filesystem(uri)  # type:ignore
    with source_fs.open(uri) as f:
        sample = "".join([f.readline().decode("utf-8") for _ in range(100)])

    read_options, parse_options = csv_params_from_clevercsv(sample=sample)
    convert_options = csv.ConvertOptions(
        strings_can_be_null=True,
        timestamp_parsers=[csv.ISO8601, "%Y-%m-%d"],
        true_values=["True"],
        false_values=["False"],
    )
    # TODO: use csv.open_csv to stream batch by batch
    # directly, only issue is that batch_size seems
    # to be set as bytes and not lines
    table = csv.read_csv(
        pa.py_buffer(source_fs.open(uri).read()),
        read_options=read_options,
        parse_options=parse_options,
        convert_options=convert_options,
    )
    new_fields = [
        pa.field(name, orig_type)
        if orig_type != pa.string()
        else pa.field(name, pa.large_string())
        for name, orig_type in zip(table.schema.names, table.schema.types)
    ]
    table = table.cast(pa.schema(new_fields))
    column_names = convert_names(table.schema.names)
    table = table.rename_columns(column_names)
    datatype = (await dataset.manager().async_schema(dataset)).data_type()
    array = cast_array(datatype, table)
    table = pa.Table.from_arrays(
        arrays=array.flatten(), names=list(datatype.children().keys())
    )
    for batch in table.to_batches(max_chunksize=batch_size):
        yield batch


def cast_array(_type: st.Type, array: pa.Table) -> pa.Array:
    """Visitor selecting columns based on the type"""

    class ArrayConvertor(st.TypeVisitor):
        batch_array: pa.Array = array

        def Datetime(
            self,
            format: str,
            min: str,
            max: str,
            base: st.DatetimeBase,
            possible_values: t.Iterable[str],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            if array.type == pa.large_string():
                if base == st.DatetimeBase.INT64_NS:
                    # strftime,strptime
                    self.batch_array = pa.compute.strptime(
                        self.batch_array, format=format, unit="ns"
                    )
                elif base == st.DatetimeBase.INT64_MS:
                    self.batch_array = pa.compute.strptime(
                        self.batch_array, format=format, unit="ms"
                    )
                else:
                    raise NotImplementedError(
                        "Got string to convert to a datetime"
                        "with unsupported base"
                    )
            elif array.type.unit == "ns":
                self.batch_array = pc.round_temporal(array, unit="second")
            elif array.type.unit == "ms":
                self.batch_array = self.batch_array.cast(pa.timestamp("ns"))
            elif array.type.unit == "s":
                self.batch_array = self.batch_array.cast(pa.timestamp("ns"))
            else:
                raise NotImplementedError(
                    "Got a datetime with an unsupported unit"
                )

        def Date(
            self,
            format: str,
            min: str,
            max: str,
            base: st.DateBase,
            possible_values: t.Iterable[str],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            if array.type == pa.date32():
                pass
            elif array.type == pa.large_string():
                # in this case it is a string, because the type
                # has been inferred (for csv source)
                self.batch_array = pa.compute.strptime(
                    self.batch_array, format=format, unit="ns"
                ).cast(pa.date32())
            else:
                raise ValueError(
                    f"Unexpected type {array.type} for array "
                    f"in Sarus Date type"
                )

        def Time(
            self,
            format: str,
            min: str,
            max: str,
            base: st.TimeBase,
            possible_values: t.Iterable[str],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            if base == st.TimeBase.INT64_US:
                self.batch_array = pa.compute.cast(
                    array,
                    target_type=pa.time64("us"),
                )
            elif base == st.TimeBase.INT32_MS:
                self.batch_array = pa.compute.cast(
                    array,
                    target_type=pa.time32("ms"),
                )
            else:
                self.batch_array = pa.compute.cast(
                    array, target_type=pa.large_string()
                )

        def Struct(
            self,
            fields: t.Mapping[str, st.Type],
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            self.batch_array = pa.StructArray.from_arrays(
                arrays=[
                    cast_array(fields[field_name], column.combine_chunks())
                    for field_name, column in zip(fields.keys(), array.columns)
                ],
                fields=[
                    pa.field(
                        name=field_name,
                        type=to_arrow(fields[field_name]),
                        nullable=fields[field_name]
                        .protobuf()
                        .HasField("optional")
                        or fields[field_name].protobuf().HasField("unit"),
                    )
                    for field_name in fields.keys()
                ],
            )

        def Optional(
            self,
            type: st.Type,
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            self.batch_array = cast_array(type, array)

        def Array(
            self,
            type: st.Type,
            shape: t.Tuple[int, ...],
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def List(
            self,
            type: st.Type,
            max_size: int,
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Boolean(
            self, properties: t.Optional[t.Mapping[str, str]] = None
        ) -> None:
            pass

        def Bytes(
            self, properties: t.Optional[t.Mapping[str, str]] = None
        ) -> None:
            pass

        def Unit(
            self, properties: t.Optional[t.Mapping[str, str]] = None
        ) -> None:
            pass

        def Constrained(
            self,
            type: st.Type,
            constraint: st.Predicate,
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Duration(
            self,
            unit: str,
            min: int,
            max: int,
            possible_values: t.Iterable[int],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Enum(
            self,
            name: str,
            name_values: t.Sequence[t.Tuple[str, int]],
            ordered: bool,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Text(
            self,
            encoding: str,
            possible_values: t.Iterable[str],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Hypothesis(
            self,
            *types: t.Tuple[st.Type, float],
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Id(
            self,
            unique: bool,
            reference: t.Optional[st.Path] = None,
            base: t.Optional[st.IdBase] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Integer(
            self,
            min: int,
            max: int,
            base: st.IntegerBase,
            possible_values: t.Iterable[int],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Null(
            self, properties: t.Optional[t.Mapping[str, str]] = None
        ) -> None:
            pass

        def Float(
            self,
            min: float,
            max: float,
            base: st.FloatBase,
            possible_values: t.Iterable[float],
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

        def Union(
            self,
            fields: t.Mapping[str, st.Type],
            name: t.Optional[str] = None,
            properties: t.Optional[t.Mapping[str, str]] = None,
        ) -> None:
            pass

    visitor = ArrayConvertor()
    _type.accept(visitor)
    return visitor.batch_array