Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from __future__ import annotations

from collections import defaultdict
import typing as t

from sqlalchemy import MetaData, Numeric, event, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.selectable import Select
import pyarrow as pa
import sqlalchemy

from sarus_data_spec.arrow.type import to_arrow
import sarus_data_spec.typing as st


async def iter_arrow(
    engine: Engine,
    queries: t.Dict[str, Select],
    sarus_schema: st.Schema,
    batch_size: int,
    multi_schemas: bool,
) -> t.AsyncIterator[pa.RecordBatch]:
    """Executes the SQL queries and saves them to Parquet"""

    arrow_schema = sarus_schema.to_arrow()

    if multi_schemas:
        for schema_name, union_type in sarus_schema.type().children().items():
            for table_name, table_struct in union_type.children().items():
                query = queries[f"{schema_name}.{table_name}"]
                columns = [col.name for col in query.c]
                pa_type = to_arrow(table_struct)
                with engine.begin() as conn:
                    for res in conn.execute(query).partitions(size=batch_size):
                        records = pa.StructArray.from_arrays(
                            [
                                pa.array(
                                    (_[i] for _ in res),
                                    size=len(res),
                                    type=field.type,
                                )
                                for i, field in enumerate(pa_type)
                            ],
                            fields=pa_type,
                        )

                        yield arrow_recordbatch(
                            records=records,
                            table_name=table_name,
                            schema_name=schema_name,
                            tables_union=union_type,
                            schemas_union=sarus_schema.type(),
                            arrow_schema=arrow_schema,
                        )

    else:
        for table_name, table_struct in sarus_schema.type().children().items():
            query = queries[table_name]
            with engine.begin() as conn:
                for res in conn.execute(query).partitions(size=batch_size):
                    columns = [col.name for col in query.c]
                    output = [
                        {
                            col_name: element
                            for col_name, element in zip(columns, line_ex)
                        }
                        for line_ex in res
                    ]
                    records = pa.array(
                        output,
                        type=to_arrow(table_struct),
                    )
                    yield arrow_recordbatch(
                        records=records,
                        table_name=table_name,
                        tables_union=sarus_schema.type(),
                        arrow_schema=arrow_schema,
                    )


def arrow_recordbatch(
    records: pa.Array,
    table_name: str,
    tables_union: st.Type,
    arrow_schema: pa.Schema,
    schema_name: t.Optional[str] = None,
    schemas_union: t.Optional[st.Type] = None,
) -> pa.RecordBatch:
    """For a given SQL table that has been downloaded into its records
    returns its representation as arrow struct where the corresponding
    other missing tables or missing schemas have been
    added as None"""

    # Create StructArray of Tables
    names_list = list(tables_union.children().keys())

    zero = [None for _ in records]
    arrays_list = [
        pa.array(zero, type=to_arrow(el_type))
        if table_name != el_name
        else records
        for el_name, el_type in (tables_union.children().items())
    ]
    arrays_list.append(
        pa.array([table_name] * len(records), type=pa.large_string())
    )
    names_list.append("field_selected")
    # TODO: remove append() above and zip(names_list, array_list)
    pa_fields = [
        pa.field(name, arrays_list[idx].type)
        for idx, name in enumerate(names_list[:-1])
    ] + [pa.field("field_selected", pa.large_string(), nullable=False)]

    if schema_name is not None and schemas_union is not None:
        struct_array = pa.StructArray.from_arrays(
            arrays_list,
            fields=pa_fields,
        )
        # Create StructArray of Schemas
        arrays_list = [
            pa.array(zero, type=to_arrow(el_type))
            if schema_name != el_name
            else struct_array
            for el_name, el_type in schemas_union.children().items()
        ]
        arrays_list.append(
            pa.array([schema_name] * len(records), type=pa.large_string())
        )

    return pa.RecordBatch.from_arrays(arrays_list, schema=arrow_schema)


async def sql_to_arrow(
    dataset: st.Dataset,
    sarus_schema: st.Schema,
    engine: Engine,
    batch_size: int,
) -> t.AsyncIterator[pa.RecordBatch]:
    """Connects to a sql database and saves data to parquet following the
    sarus_schema"""
    metadata = MetaData()

    @event.listens_for(metadata, "column_reflect")
    def genericize_datatypes(
        inspector: sqlalchemy.engine.reflection.Inspector,
        tablename: str,
        column_dict: t.Dict,
    ) -> None:
        column_dict["type"] = column_dict["type"].as_generic()

    tables_dict: t.Dict[str, t.List[str]] = defaultdict(list)
    for table in dataset.protobuf().spec.sql.tables:
        tables_dict[table.schema].append(table.table)
    tables_dict = dict(tables_dict)

    multischemas = list(tables_dict.keys()) != [""]
    if multischemas:
        for schema, tables_list in tables_dict.items():
            metadata.reflect(
                bind=engine,
                schema=schema,
                only=tables_list,
                resolve_fks=True,
                views=True,
            )
    else:
        metadata.reflect(
            bind=engine, only=tables_dict[""], resolve_fks=True, views=True
        )

    # this needs to happen before any queries
    for sql_table in metadata.tables.values():
        for column in sql_table.columns.values():
            if isinstance(column.type, Numeric):
                column.type.asdecimal = False

    queries = {}
    if multischemas:
        for schema, table_list in tables_dict.items():
            for table_name in table_list:
                queries[f"{schema}.{table_name}"] = select(
                    *(
                        col
                        for col in metadata.tables[f"{schema}.{table_name}"].c
                    )
                )
    else:
        for table_name in tables_dict[""]:
            queries[table_name] = select(
                *(col for col in metadata.tables[table_name].c)
            )
    return iter_arrow(
        engine=engine,
        queries=queries,
        batch_size=batch_size,
        sarus_schema=sarus_schema,
        multi_schemas=multischemas,
    )


async def _sql_to_arrow(
    dataset: st.Dataset,
    sarus_schema: st.Schema,
    batch_size: int,
) -> t.AsyncIterator[pa.RecordBatch]:
    engine = dataset.manager().engine(uri=dataset.protobuf().spec.sql.uri)
    return await sql_to_arrow(
        dataset=dataset,
        engine=engine,
        batch_size=batch_size,
        sarus_schema=sarus_schema,
    )