Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
import os
import tempfile
import typing as t
import warnings
import pyarrow as pa
import pyarrow.parquet as pq
try:
import sqlalchemy as sa
except ModuleNotFoundError:
warnings.warn("Sqlalchemy not installed, cannot send sql queries")
from sarus_data_spec import typing as st
from sarus_data_spec.constants import (
DATA,
OPTIONAL_VALUE,
PU_COLUMN,
PUBLIC,
SQL_CACHING_URI,
TO_SQL_CACHING_TASK,
WEIGHTS,
)
from sarus_data_spec.manager.ops.foreign_keys import fk_visitor
from sarus_data_spec.manager.ops.primary_keys import pk_visitor
from sarus_data_spec.manager.ops.sql_utils.schema_translations import (
BQFieldType,
to_bq_schema,
to_sqlalchemy_metadata,
)
from sarus_data_spec.manager.ops.sql_utils.table_mapping import (
base64encode_table_map,
expand_table_map,
explicit_table_map,
name_encoder,
table_mapping,
)
from sarus_data_spec.transform import filter, get_item
import sarus_data_spec.protobuf as sp
try:
from google.cloud import bigquery
from google.oauth2 import service_account
except ImportError:
warnings.warn(
"Google API Python client and related libraries are not installed,"
" cannot push dataset to bigquery"
)
big_query_schema_field = t.Any
else:
big_query_schema_field = bigquery.SchemaField # type:ignore
# size of the batches when pushing to the database. If too high, it can
# cause big memory consumption when the batch is materialised,
# especially if the dataset has a lot of columns.
# Beware when increasing it.
BATCH_SIZE = 5_000
TAB_NAME_LENGTH = 8
SCHEMA_NAME_LENGTH = TAB_NAME_LENGTH
async def async_sql_op(
dataset: st.Dataset,
pushed_uri: str,
sql_schema: str,
table_names: t.Optional[t.List[str]],
drop_admin: bool,
) -> t.Tuple[str, t.Dict[str, t.Tuple[str, ...]], t.List[str]]:
"""It pushes the dataset to a SQL DB.
- It retrives the caching_uri from the source status.
- It creates a schema in the DB named according to the dataset uuid.
- It iterates over the sarus tables to retrive the correct column
types and associated constraints (pks, fks and nullable columns).
- It iterates over the the sarus tables a second time to push the data.
It returns:
(pushed_uri, encoded table mapping, extended table mapping)
"""
manager = dataset.manager()
ds_schema = await manager.async_schema(dataset)
tables = ds_schema.tables()
assert pushed_uri
engine = manager.engine(uri=pushed_uri)
# create a schema name from the dataset uuid
metadata = SarusMetadata(engine, sql_schema)
primary_keys = [
tuple(path.to_strings_list()[0])
for path in pk_visitor(ds_schema.data_type())
]
foreign_keys = {}
for key, value in fk_visitor(ds_schema.data_type()).items():
newkey = tuple(key.to_strings_list()[0])
if OPTIONAL_VALUE in newkey:
foreign_keys[tuple([*newkey[:-1]])] = tuple(
value.to_strings_list()[0][1:]
)
else:
foreign_keys[newkey] = tuple(value.to_strings_list()[0][1:])
# table mapping: Path (withoud DATA) ->
# (sql_schema, sql_tables)
table_map = table_mapping(
tables=tables,
sarus_schema_name=ds_schema.name(),
encoded_name_length=TAB_NAME_LENGTH,
sql_schema=sql_schema,
table_names=table_names,
)
# Tuple[str] representation of table mapping -> (sql_schema,
# sql_tables). Used to correctly map pks and fks.
explicit_map = explicit_table_map(table_map)
# I need to iterate twice on tables:
# ones to fill the metadata and to create tables at once in the
# DB. A second time to push
# store sql table name -> sarus table dataset to be used to push data.
sqltable_to_table_ds = {}
# Iterate to create sql tables from sarus schema type of each table.
for table_path in tables:
# full path as Tuple[str]
explicit_table_path = (
ds_schema.name(),
*table_path.to_strings_list()[0][1:],
)
# this is an encoded sql table name
table_name = explicit_map[explicit_table_path][-1]
# get table dataset and schema
ds_filter = ds_schema.data_type().get(table_path)
filtered_ds = filter(filter=ds_filter)(dataset)
table_ds = get_item(table_path)(filtered_ds)
schema_table = await manager.async_schema(t.cast(st.Dataset, table_ds))
# get pks and fks using the tables encoded names
table_pks = [
pk[-1] for pk in primary_keys if pk[:-1] == explicit_table_path[1:]
]
table_fks = {
keys[-1]: ".".join(
[
explicit_map[(ds_schema.name(), *values[:-1])][-1],
values[-1],
]
)
for keys, values in foreign_keys.items()
if keys[:-1] == explicit_table_path[1:]
}
# fill metadata depending on schma type
schema_type = (
schema_table.data_type() if drop_admin else schema_table.type()
)
metadata.fill(
schema_type,
table_name,
primary_keys=table_pks,
foreign_keys=table_fks,
)
sqltable_to_table_ds[table_name] = table_ds
metadata.create_all_tables()
# Iterate for a second time to push sarus tables to sql.
for table_name in metadata.sorted_tables():
table_ds = sqltable_to_table_ds[table_name]
iterator = await manager.async_to(
t.cast(st.Dataset, table_ds),
kind=t.AsyncIterator[pa.RecordBatch],
drop_admin=drop_admin,
batch_size=BATCH_SIZE,
)
iterator = t.cast(t.AsyncIterator[pa.RecordBatch], iterator)
async for batch in iterator:
if DATA in batch.schema.names:
# Here we flatten the batch to have a struct
# with PU_COLUMN, WEIGHTS, PUBLIC col1, col2 etc.
data_arrays = batch.column(DATA).flatten()
colnames = [
field.name.replace(DATA + ".", "")
for field in batch.field(DATA).flatten()
]
for col in [PU_COLUMN, WEIGHTS, PUBLIC]:
data_arrays.append(batch.column(col))
colnames.append(batch.field(col).name)
batch = pa.record_batch(data_arrays, names=colnames)
if batch:
metadata.push_tosql(table_name, batch)
# Path (withoud DATA) -> (sql_schema, sql_tables)
encoded_map = base64encode_table_map(table_map)
expanded_map = expand_table_map(list(explicit_map.keys()))
return pushed_uri, encoded_map, expanded_map
async def async_to_sql_op(
dataset: st.Dataset,
) -> t.Tuple[str, t.Dict[str, t.Tuple[str, ...]], t.List[str]]:
"""It pushes the dataset to a SQL DB.
- It retrives the caching_uri from the source status.
- It creates a schema in the DB named according to the dataset uuid.
- It iterates over the sarus tables to retrive the correct column
types and associated constraints (pks, fks and nullable columns).
- It iterates over the the sarus tables a second time to push the data.
It returns:
(pushed_uri, encoded table mapping, extended table mapping)
"""
manager = dataset.manager()
if dataset.is_source() and dataset.protobuf().spec.HasField("sql"):
ds_schema = await manager.async_schema(dataset)
tables = ds_schema.tables()
# If it is a sql source we still need to attach a ready status
# with a table mapping
table_map = table_mapping(
tables=tables,
sarus_schema_name=ds_schema.name(),
sql_schema=None,
)
explicit_map = explicit_table_map(table_map)
pushed_uri = dataset.protobuf().spec.sql.uri
# Path (withoud DATA) -> (sql_schema, sql_tables)
encoded_map = base64encode_table_map(table_map)
expanded_map = expand_table_map(list(explicit_map.keys()))
return pushed_uri, encoded_map, expanded_map
else:
source_ds = dataset.sources(sp.type_name(sp.Dataset)).pop()
source_st = manager.status(
dataspec=t.cast(st.DataSpec, source_ds),
task_name=TO_SQL_CACHING_TASK,
)
if source_st is None:
raise ValueError(
f"Missing {TO_SQL_CACHING_TASK} for dataset "
f"{source_ds.uuid()}"
)
assert source_st
source_st_task = source_st.task(task=TO_SQL_CACHING_TASK)
if source_st_task is None:
raise ValueError(
f"Missing {TO_SQL_CACHING_TASK} task for dataset "
f"{source_ds.uuid()}"
)
assert source_st_task
pushed_uri_caching = source_st_task.properties().get(SQL_CACHING_URI)
if pushed_uri_caching is None:
raise ValueError(
f"Source does not have a {SQL_CACHING_URI} value, can't use to sql."
)
pushed_uri = pushed_uri_caching
sql_schema = dataset.manager().sql_pushing_schema_prefix(
dataset
) + name_encoder(
names=(dataset.uuid(),),
length=SCHEMA_NAME_LENGTH,
)
# choose encoded table name
table_names = None
return await async_sql_op(
dataset, pushed_uri, sql_schema, table_names, drop_admin=False
)
async def async_push_sql_op(
dataset: st.Dataset,
) -> t.Tuple[str, t.Dict[str, t.Tuple[str, ...]], t.List[str]]:
"""It pushes the dataset to a SQL DB.
- It retrives the caching_uri from the source status.
- It creates a schema in the DB named according to the dataset uuid.
- It iterates over the sarus tables to retrive the correct column
types and associated constraints (pks, fks and nullable columns).
- It iterates over the the sarus tables a second time to push the data.
It returns:
(pushed_uri, encoded table mapping, extended table mapping)
"""
if (
dataset.is_transformed()
and dataset.transform().protobuf().spec.HasField("push_sql")
):
# for the transform push_sql, we have the uri
pushed_uri = dataset.transform().protobuf().spec.push_sql.uri
sql_schema = dataset.transform().protobuf().spec.push_sql.schema_name
table_name = dataset.transform().protobuf().spec.push_sql.table_name
table_names = [table_name]
return await async_sql_op(
dataset, pushed_uri, sql_schema, table_names, drop_admin=True
)
else:
raise ValueError(
f"Can't push dataset with transform {dataset.transform().name()} to an external database with push sql, please try with the transform push sql."
)
class SarusMetadata:
"""Helper to abstract the interaction with different APIs used to pushing
data to a DB. We use for most the time SQLAlchemy but is not always optimal
when it comes to push data to a DB (as for instance with bigquery).
"""
def __init__(self, engine: sa.engine.Engine, sql_schema_name: str) -> None:
self.engine = engine
self.sql_schema_name = sql_schema_name
self.sa_metadata: sa.MetaData
self.bq_client: t.Any
self.tabname_to_tabmetadata: t.Dict[str, t.Any] = {}
self._setup()
def _setup(self) -> None:
"""It creates the schema if the schema doesn't exists with sqlalchemy.
For bigquery, moreover, it retrieves the credentials and instancieate
the bigquery client.
"""
with self.engine.begin() as conn:
if self.sql_schema_name not in conn.dialect.get_schema_names(conn):
operation = sa.schema.CreateSchema(self.sql_schema_name) # type: ignore
conn.execute(operation)
if self.engine.dialect.name == "bigquery":
credentials_path = (
self.engine.dialect.credentials_path # type: ignore
)
credentials = (
service_account.Credentials.from_service_account_file(
credentials_path
)
)
self.bq_client = bigquery.Client(
credentials=credentials,
project=credentials.project_id,
)
self.sa_metadata = sa.MetaData(schema=self.sql_schema_name)
def fill(
self,
schema_type: st.Type,
table_name: str,
primary_keys: t.List[str],
foreign_keys: t.Dict[str, str],
) -> None:
"""It calls visitors to store information on column types. It should
take a schema_type for a specific sarus table. Except for bigquery,
we fill the sqlalchemy metadata with the a sqlalchemy Table containing
columns with the proper datatype. For bigquery, we store the table data
type information in a dict which will be used afterwards during pushing
"""
if self.engine.dialect.name == "bigquery":
# it stores the column types of the table
bq_tabmetadata: t.Dict[str, t.Any] = {}
to_bq_schema(
schema_type,
bq_tabmetadata,
typemode=BQFieldType.REQUIRED.value,
)
self.tabname_to_tabmetadata[table_name] = bq_tabmetadata
else:
sa_tabmetadata = sa.Table(table_name, self.sa_metadata)
to_sqlalchemy_metadata(
schema_type,
sa_tabmetadata,
col_name=None,
nullable=False,
primary_keys=primary_keys,
foreign_keys=foreign_keys,
)
self.tabname_to_tabmetadata[table_name] = sa_tabmetadata
def create_all_tables(self) -> None:
"""Create tables for all methods that push data via sqlalchemy"""
self.sa_metadata.create_all(self.engine)
def sorted_tables(self) -> t.List[str]:
"""Returns sorted table names to prevent pushing
a table with a foreign key pointing towards a table which is not yet
pushed. SqlAlchemy metadata does that internally while order doesn't
matter for bigquery.
"""
if self.engine.dialect.name == "bigquery":
return list(self.tabname_to_tabmetadata.keys())
else:
return [tab.name for tab in self.sa_metadata.sorted_tables]
def push_tosql(self, table_name: str, batch: pa.RecordBatch) -> None:
"""It pushes data to the SQL DB using the most efficient method
depending on the DB.
"""
if self.engine.dialect.name == "bigquery":
bq_tabmetadata = self.tabname_to_tabmetadata[table_name]
columns = list(bq_tabmetadata.keys())
col_types = [bq_tabmetadata[colname] for colname in columns]
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET, schema=col_types
)
with tempfile.TemporaryDirectory() as temp_dir:
parquet_file = os.path.join(temp_dir, "temp.parquet")
# We write to parquet instead of passing a pandas dataframe
# since python-bigquery would write to parquet anyways
# and there are no methods to load directly a pyarrow object.
# Link to loading methods:
# https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.client.Client#google_cloud_bigquery_client_Client_load_table_from_dataframe
pq.write_table(
table=pa.Table.from_batches([batch]),
where=parquet_file,
version="2.6",
coerce_timestamps="us",
allow_truncated_timestamps=True,
)
with open(parquet_file, "rb") as f:
job = self.bq_client.load_table_from_file(
file_obj=f,
destination=f"{self.sql_schema_name}.{table_name}",
job_config=job_config,
)
job.result()
elif self.engine.dialect.name == "mssql":
sa_tabmetadata = self.tabname_to_tabmetadata[table_name]
mssql_insert(self.engine, batch, sa_tabmetadata)
else:
sa_tabmetadata = self.tabname_to_tabmetadata[table_name]
default_insert(self.engine, batch, sa_tabmetadata)
def default_insert(
engine: sa.engine.Engine, batch: pa.RecordBatch, sa_table: sa.Table
) -> None:
"""SQLALCHEMY CORE API for batch insert
This should be pretty efficient and quite general:
https://towardsdatascience.com/how-to-perform-bulk-inserts-with-sqlalchemy-efficiently-in-python-23044656b97d
"""
with engine.begin() as conn:
conn.execute(sa.insert(sa_table), batch.to_pylist())
def mssql_insert(
engine: sa.engine.Engine, batch: pa.RecordBatch, sa_table: sa.Table
) -> None:
"""Method to push data on an SQL server. This is actually though when
the server is Azure. This is equivalent to a parametrized INSERT query
since pyodb seems to not support fast_executemany for AZURE.
TODO: we should be able to apply this only for azure because for
synapse which still uses mssql-pyodb the most efficient method is different
"""
azure_max_parameters = 2100
chunksize = azure_max_parameters // batch.num_columns - 1
batch_pylist = batch.to_pylist()
with engine.begin() as conn:
for i in range(0, len(batch_pylist), chunksize):
conn.execute(
sa.insert(sa_table).values(batch_pylist[i : i + chunksize])
)