Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
from __future__ import annotations

from collections import defaultdict
import json
import typing as t

from sqlalchemy import MetaData, Table, event
import sqlalchemy

from sarus_data_spec.constants import (
    DATASET_SLUGNAME,
    FOREIGN_KEYS,
    PRIMARY_KEYS,
)
from sarus_data_spec.manager.ops.source.sql.type import sql_to_sarus
from sarus_data_spec.path import Path, straight_path
from sarus_data_spec.protobuf.utilities import to_base64
from sarus_data_spec.schema import Schema
from sarus_data_spec.schema import schema as schema_builder
from sarus_data_spec.type import Struct, Type, Union
import sarus_data_spec.typing as st


def type_table(
    table: Table,
    multiple_schemas: bool,
    current_fks: t.Dict[Path, Path],
    accepted_tables: t.List[Path],
) -> t.Tuple[Type, t.List[Path], t.Dict[Path, Path]]:
    """Converts SQL schema of a table to Sarus Type as a struct and return
    the list of Primary Keys and a Dict of the foreign keys Path with the
    corresponding pointed column path.
    """
    fields = {}
    primary_keys = []
    fk_paths = {}
    if multiple_schemas:
        for element in table.foreign_keys:
            sql_schema = (
                element.column.table.schema
                if element.column.table.schema
                else table.schema
            )
            if (
                sql_schema
                and straight_path(
                    nodes=[sql_schema, element.column.table.name]
                )
                in accepted_tables
            ):
                element_name = element.parent.name
                fk_paths[element_name] = straight_path(
                    nodes=[
                        sql_schema,
                        element.column.table.name,
                        element.column.name,
                    ]
                )
    else:
        fk_paths = {
            element.parent.name: straight_path(
                nodes=[element.column.table.name, element.column.name]
            )
            for element in table.foreign_keys
            if straight_path(nodes=[element.column.table.name])
            in accepted_tables
        }

    for column in table.c:
        fields[column.name] = sql_to_sarus(
            sql_column=column, foreign_keys=fk_paths
        )
        if column.primary_key:
            if multiple_schemas:
                primary_keys.append(
                    straight_path(
                        [
                            table.schema,  # type:ignore
                            table.name,
                            column.name,
                        ]
                    )
                )
            else:
                primary_keys.append(straight_path([table.name, column.name]))
    if multiple_schemas:
        for key in table.foreign_keys:
            sql_schema = (
                key.column.table.schema
                if key.column.table.schema
                else table.schema
            )
            if sql_schema and (
                straight_path(
                    nodes=[
                        sql_schema,
                        key.column.table.name,
                    ]
                )
                in accepted_tables
            ):
                assert table.schema is not None
                fk_path = straight_path(
                    nodes=[
                        table.schema,
                        table.name,
                        key.parent.name,
                    ]
                )
                pointing_path = straight_path(
                    nodes=[
                        sql_schema,
                        key.column.table.name,
                        key.column.name,
                    ]
                )
                current_fks[fk_path] = pointing_path

    else:
        for key in table.foreign_keys:
            if straight_path(nodes=[key.column.table.name]) in accepted_tables:
                fk_path = straight_path(nodes=[table.name, key.parent.name])
                pointing_path = straight_path(
                    nodes=[key.column.table.name, key.column.name]
                )
                current_fks[fk_path] = pointing_path

    return (
        Struct(fields=fields),
        primary_keys,
        current_fks,
    )


def info_from_single_schema(
    metadata: MetaData,
    accepted_table_paths: t.List[Path],
    tables_list: t.List[str],
) -> t.Tuple[st.Type, t.List[Path], t.Dict[Path, Path]]:
    """Given the SQL Metatadata, it iterates over all the tables
    to return a global Union Type, a Dict of Primary Keys and
    a dict of Foreign Keys.
    """
    fields = {}
    primary_keys: t.List[Path] = []
    foreign_keys: t.Dict[Path, Path] = {}
    for table in tables_list:
        field, new_primary_keys, foreign_keys = type_table(
            metadata.tables[table],
            multiple_schemas=False,
            current_fks=foreign_keys,
            accepted_tables=accepted_table_paths,
        )
        primary_keys.extend(new_primary_keys)
        fields[table] = field
    return (
        Union(
            fields,
        ),
        primary_keys,
        foreign_keys,
    )


