Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import os
import tempfile
import typing as t
import warnings

import pyarrow as pa
import pyarrow.parquet as pq

try:
    import sqlalchemy as sa
except ModuleNotFoundError:
    warnings.warn("Sqlalchemy not installed, cannot send sql queries")

from sarus_data_spec import typing as st
from sarus_data_spec.constants import (
    DATA,
    OPTIONAL_VALUE,
    PU_COLUMN,
    PUBLIC,
    SQL_CACHING_URI,
    TO_SQL_CACHING_TASK,
    WEIGHTS,
)
from sarus_data_spec.manager.ops.foreign_keys import fk_visitor
from sarus_data_spec.manager.ops.primary_keys import pk_visitor
from sarus_data_spec.manager.ops.sql_utils.schema_translations import (
    BQFieldType,
    to_bq_schema,
    to_sqlalchemy_metadata,
)
from sarus_data_spec.manager.ops.sql_utils.table_mapping import (
    base64encode_table_map,
    expand_table_map,
    explicit_table_map,
    name_encoder,
    table_mapping,
)
from sarus_data_spec.transform import filter, get_item
import sarus_data_spec.protobuf as sp

try:
    from google.cloud import bigquery
    from google.oauth2 import service_account

except ImportError:
    warnings.warn(
        "Google API Python client and related libraries are not installed,"
        " cannot push dataset to bigquery"
    )
    big_query_schema_field = t.Any
else:
    big_query_schema_field = bigquery.SchemaField  # type:ignore

# size of the batches when pushing to the database. If too high, it can
# cause big memory consumption when the batch is materialised,
# especially if the dataset has a lot of columns.
# Beware when increasing it.
BATCH_SIZE = 5_000

TAB_NAME_LENGTH = 8
SCHEMA_NAME_LENGTH = TAB_NAME_LENGTH


