Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
from __future__ import annotations
from collections import defaultdict
import json
import typing as t
from sqlalchemy import MetaData, Table, event
import sqlalchemy
from sarus_data_spec.constants import (
DATASET_SLUGNAME,
FOREIGN_KEYS,
PRIMARY_KEYS,
)
from sarus_data_spec.manager.ops.source.sql.type import sql_to_sarus
from sarus_data_spec.path import Path, straight_path
from sarus_data_spec.protobuf.utilities import to_base64
from sarus_data_spec.schema import Schema
from sarus_data_spec.schema import schema as schema_builder
from sarus_data_spec.type import Struct, Type, Union
import sarus_data_spec.typing as st
def type_table(
table: Table,
multiple_schemas: bool,
current_fks: t.Dict[Path, Path],
accepted_tables: t.List[Path],
) -> t.Tuple[Type, t.List[Path], t.Dict[Path, Path]]:
"""Converts SQL schema of a table to Sarus Type as a struct and return
the list of Primary Keys and a Dict of the foreign keys Path with the
corresponding pointed column path.
"""
fields = {}
primary_keys = []
fk_paths = {}
if multiple_schemas:
for element in table.foreign_keys:
sql_schema = (
element.column.table.schema
if element.column.table.schema
else table.schema
)
if (
sql_schema
and straight_path(
nodes=[sql_schema, element.column.table.name]
)
in accepted_tables
):
element_name = element.parent.name
fk_paths[element_name] = straight_path(
nodes=[
sql_schema,
element.column.table.name,
element.column.name,
]
)
else:
fk_paths = {
element.parent.name: straight_path(
nodes=[element.column.table.name, element.column.name]
)
for element in table.foreign_keys
if straight_path(nodes=[element.column.table.name])
in accepted_tables
}
for column in table.c:
fields[column.name] = sql_to_sarus(
sql_column=column, foreign_keys=fk_paths
)
if column.primary_key:
if multiple_schemas:
primary_keys.append(
straight_path(
[
table.schema, # type:ignore
table.name,
column.name,
]
)
)
else:
primary_keys.append(straight_path([table.name, column.name]))
if multiple_schemas:
for key in table.foreign_keys:
sql_schema = (
key.column.table.schema
if key.column.table.schema
else table.schema
)
if sql_schema and (
straight_path(
nodes=[
sql_schema,
key.column.table.name,
]
)
in accepted_tables
):
assert table.schema is not None
fk_path = straight_path(
nodes=[
table.schema,
table.name,
key.parent.name,
]
)
pointing_path = straight_path(
nodes=[
sql_schema,
key.column.table.name,
key.column.name,
]
)
current_fks[fk_path] = pointing_path
else:
for key in table.foreign_keys:
if straight_path(nodes=[key.column.table.name]) in accepted_tables:
fk_path = straight_path(nodes=[table.name, key.parent.name])
pointing_path = straight_path(
nodes=[key.column.table.name, key.column.name]
)
current_fks[fk_path] = pointing_path
return (
Struct(fields=fields),
primary_keys,
current_fks,
)
def info_from_single_schema(
metadata: MetaData,
accepted_table_paths: t.List[Path],
tables_list: t.List[str],
) -> t.Tuple[st.Type, t.List[Path], t.Dict[Path, Path]]:
"""Given the SQL Metatadata, it iterates over all the tables
to return a global Union Type, a Dict of Primary Keys and
a dict of Foreign Keys.
"""
fields = {}
primary_keys: t.List[Path] = []
foreign_keys: t.Dict[Path, Path] = {}
for table in tables_list:
field, new_primary_keys, foreign_keys = type_table(
metadata.tables[table],
multiple_schemas=False,
current_fks=foreign_keys,
accepted_tables=accepted_table_paths,
)
primary_keys.extend(new_primary_keys)
fields[table] = field
return (
Union(
fields,
),
primary_keys,
foreign_keys,
)
def info_from_multiples_schemas(
metadata: MetaData, accepted_tables: t.List[Path], tables_list: t.List[str]
) -> t.Tuple[st.Type, t.List[Path], t.Dict[Path, Path]]:
"""Given the SQL Metatadata, it iterates over all the tables
to return a global Union of Union Type, a Dict of Primary Keys and
a dict of Foreign Keys.
"""
fields: t.DefaultDict[str, t.Dict] = defaultdict(dict)
primary_keys = []
foreign_keys: t.Dict[Path, Path] = {}
for table in tables_list:
field, new_primary_keys, foreign_keys = type_table(
table=metadata.tables[table],
multiple_schemas=True,
current_fks=foreign_keys,
accepted_tables=accepted_tables,
)
primary_keys.extend(new_primary_keys)
fields[t.cast(str, metadata.tables[table].schema)][
metadata.tables[table].name
] = field
schema_types = {}
for schema_name, schema_fields in fields.items():
schema_types[schema_name] = Union(
schema_fields,
name=t.cast(str, schema_name),
)
return (
Union(schema_types),
primary_keys,
foreign_keys,
)
async def sql_schema(
dataset: st.Dataset, engine: sqlalchemy.engine.Engine
) -> Schema:
"""Computes the sarus schema for a sql dataset"""
metadata = MetaData()
# https://docs.sqlalchemy.org/en/14/core/reflection.html
@event.listens_for(metadata, "column_reflect")
def genericize_datatypes(
inspector: sqlalchemy.engine.reflection.Inspector,
tablename: str,
column_dict: t.Dict,
) -> None:
column_dict["type"] = column_dict["type"].as_generic()
tables_dict: t.Dict[str, t.List[str]] = defaultdict(list)
for table in dataset.protobuf().spec.sql.tables:
tables_dict[table.schema].append(table.table)
tables_dict = dict(tables_dict)
# we need to reflect all tables linked to FKs, otherwise FK detection fails
if len(tables_dict) == 1 and "" in tables_dict:
metadata.reflect(
bind=engine, only=tables_dict[""], resolve_fks=True, views=True
)
accepted_table_paths = [
straight_path(nodes=[table])
for schema, table_list in tables_dict.items()
for table in table_list
]
(
union_type,
primary_keys,
foreign_keys,
) = info_from_single_schema(
metadata=metadata,
accepted_table_paths=accepted_table_paths,
tables_list=tables_dict[""],
)
else:
accepted_table_paths = [
straight_path(nodes=[schema, table])
for schema, table_list in tables_dict.items()
for table in table_list
]
for schema, tables_list in tables_dict.items():
metadata.reflect(
bind=engine,
schema=schema,
only=tables_list,
resolve_fks=True,
views=True,
)
(
union_type,
primary_keys,
foreign_keys,
) = info_from_multiples_schemas(
metadata=metadata,
accepted_tables=accepted_table_paths,
tables_list=[
f"{schema}.{table}"
for schema, table_list in tables_dict.items()
for table in table_list
],
)
tables_list = (
tables_dict[""]
if (len(tables_dict) == 1 and "" in tables_dict)
else [
f"{schema}.{table}"
for schema, table_list in tables_dict.items()
for table in table_list
]
)
enforce_no_empty_tables(
metadata=metadata, engine=engine, tables_list=tables_list
)
properties = {}
properties[PRIMARY_KEYS] = json.dumps(
[to_base64(key.protobuf()) for key in primary_keys]
)
properties[FOREIGN_KEYS] = json.dumps(
{
to_base64(pointing.protobuf()): to_base64(pointed.protobuf())
for pointing, pointed in foreign_keys.items()
}
)
return schema_builder(
dataset=dataset,
schema_type=union_type,
properties=properties,
name=dataset.properties().get(DATASET_SLUGNAME, None),
)
async def _sql_schema(dataset: st.Dataset) -> Schema:
uri = dataset.protobuf().spec.sql.uri
engine = dataset.manager().engine(uri=uri)
return await sql_schema(dataset, engine)
def enforce_no_empty_tables(
metadata: MetaData,
engine: sqlalchemy.engine.Engine,
tables_list: t.List[str],
) -> None:
for table in tables_list:
query = sqlalchemy.select(
*(col for col in metadata.tables[table].c)
).limit(1)
with engine.begin() as conn:
first_line = conn.execute(query).first()
if first_line is None:
raise EmptyTableError(f"Found table {table} without rows")
class EmptyTableError(ValueError):
pass