def info_from_multiples_schemas(
    metadata: MetaData, accepted_tables: t.List[Path], tables_list: t.List[str]
) -> t.Tuple[st.Type, t.List[Path], t.Dict[Path, Path]]:
    """Given the SQL Metatadata, it iterates over all the tables
    to return a global Union of Union Type, a Dict of Primary Keys and
    a dict of Foreign Keys.
    """
    fields: t.DefaultDict[str, t.Dict] = defaultdict(dict)
    primary_keys = []
    foreign_keys: t.Dict[Path, Path] = {}
    for table in tables_list:
        field, new_primary_keys, foreign_keys = type_table(
            table=metadata.tables[table],
            multiple_schemas=True,
            current_fks=foreign_keys,
            accepted_tables=accepted_tables,
        )
        primary_keys.extend(new_primary_keys)
        fields[t.cast(str, metadata.tables[table].schema)][
            metadata.tables[table].name
        ] = field

    schema_types = {}
    for schema_name, schema_fields in fields.items():
        schema_types[schema_name] = Union(
            schema_fields,
            name=t.cast(str, schema_name),
        )
    return (
        Union(schema_types),
        primary_keys,
        foreign_keys,
    )


async def sql_schema(
    dataset: st.Dataset, engine: sqlalchemy.engine.Engine
) -> Schema:
    """Computes the sarus schema for a sql dataset"""
    metadata = MetaData()
    # https://docs.sqlalchemy.org/en/14/core/reflection.html

    @event.listens_for(metadata, "column_reflect")
    def genericize_datatypes(
        inspector: sqlalchemy.engine.reflection.Inspector,
        tablename: str,
        column_dict: t.Dict,
    ) -> None:
        column_dict["type"] = column_dict["type"].as_generic()

    tables_dict: t.Dict[str, t.List[str]] = defaultdict(list)
    for table in dataset.protobuf().spec.sql.tables:
        tables_dict[table.schema].append(table.table)
    tables_dict = dict(tables_dict)
    # we need to reflect all tables linked to FKs, otherwise FK detection fails

    if len(tables_dict) == 1 and "" in tables_dict:
        metadata.reflect(
            bind=engine, only=tables_dict[""], resolve_fks=True, views=True
        )
        accepted_table_paths = [
            straight_path(nodes=[table])
            for schema, table_list in tables_dict.items()
            for table in table_list
        ]
        (
            union_type,
            primary_keys,
            foreign_keys,
        ) = info_from_single_schema(
            metadata=metadata,
            accepted_table_paths=accepted_table_paths,
            tables_list=tables_dict[""],
        )

    else:
        accepted_table_paths = [
            straight_path(nodes=[schema, table])
            for schema, table_list in tables_dict.items()
            for table in table_list
        ]
        for schema, tables_list in tables_dict.items():
            metadata.reflect(
                bind=engine,
                schema=schema,
                only=tables_list,
                resolve_fks=True,
                views=True,
            )
        (
            union_type,
            primary_keys,
            foreign_keys,
        ) = info_from_multiples_schemas(
            metadata=metadata,
            accepted_tables=accepted_table_paths,
            tables_list=[
                f"{schema}.{table}"
                for schema, table_list in tables_dict.items()
                for table in table_list
            ],
        )
    tables_list = (
        tables_dict[""]
        if (len(tables_dict) == 1 and "" in tables_dict)
        else [
            f"{schema}.{table}"
            for schema, table_list in tables_dict.items()
            for table in table_list
        ]
    )

    enforce_no_empty_tables(
        metadata=metadata, engine=engine, tables_list=tables_list
    )
    properties = {}
    properties[PRIMARY_KEYS] = json.dumps(
        [to_base64(key.protobuf()) for key in primary_keys]
    )
    properties[FOREIGN_KEYS] = json.dumps(
        {
            to_base64(pointing.protobuf()): to_base64(pointed.protobuf())
            for pointing, pointed in foreign_keys.items()
        }
    )
    return schema_builder(
        dataset=dataset,
        schema_type=union_type,
        properties=properties,
        name=dataset.properties().get(DATASET_SLUGNAME, None),
    )


async def _sql_schema(dataset: st.Dataset) -> Schema:
    uri = dataset.protobuf().spec.sql.uri
    engine = dataset.manager().engine(uri=uri)
    return await sql_schema(dataset, engine)


def enforce_no_empty_tables(
    metadata: MetaData,
    engine: sqlalchemy.engine.Engine,
    tables_list: t.List[str],
) -> None:
    for table in tables_list:
        query = sqlalchemy.select(
            *(col for col in metadata.tables[table].c)
        ).limit(1)
        with engine.begin() as conn:
            first_line = conn.execute(query).first()

        if first_line is None:
            raise EmptyTableError(f"Found table {table} without rows")


class EmptyTableError(ValueError):
    pass