Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
from __future__ import annotations
from collections import defaultdict
import typing as t
from sqlalchemy import MetaData, Numeric, event, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.selectable import Select
import pyarrow as pa
import sqlalchemy
from sarus_data_spec.arrow.type import to_arrow
import sarus_data_spec.typing as st
async def iter_arrow(
engine: Engine,
queries: t.Dict[str, Select],
sarus_schema: st.Schema,
batch_size: int,
multi_schemas: bool,
) -> t.AsyncIterator[pa.RecordBatch]:
"""Executes the SQL queries and saves them to Parquet"""
arrow_schema = sarus_schema.to_arrow()
if multi_schemas:
for schema_name, union_type in sarus_schema.type().children().items():
for table_name, table_struct in union_type.children().items():
query = queries[f"{schema_name}.{table_name}"]
columns = [col.name for col in query.c]
pa_type = to_arrow(table_struct)
with engine.begin() as conn:
for res in conn.execute(query).partitions(size=batch_size):
records = pa.StructArray.from_arrays(
[
pa.array(
(_[i] for _ in res),
size=len(res),
type=field.type,
)
for i, field in enumerate(pa_type)
],
fields=pa_type,
)
yield arrow_recordbatch(
records=records,
table_name=table_name,
schema_name=schema_name,
tables_union=union_type,
schemas_union=sarus_schema.type(),
arrow_schema=arrow_schema,
)
else:
for table_name, table_struct in sarus_schema.type().children().items():
query = queries[table_name]
with engine.begin() as conn:
for res in conn.execute(query).partitions(size=batch_size):
columns = [col.name for col in query.c]
output = [
{
col_name: element
for col_name, element in zip(columns, line_ex)
}
for line_ex in res
]
records = pa.array(
output,
type=to_arrow(table_struct),
)
yield arrow_recordbatch(
records=records,
table_name=table_name,
tables_union=sarus_schema.type(),
arrow_schema=arrow_schema,
)
def arrow_recordbatch(
records: pa.Array,
table_name: str,
tables_union: st.Type,
arrow_schema: pa.Schema,
schema_name: t.Optional[str] = None,
schemas_union: t.Optional[st.Type] = None,
) -> pa.RecordBatch:
"""For a given SQL table that has been downloaded into its records
returns its representation as arrow struct where the corresponding
other missing tables or missing schemas have been
added as None"""
# Create StructArray of Tables
names_list = list(tables_union.children().keys())
zero = [None for _ in records]
arrays_list = [
pa.array(zero, type=to_arrow(el_type))
if table_name != el_name
else records
for el_name, el_type in (tables_union.children().items())
]
arrays_list.append(
pa.array([table_name] * len(records), type=pa.large_string())
)
names_list.append("field_selected")
# TODO: remove append() above and zip(names_list, array_list)
pa_fields = [
pa.field(name, arrays_list[idx].type)
for idx, name in enumerate(names_list[:-1])
] + [pa.field("field_selected", pa.large_string(), nullable=False)]
if schema_name is not None and schemas_union is not None:
struct_array = pa.StructArray.from_arrays(
arrays_list,
fields=pa_fields,
)
# Create StructArray of Schemas
arrays_list = [
pa.array(zero, type=to_arrow(el_type))
if schema_name != el_name
else struct_array
for el_name, el_type in schemas_union.children().items()
]
arrays_list.append(
pa.array([schema_name] * len(records), type=pa.large_string())
)
return pa.RecordBatch.from_arrays(arrays_list, schema=arrow_schema)
async def sql_to_arrow(
dataset: st.Dataset,
sarus_schema: st.Schema,
engine: Engine,
batch_size: int,
) -> t.AsyncIterator[pa.RecordBatch]:
"""Connects to a sql database and saves data to parquet following the
sarus_schema"""
metadata = MetaData()
@event.listens_for(metadata, "column_reflect")
def genericize_datatypes(
inspector: sqlalchemy.engine.reflection.Inspector,
tablename: str,
column_dict: t.Dict,
) -> None:
column_dict["type"] = column_dict["type"].as_generic()
tables_dict: t.Dict[str, t.List[str]] = defaultdict(list)
for table in dataset.protobuf().spec.sql.tables:
tables_dict[table.schema].append(table.table)
tables_dict = dict(tables_dict)
multischemas = list(tables_dict.keys()) != [""]
if multischemas:
for schema, tables_list in tables_dict.items():
metadata.reflect(
bind=engine,
schema=schema,
only=tables_list,
resolve_fks=True,
views=True,
)
else:
metadata.reflect(
bind=engine, only=tables_dict[""], resolve_fks=True, views=True
)
# this needs to happen before any queries
for sql_table in metadata.tables.values():
for column in sql_table.columns.values():
if isinstance(column.type, Numeric):
column.type.asdecimal = False
queries = {}
if multischemas:
for schema, table_list in tables_dict.items():
for table_name in table_list:
queries[f"{schema}.{table_name}"] = select(
*(
col
for col in metadata.tables[f"{schema}.{table_name}"].c
)
)
else:
for table_name in tables_dict[""]:
queries[table_name] = select(
*(col for col in metadata.tables[table_name].c)
)
return iter_arrow(
engine=engine,
queries=queries,
batch_size=batch_size,
sarus_schema=sarus_schema,
multi_schemas=multischemas,
)
async def _sql_to_arrow(
dataset: st.Dataset,
sarus_schema: st.Schema,
batch_size: int,
) -> t.AsyncIterator[pa.RecordBatch]:
engine = dataset.manager().engine(uri=dataset.protobuf().spec.sql.uri)
return await sql_to_arrow(
dataset=dataset,
engine=engine,
batch_size=batch_size,
sarus_schema=sarus_schema,
)