async def async_sql_op(
    dataset: st.Dataset,
    pushed_uri: str,
    sql_schema: str,
    table_names: t.Optional[t.List[str]],
    drop_admin: bool,
) -> t.Tuple[str, t.Dict[str, t.Tuple[str, ...]], t.List[str]]:
    """It pushes the dataset to a SQL DB.
    - It retrives the caching_uri from the source status.
    - It creates a schema in the DB named according to the dataset uuid.
    - It iterates over the sarus tables to retrive the correct column
        types and associated constraints (pks, fks and nullable columns).
    - It iterates over the the sarus tables a second time to push the data.

    It returns:
        (pushed_uri, encoded table mapping, extended table mapping)
    """
    manager = dataset.manager()
    ds_schema = await manager.async_schema(dataset)
    tables = ds_schema.tables()
    assert pushed_uri
    engine = manager.engine(uri=pushed_uri)
    # create a schema name from the dataset uuid
    metadata = SarusMetadata(engine, sql_schema)
    primary_keys = [
        tuple(path.to_strings_list()[0])
        for path in pk_visitor(ds_schema.data_type())
    ]
    foreign_keys = {}
    for key, value in fk_visitor(ds_schema.data_type()).items():
        newkey = tuple(key.to_strings_list()[0])
        if OPTIONAL_VALUE in newkey:
            foreign_keys[tuple([*newkey[:-1]])] = tuple(
                value.to_strings_list()[0][1:]
            )
        else:
            foreign_keys[newkey] = tuple(value.to_strings_list()[0][1:])

    # table mapping: Path (withoud DATA) ->
    # (sql_schema, sql_tables)
    table_map = table_mapping(
        tables=tables,
        sarus_schema_name=ds_schema.name(),
        encoded_name_length=TAB_NAME_LENGTH,
        sql_schema=sql_schema,
        table_names=table_names,
    )

    # Tuple[str] representation of table mapping -> (sql_schema,
    # sql_tables). Used to correctly map pks and fks.
    explicit_map = explicit_table_map(table_map)

    # I need to iterate twice on tables:
    # ones to fill the metadata and to create tables at once in the
    # DB. A second time to push

    # store sql table name -> sarus table dataset to be used to push data.
    sqltable_to_table_ds = {}

    # Iterate to create sql tables from sarus schema type of each table.
    for table_path in tables:
        # full path as Tuple[str]
        explicit_table_path = (
            ds_schema.name(),
            *table_path.to_strings_list()[0][1:],
        )
        # this is an encoded sql table name
        table_name = explicit_map[explicit_table_path][-1]

        # get table dataset and schema
        ds_filter = ds_schema.data_type().get(table_path)
        filtered_ds = filter(filter=ds_filter)(dataset)
        table_ds = get_item(table_path)(filtered_ds)
        schema_table = await manager.async_schema(t.cast(st.Dataset, table_ds))
        # get pks and fks using the tables encoded names
        table_pks = [
            pk[-1] for pk in primary_keys if pk[:-1] == explicit_table_path[1:]
        ]
        table_fks = {
            keys[-1]: ".".join(
                [
                    explicit_map[(ds_schema.name(), *values[:-1])][-1],
                    values[-1],
                ]
            )
            for keys, values in foreign_keys.items()
            if keys[:-1] == explicit_table_path[1:]
        }

        # fill metadata depending on schma type
        schema_type = (
            schema_table.data_type() if drop_admin else schema_table.type()
        )
        metadata.fill(
            schema_type,
            table_name,
            primary_keys=table_pks,
            foreign_keys=table_fks,
        )
        sqltable_to_table_ds[table_name] = table_ds

    metadata.create_all_tables()

    # Iterate for a second time to push sarus tables to sql.
    for table_name in metadata.sorted_tables():
        table_ds = sqltable_to_table_ds[table_name]
        iterator = await manager.async_to(
            t.cast(st.Dataset, table_ds),
            kind=t.AsyncIterator[pa.RecordBatch],
            drop_admin=drop_admin,
            batch_size=BATCH_SIZE,
        )
        iterator = t.cast(t.AsyncIterator[pa.RecordBatch], iterator)
        async for batch in iterator:
            if DATA in batch.schema.names:
                # Here we flatten the batch to have a struct
                # with PU_COLUMN, WEIGHTS, PUBLIC col1, col2 etc.
                data_arrays = batch.column(DATA).flatten()
                colnames = [
                    field.name.replace(DATA + ".", "")
                    for field in batch.field(DATA).flatten()
                ]
                for col in [PU_COLUMN, WEIGHTS, PUBLIC]:
                    data_arrays.append(batch.column(col))
                    colnames.append(batch.field(col).name)

                batch = pa.record_batch(data_arrays, names=colnames)
            if batch:
                metadata.push_tosql(table_name, batch)

    # Path (withoud DATA) -> (sql_schema, sql_tables)
    encoded_map = base64encode_table_map(table_map)
    expanded_map = expand_table_map(list(explicit_map.keys()))
    return pushed_uri, encoded_map, expanded_map


async def async_to_sql_op(
    dataset: st.Dataset,
) -> t.Tuple[str, t.Dict[str, t.Tuple[str, ...]], t.List[str]]:
    """It pushes the dataset to a SQL DB.
    - It retrives the caching_uri from the source status.
    - It creates a schema in the DB named according to the dataset uuid.
    - It iterates over the sarus tables to retrive the correct column
        types and associated constraints (pks, fks and nullable columns).
    - It iterates over the the sarus tables a second time to push the data.

    It returns:
        (pushed_uri, encoded table mapping, extended table mapping)
    """
    manager = dataset.manager()
    if dataset.is_source() and dataset.protobuf().spec.HasField("sql"):
        ds_schema = await manager.async_schema(dataset)
        tables = ds_schema.tables()
        # If it is a sql source we still need to attach a ready status
        # with a table mapping
        table_map = table_mapping(
            tables=tables,
            sarus_schema_name=ds_schema.name(),
            sql_schema=None,
        )
        explicit_map = explicit_table_map(table_map)
        pushed_uri = dataset.protobuf().spec.sql.uri
        # Path (withoud DATA) -> (sql_schema, sql_tables)
        encoded_map = base64encode_table_map(table_map)
        expanded_map = expand_table_map(list(explicit_map.keys()))
        return pushed_uri, encoded_map, expanded_map
    else:
        source_ds = dataset.sources(sp.type_name(sp.Dataset)).pop()
        source_st = manager.status(
            dataspec=t.cast(st.DataSpec, source_ds),
            task_name=TO_SQL_CACHING_TASK,
        )

        if source_st is None:
            raise ValueError(
                f"Missing {TO_SQL_CACHING_TASK} for dataset "
                f"{source_ds.uuid()}"
            )
        assert source_st
        source_st_task = source_st.task(task=TO_SQL_CACHING_TASK)
        if source_st_task is None:
            raise ValueError(
                f"Missing {TO_SQL_CACHING_TASK} task for dataset "
                f"{source_ds.uuid()}"
            )
        assert source_st_task

        pushed_uri_caching = source_st_task.properties().get(SQL_CACHING_URI)
        if pushed_uri_caching is None:
            raise ValueError(
                f"Source does not have a {SQL_CACHING_URI} value, can't use to sql."
            )
        pushed_uri = pushed_uri_caching
        sql_schema = dataset.manager().sql_pushing_schema_prefix(
            dataset
        ) + name_encoder(
            names=(dataset.uuid(),),
            length=SCHEMA_NAME_LENGTH,
        )
        # choose encoded table name
        table_names = None
        return await async_sql_op(
            dataset, pushed_uri, sql_schema, table_names, drop_admin=False
        )


async def async_push_sql_op(
    dataset: st.Dataset,
) -> t.Tuple[str, t.Dict[str, t.Tuple[str, ...]], t.List[str]]:
    """It pushes the dataset to a SQL DB.
    - It retrives the caching_uri from the source status.
    - It creates a schema in the DB named according to the dataset uuid.
    - It iterates over the sarus tables to retrive the correct column
        types and associated constraints (pks, fks and nullable columns).
    - It iterates over the the sarus tables a second time to push the data.

    It returns:
        (pushed_uri, encoded table mapping, extended table mapping)
    """
    if (
        dataset.is_transformed()
        and dataset.transform().protobuf().spec.HasField("push_sql")
    ):
        # for the transform push_sql, we have the uri
        pushed_uri = dataset.transform().protobuf().spec.push_sql.uri
        sql_schema = dataset.transform().protobuf().spec.push_sql.schema_name
        table_name = dataset.transform().protobuf().spec.push_sql.table_name
        table_names = [table_name]
        return await async_sql_op(
            dataset, pushed_uri, sql_schema, table_names, drop_admin=True
        )
    else:
        raise ValueError(
            f"Can't push dataset with transform {dataset.transform().name()} to an external database with push sql, please try with the transform push sql."
        )


class SarusMetadata:
    """Helper to abstract the interaction with different APIs used to pushing
    data to a DB. We use for most the time SQLAlchemy but is not always optimal
    when it comes to push data to a DB (as for instance with bigquery).
    """

    def __init__(self, engine: sa.engine.Engine, sql_schema_name: str) -> None:
        self.engine = engine
        self.sql_schema_name = sql_schema_name

        self.sa_metadata: sa.MetaData
        self.bq_client: t.Any
        self.tabname_to_tabmetadata: t.Dict[str, t.Any] = {}
        self._setup()

    def _setup(self) -> None:
        """It creates the schema if the schema doesn't exists with sqlalchemy.
        For bigquery, moreover, it retrieves the credentials and instancieate
        the bigquery client.
        """
        with self.engine.begin() as conn:
            if self.sql_schema_name not in conn.dialect.get_schema_names(conn):
                operation = sa.schema.CreateSchema(self.sql_schema_name)  # type: ignore
                conn.execute(operation)

        if self.engine.dialect.name == "bigquery":
            credentials_path = (
                self.engine.dialect.credentials_path  # type: ignore
            )
            credentials = (
                service_account.Credentials.from_service_account_file(
                    credentials_path
                )
            )
            self.bq_client = bigquery.Client(
                credentials=credentials,
                project=credentials.project_id,
            )

        self.sa_metadata = sa.MetaData(schema=self.sql_schema_name)

    def fill(
        self,
        schema_type: st.Type,
        table_name: str,
        primary_keys: t.List[str],
        foreign_keys: t.Dict[str, str],
    ) -> None:
        """It calls visitors to store information on column types. It should
        take a schema_type for a specific sarus table. Except for bigquery,
        we fill the sqlalchemy metadata with the a sqlalchemy Table containing
        columns with the proper datatype. For bigquery, we store the table data
        type information in a dict which will be used afterwards during pushing
        """
        if self.engine.dialect.name == "bigquery":
            # it stores the column types of the table
            bq_tabmetadata: t.Dict[str, t.Any] = {}
            to_bq_schema(
                schema_type,
                bq_tabmetadata,
                typemode=BQFieldType.REQUIRED.value,
            )
            self.tabname_to_tabmetadata[table_name] = bq_tabmetadata
        else:
            sa_tabmetadata = sa.Table(table_name, self.sa_metadata)
            to_sqlalchemy_metadata(
                schema_type,
                sa_tabmetadata,
                col_name=None,
                nullable=False,
                primary_keys=primary_keys,
                foreign_keys=foreign_keys,
            )
            self.tabname_to_tabmetadata[table_name] = sa_tabmetadata

    def create_all_tables(self) -> None:
        """Create tables for all methods that push data via sqlalchemy"""
        self.sa_metadata.create_all(self.engine)

    def sorted_tables(self) -> t.List[str]:
        """Returns sorted table names to prevent pushing
        a table with a foreign key pointing towards a table which is not yet
        pushed. SqlAlchemy metadata does that internally while order doesn't
        matter for bigquery.
        """
        if self.engine.dialect.name == "bigquery":
            return list(self.tabname_to_tabmetadata.keys())
        else:
            return [tab.name for tab in self.sa_metadata.sorted_tables]

    def push_tosql(self, table_name: str, batch: pa.RecordBatch) -> None:
        """It pushes data to the SQL DB using the most efficient method
        depending on the DB.
        """
        if self.engine.dialect.name == "bigquery":
            bq_tabmetadata = self.tabname_to_tabmetadata[table_name]
            columns = list(bq_tabmetadata.keys())
            col_types = [bq_tabmetadata[colname] for colname in columns]
            job_config = bigquery.LoadJobConfig(
                source_format=bigquery.SourceFormat.PARQUET, schema=col_types
            )
            with tempfile.TemporaryDirectory() as temp_dir:
                parquet_file = os.path.join(temp_dir, "temp.parquet")
                # We write to parquet instead of passing a pandas dataframe
                # since python-bigquery would write to parquet anyways
                # and there are no methods to load directly a pyarrow object.
                # Link to loading methods:
                # https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.client.Client#google_cloud_bigquery_client_Client_load_table_from_dataframe
                pq.write_table(
                    table=pa.Table.from_batches([batch]),
                    where=parquet_file,
                    version="2.6",
                    coerce_timestamps="us",
                    allow_truncated_timestamps=True,
                )

                with open(parquet_file, "rb") as f:
                    job = self.bq_client.load_table_from_file(
                        file_obj=f,
                        destination=f"{self.sql_schema_name}.{table_name}",
                        job_config=job_config,
                    )
                    job.result()

        elif self.engine.dialect.name == "mssql":
            sa_tabmetadata = self.tabname_to_tabmetadata[table_name]
            mssql_insert(self.engine, batch, sa_tabmetadata)

        else:
            sa_tabmetadata = self.tabname_to_tabmetadata[table_name]
            default_insert(self.engine, batch, sa_tabmetadata)


def default_insert(
    engine: sa.engine.Engine, batch: pa.RecordBatch, sa_table: sa.Table
) -> None:
    """SQLALCHEMY CORE API for batch insert
    This should be pretty efficient and quite general:
    https://towardsdatascience.com/how-to-perform-bulk-inserts-with-sqlalchemy-efficiently-in-python-23044656b97d
    """
    with engine.begin() as conn:
        conn.execute(sa.insert(sa_table), batch.to_pylist())


def mssql_insert(
    engine: sa.engine.Engine, batch: pa.RecordBatch, sa_table: sa.Table
) -> None:
    """Method to push data on an SQL server. This is actually though when
    the server is Azure. This is equivalent to a parametrized INSERT query
    since pyodb seems to not support fast_executemany for AZURE.
    TODO: we should be able to apply this only for azure because for
    synapse which still uses mssql-pyodb the most efficient method is different
    """
    azure_max_parameters = 2100
    chunksize = azure_max_parameters // batch.num_columns - 1
    batch_pylist = batch.to_pylist()
    with engine.begin() as conn:
        for i in range(0, len(batch_pylist), chunksize):
            conn.execute(
                sa.insert(sa_table).values(batch_pylist[i : i + chunksize])
            )