Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
from decimal import Decimal
import json
import logging
import typing as t
import warnings
from cachetools import LRUCache
from fsspec.implementations.http import HTTPFileSystem
from fsspec.implementations.local import LocalFileSystem
from sarus_data_spec.manager.cache_utils import lru_caching
from sarus_data_spec.manager.ops.sql_utils.bigdata import (
sqlalchemy_query_to_string,
)
from sarus_data_spec.manager.ops.sql_utils.schema_translations import (
sa_metadata_from_dataset,
)
from sarus_data_spec.manager.ops.sql_utils.type_mapping import (
CursorDescription,
rename_columns_if_ambigous,
sarus_types_from_sql_description,
)
import sarus_data_spec.type as sdt
try:
import pyqrlew as pyqrl
from sarus_data_spec.manager.ops.sql_utils.pyqrlew_utils import (
pyqrlew_dataset,
compose_and_translate_query,
compose_and_translate_queries,
)
from sarus_data_spec.manager.ops.sql_utils.queries import (
flatten_queries_dict,
nested_dict_of_types,
nested_unions_from_nested_dict_of_types,
)
except ModuleNotFoundError:
warnings.warn("Pyqrlew not installed. Can't process SQL queries")
from sarus_data_spec.arrow.type import to_arrow
from sarus_data_spec.constants import (
DATA,
SQL_CACHING_URI,
SQLALCHEMY_DIALECT_MAP,
TABLE_MAPPING,
)
from sarus_data_spec.path import straight_path
import subprocess
try:
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy import text
import sqlalchemy as sa
except ModuleNotFoundError:
warnings.warn("SQLAlchemy not available.")
import fsspec
import numpy as np
import pyarrow as pa
from sarus_data_spec import typing as st
from sarus_data_spec.manager.base import Base
from sarus_data_spec.manager.ops.base import _ensure_batch_correct
from sarus_data_spec.manager.ops.processor.routing import (
TransformedDataset,
TransformedScalar,
)
from sarus_data_spec.manager.ops.source.routing import (
SourceScalar,
source_dataset_to_arrow,
)
from sarus_data_spec.manager.typing import Computation
from sarus_data_spec.path import Path
import sarus_data_spec.protobuf as sp
import sarus_data_spec.storage.typing as storage_typing
logger = logging.getLogger(__name__)
class PrivateManager(Base):
"""Base class for managers that we do not want to
expose outside of Sarus and share some common methods,
in particular, it assesses when caching is done"""
def __init__(
self, storage: storage_typing.Storage, protobuf: sp.Manager
) -> None:
super().__init__(storage, protobuf)
self._caches: t.Dict[str, t.Any] = {
"sql_storage": LRUCache(maxsize=512),
"computation": LRUCache(maxsize=32),
"values": LRUCache(maxsize=4),
"statistics": LRUCache(maxsize=3),
}
# to define in subclasses, computations for DP attributes:
self._sql_computation: Computation[t.AsyncIterator[pa.RecordBatch]]
self.to_sql_computation: Computation[t.Mapping[str, str]]
self.push_sql_computation: Computation[t.Mapping[str, str]]
self.multiplicity_computation: Computation[st.Multiplicity]
self.size_computation: Computation[st.Size]
self.bounds_computation: Computation[st.Bounds]
self.marginals_computation: Computation[st.Marginals]
self.links_computation: Computation[st.Links]
def dataspec_computation(
self,
dataspec: st.DataSpec,
) -> Computation:
"""Return the computation for a DataSpec."""
proto = dataspec.prototype()
if proto == sp.Dataset:
return self.to_arrow_computation
return self.value_computation
def sql_computation(self) -> Computation:
"""Returns the SQL computation."""
return self._sql_computation
def caches(self) -> t.Any:
return self._caches
@lru_caching("values", use_first_arg=False)
async def async_to(
self,
dataset: st.Dataset,
kind: t.Type,
drop_admin: bool = True,
batch_size: t.Optional[int] = None,
) -> st.DatasetCastable:
return await super().async_to(dataset, kind, drop_admin, batch_size)
async def async_to_arrow_op(
self, dataset: st.Dataset, batch_size: int
) -> t.AsyncIterator[pa.RecordBatch]:
"""Op that enables routing to compute the arrow iterator.
This method is shared because when the data is not
cached to parquet, even the Api manager should be
able to stream its content.
"""
if dataset.is_transformed():
iterator = await TransformedDataset(dataset).to_arrow(
batch_size=batch_size
)
return iterator
elif dataset.is_source():
iterator = await source_dataset_to_arrow(
dataset, batch_size=batch_size
)
return iterator
else:
raise ValueError("Dataset is either transformed or source")
async def execute_sql_query(
self,
dataset: st.Dataset,
caching_properties: t.Mapping[str, str],
query: t.Union[str, st.NestedQueryDict],
dialect: t.Optional[st.SQLDialect] = None,
batch_size: int = 10000,
result_type: t.Optional[st.Type] = None,
) -> t.AsyncIterator[pa.RecordBatch]:
"""It executes the query to DB"""
# we called it pushed_uri because it can be either the caching_uri
# if the dataset is not source or the source uri otherwise.
pushed_uri = caching_properties[SQL_CACHING_URI]
table_mapping = {
t.cast(st.Path, Path(sp.utilities.from_base64(key))): value
for key, value in json.loads(
caching_properties[TABLE_MAPPING]
).items()
}
if dialect is None:
dialect = st.SQLDialect.POSTGRES
engine = self.engine(pushed_uri)
destination_dialect = SQLALCHEMY_DIALECT_MAP[engine.dialect.name]
schema = await dataset.manager().async_schema(dataset)
pyqrl_ds = await pyqrlew_dataset(dataset, str(schema))
renaming_relations = renaming_relations_from_table_mapping(
dataset, pyqrl_ds, schema.name(), table_mapping
)
if isinstance(query, str):
new_query = compose_and_translate_query(
query,
pyqrl_ds,
renaming_relations,
(dialect, destination_dialect),
)
if result_type is None:
iterator = sa_results_to_batcharray(
result_type,
new_query,
batch_size,
engine,
destination_dialect,
)
else:
query_dict: t.Dict[t.Tuple[str, ...], str] = {(): new_query}
iterator = iterator_from_queries_type(
queries=query_dict,
result_type=result_type,
engine=engine,
batch_size=batch_size,
)
else:
updated_queries = compose_and_translate_queries(
query,
pyqrl_ds,
renaming_relations,
(dialect, destination_dialect),
)
# Rewrite queries dict as : dict(Tuple[str, ...]: query)
flatten_queries = flatten_queries_dict(updated_queries)
if result_type is None:
types: t.Dict[t.Tuple[str, ...], st.Type] = {}
for path_tuple, query in flatten_queries.items():
with engine.begin() as conn:
results = conn.execution_options(yield_per=1).execute(
text(query)
)
types[path_tuple] = sarus_types_from_sql_description(
t.cast(
CursorDescription, results.cursor.description
),
destination_dialect,
)
nested_result_types = nested_dict_of_types(types)
result_type = sdt.Union(
nested_unions_from_nested_dict_of_types(
nested_result_types
)
)
iterator = iterator_from_queries_type(
queries=flatten_queries,
result_type=result_type,
engine=engine,
batch_size=batch_size,
)
# now ensure batches have right size
async def identity(x: pa.Array) -> pa.Array:
return x
return _ensure_batch_correct(
async_iterator=iterator,
func_to_apply=identity,
batch_size=batch_size,
)
async def async_sql_op(
self,
dataset: st.Dataset,
query: t.Union[str, t.Dict[str, t.Any]],
dialect: t.Optional[st.SQLDialect] = None,
batch_size: int = 10000,
result_type: t.Optional[st.Type] = None,
) -> t.AsyncIterator[pa.RecordBatch]:
"""SQL routing. Pass the query to the parents"""
if dataset.is_source():
raise ValueError(
"Source dataset has no parents to pass the query",
)
return await TransformedDataset(dataset).sql(
query, dialect, batch_size, result_type
)
async def async_to_sql(self, dataset: st.Dataset) -> None:
await self.to_sql_computation.complete_task(dataspec=dataset)
async def async_push_sql(self, dataset: st.Dataset) -> None:
await self.push_sql_computation.complete_task(dataspec=dataset)
async def async_value_op(self, scalar: st.Scalar) -> t.Any:
"""Route a Scalar to its Op implementation.
This method is shared between API and Worker because when the data is
not cached the API manager should also be able to compute the value.
"""
if scalar.is_transformed():
return await TransformedScalar(scalar).value()
else:
return await SourceScalar(scalar).value()
def is_cached(self, dataspec: st.DataSpec) -> bool:
"""Sets whether a dataset should be cached or not"""
proto = dataspec.prototype()
if proto == sp.Dataset:
dataset = t.cast(st.Dataset, dataspec)
if dataset.is_source() or self.is_big_data(dataset):
return False
if dataset.transform().spec() in [
"user_settings",
"privacy_unit_tracking",
"synthetic",
"assign_budget",
"sample",
"differentiated_sample",
"external",
"generate_from_model",
"to_small_data",
"group_by_pe",
]:
return True
return False
else: # scalars
scalar = t.cast(st.Scalar, dataspec)
if scalar.is_transformed() and scalar.transform().name() in [
"automatic_user_settings",
"automatic_privacy_unit_tracking_paths",
"automatic_public_paths",
]:
return False
elif scalar.is_pretrained_model():
return False
return True
def is_cached_to_sql(self, dataspec: st.DataSpec) -> bool:
"""Sets whether a dataset should be pushed to an SQL database or not.
This mimic is_cached.
"""
proto = dataspec.prototype()
if proto == sp.Dataset:
dataset = t.cast(st.Dataset, dataspec)
if dataset.is_source():
return True
if (
dataset.transform().spec() in ["assign_budget", "synthetic"]
) and not self.is_big_data(dataset):
return True
return False
return False
def engine(self, uri: str) -> Engine:
"""The goal of this method is to abstract the use
of dataconnection in DOD to access data on the client
server. Dataspec manager is not aware of dataconnections
so it simply uses the uri"""
connect_args: t.Dict[str, int] = {}
if uri.startswith("sqlite"):
connect_args["check_same_thread"] = False
elif uri.startswith("mssql+pyodbc"):
connect_args["fast_executemany"] = True
connect_args["autocommit"] = True
elif uri.startswith("postgresql"):
connect_args["keepalives"] = 1
connect_args["keepalives_idle"] = 30
connect_args["keepalives_interval"] = 10
connect_args["keepalives_count"] = 5
return create_engine(url=uri, connect_args=connect_args, echo=False)
def filesystem(self, uri: str) -> fsspec.AbstractFileSystem:
"""This methods returns a fsspec used to read the
data. To be changed in DoD"""
if uri.startswith("http://") or uri.startswith("https://"):
return HTTPFileSystem()
if uri.startswith("file://"):
return LocalFileSystem()
raise NotImplementedError("Uri not implemented")
async def async_multiplicity(self, dataset: st.Dataset) -> st.Multiplicity:
check_attribute_can_be_computed(
dataset=dataset, attribute_name="Multiplicities"
)
return await self.multiplicity_computation.task_result(dataset)
async def async_size(self, dataset: st.Dataset) -> st.Size:
check_attribute_can_be_computed(
dataset=dataset, attribute_name="Sizes"
)
return await self.size_computation.task_result(dataset)
async def async_bounds(self, dataset: st.Dataset) -> st.Bounds:
check_attribute_can_be_computed(
dataset=dataset, attribute_name="Bounds"
)
return await self.bounds_computation.task_result(dataset)
async def async_marginals(self, dataset: st.Dataset) -> st.Marginals:
check_attribute_can_be_computed(
dataset=dataset, attribute_name="Marginals"
)
return await self.marginals_computation.task_result(dataset)
async def async_links(self, dataset: st.Dataset) -> t.Any:
return await self.links_computation.task_result(dataspec=dataset)
def launch_job(self, command: t.List[str], env: t.Dict[str, str]) -> None:
"""Method to launch a specific command in a subprocess"""
process = subprocess.Popen(
command, env=env, stderr=subprocess.PIPE, text=True
)
_, errs = process.communicate()
if process.returncode != 0:
raise RuntimeError(
f"Subprocess error occurred with command {command}: {errs}"
)
def check_attribute_can_be_computed(
dataset: st.Dataset, attribute_name: str
) -> None:
"""Verifies that a given DP attribute can be computed according to the
dataset transform. The rule is shared for sizes,bounds and marginals"""
if not dataset.is_transformed():
raise NotImplementedError(
f"{attribute_name} cannot be computed on source dataset"
)
if dataset.transform().spec() not in [
"assign_budget",
"filter",
"get_item",
"project",
"shuffle",
"differentiated_sample",
"sample",
"synthetic",
"external",
]:
raise NotImplementedError(
f"{attribute_name} cannot be computed for dataset"
f" transformed by {dataset.transform().spec()}"
)
async def iterator_from_queries_type(
result_type: st.Type,
queries: t.Dict[t.Tuple[str, ...], str],
batch_size: int,
engine: Engine,
) -> t.AsyncIterator[pa.Array]:
has_admin_cols = result_type.has_admin_columns()
admincol_data_type: t.Dict[str, pa.DataType] = {}
if has_admin_cols:
children = result_type.children()
data_type = children[DATA]
admin_cols = [col for col in children.keys() if col != DATA]
admincol_data_type = {
col: to_arrow(children[col], False) for col in admin_cols
}
else:
data_type = result_type
admin_cols = []
for path, query in queries.items():
# we suppose here that: the path: ("",) means that there
# is no table to select in data_type as for CSV.
if len(path) > 0 and path[0] == "":
path = path[1:]
column_names = list(
data_type.sub_types(straight_path([DATA, *path]))[0]
.children()
.keys()
)
arrow_data_type = to_arrow(
data_type.sub_types(straight_path([DATA, *path]))[0],
nullable=False,
)
with engine.begin() as conn:
result = conn.execution_options(yield_per=batch_size).execute(
text(query)
)
col_index_dict = result._metadata._key_to_index
for partition in result.partitions(batch_size):
columns_data: t.Dict[str, t.List[t.Any]] = {
col_name: [] for col_name in column_names
}
for line_ex in partition:
for col_name in column_names:
columns_data[col_name].append(
line_ex[col_index_dict[col_name]]
)
arrow_arrays = []
for col_name in column_names:
try:
as_pyarr = pa.array(columns_data[col_name])
except pa.ArrowInvalid as err:
warnings.warn(
"Encountered presumably a Decimal "
f"precision out of range error: {err} "
"Trying to convert in float via numpy"
)
as_pyarr = high_precision_decimals_to_float(
columns_data[col_name]
)
arrow_arrays.append(
pa.compute.cast(
as_pyarr,
arrow_data_type.field(col_name).type,
safe=False, # to allow for float to int conversion
)
)
arrow_array = pa.StructArray.from_arrays(
arrow_arrays, fields=list(arrow_data_type)
)
final_array = create_array_from_type(
_type=data_type, data=arrow_array, path=list(path)
)
if has_admin_cols:
# administrative columns must be added to the query in
# order to have them in the results. The order with which
# they are added matters.
admin_output = [
{
col_name: line_ex[col_index_dict[col_name]]
for col_name in admin_cols
}
for line_ex in partition
]
admin_array = pa.array(admin_output)
admin_is_nullable = [
children[col].protobuf().HasField("optional")
for col in admin_cols
]
admin_fields = [
pa.field(
col,
admincol_data_type[col],
nullable=admin_is_nullable[idx],
)
for idx, col in enumerate(admin_cols)
]
# I need to construct the array with the fields in order to
# have py struct with fields which are nullable as in the
# schema
final_array = pa.StructArray.from_arrays(
[
final_array,
*[
admin_array.field(col).cast(
admincol_data_type[col]
)
for col in admin_cols
],
],
fields=[
pa.field(DATA, final_array.type, nullable=False),
*admin_fields,
],
)
yield final_array
def create_array_from_type(
_type: st.Type, data: pa.Array, path: t.List[str]
) -> pa.Array:
"""concatenate together data with Null arrays in order
to create the _type expected"""
class ArrayCreator(st.TypeVisitor):
batch_array: pa.Array
def Union(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
arrays = []
curr_field = path.pop(0)
field_selected = []
for item_name, item_type in fields.items():
if item_name == curr_field:
arrays.append(
create_array_from_type(item_type, data, path.copy())
)
field_selected += [item_name] * len(data)
else:
array_type = to_arrow(_type=item_type)
arrays.append(pa.array([None] * len(data), array_type))
arrays.append(pa.array(field_selected, pa.large_string()))
names = list(fields.keys()) + ["field_selected"]
is_nullable = [True] * len(fields) + [False]
pa_fields = [
pa.field(name, arrays[idx].type, nullable=is_nullable[idx])
for idx, name in enumerate(names)
]
self.batch_array = pa.StructArray.from_arrays(
arrays, fields=pa_fields
)
def Struct(
self,
fields: t.Mapping[str, st.Type],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
self.batch_array = data
def Text(
self,
encoding: str,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Integer(
self,
min: int,
max: int,
base: st.IntegerBase,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Float(
self,
min: float,
max: float,
base: st.FloatBase,
possible_values: t.Iterable[float],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Datetime(
self,
format: str,
min: str,
max: str,
base: st.DatetimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Enum(
self,
name: str,
name_values: t.Sequence[t.Tuple[str, int]],
ordered: bool,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Optional(
self,
type: st.Type,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Null(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
raise NotImplementedError
def Unit(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
raise NotImplementedError
def Boolean(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
raise NotImplementedError
def Id(
self,
unique: bool,
reference: t.Optional[st.Path] = None,
base: t.Optional[st.IdBase] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Bytes(
self, properties: t.Optional[t.Mapping[str, str]] = None
) -> None:
raise NotImplementedError
def List(
self,
type: st.Type,
max_size: int,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Array(
self,
type: st.Type,
shape: t.Tuple[int, ...],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Time(
self,
format: str,
min: str,
max: str,
base: st.TimeBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Date(
self,
format: str,
min: str,
max: str,
base: st.DateBase,
possible_values: t.Iterable[str],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Duration(
self,
unit: str,
min: int,
max: int,
possible_values: t.Iterable[int],
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Constrained(
self,
type: st.Type,
constraint: st.Predicate,
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
def Hypothesis(
self,
*types: t.Tuple[st.Type, float],
name: t.Optional[str] = None,
properties: t.Optional[t.Mapping[str, str]] = None,
) -> None:
raise NotImplementedError
visitor = ArrayCreator()
_type.accept(visitor)
return visitor.batch_array
async def sa_results_to_batcharray(
result_type: t.Optional[st.Type],
query: str,
batch_size: int,
engine: Engine,
destination_dialect: st.SQLDialect,
) -> t.AsyncIterator[pa.Array]:
"""it executes the single query and constructs a
t.AsyncIterator[pa.RecordBatch] from sql query results.
If there are ambiguous columns they are renamed.
results are casted according to result_type. If result_type is none
the type will be deduced from the results.
"""
with engine.begin() as conn:
results = conn.execution_options(yield_per=batch_size).execute(
text(query)
)
if result_type is None:
try:
result_type = sarus_types_from_sql_description(
t.cast(CursorDescription, results.cursor.description),
destination_dialect,
)
except NotImplementedError as e:
# if types from sql_descriptions are not implemented
# we try to infer types from the results
warnings.warn(
f"Couldn't compute query results type due to {e}"
)
res_cols = list(results.keys())
arrow_type = None
if result_type is not None:
arrow_type = to_arrow(result_type)
for res_rows in results.partitions(size=batch_size):
renamed_cols = rename_columns_if_ambigous(res_cols)
if len(res_rows) == 0:
res_as_pydict: t.Dict[str, t.Any] = dict(
zip(renamed_cols, [[]] * len(renamed_cols))
)
elif len(res_rows) == 1:
res_as_pydict = {
key: [val] for key, val in zip(renamed_cols, *res_rows)
}
else:
res_as_pydict = dict(zip(renamed_cols, list(zip(*res_rows))))
arr_output = []
for name, col in res_as_pydict.items():
if arrow_type is not None:
try:
as_pyarr = pa.array(col)
except pa.ArrowInvalid as err:
warnings.warn(
"Encountered presumably a Decimal "
f"precision out of range error: {err} "
"Trying to convert in float via numpy"
)
as_pyarr = high_precision_decimals_to_float(col)
arr_output.append(
pa.compute.cast(as_pyarr, arrow_type[name].type)
)
else:
arr_output.append(pa.array(col))
final_array = pa.StructArray.from_arrays(
arr_output, names=renamed_cols
)
yield final_array
def high_precision_decimals_to_float(col_data: t.List[t.Any]) -> pa.Array:
"""It converts a list with Decimals into pyarrow converting first to
numpy float and then creating the arrow array. It is used when the result
of an sql query contains Decimal with precision higher than 76. In this
cases pyarrow fails to create an arrow array. col_data must be a non
empty list of Decimals.
"""
assert col_data
assert isinstance(col_data[0], Decimal)
np_array = np.array(col_data, dtype=np.float64)
np_array[np_array == np.inf] = np.finfo(np.float64).max
np_array[np_array == -np.inf] = np.finfo(np.float64).min
return pa.array(np_array)
def renaming_relations_from_table_mapping(
dataset: st.Dataset,
pyqrl_ds: pyqrl.Dataset,
schema_name: str,
table_mapping: t.Dict[st.Path, t.Sequence[str]],
) -> t.Sequence[t.Tuple[t.Tuple[str, ...], pyqrl.Relation]]:
"""It returns a list of tuple containing a path and a relation used to table
renaming.
e.g. given that table_mapping contains a sarus table path -> db table path
each element in the sequence will be:
(path, to, sarus, table), Relation
where the relation is a Table with the db table path.
"""
metadata = sa_metadata_from_dataset(dataset)
queries = []
real_to_sarus = {}
for key, val in table_mapping.items():
table_name = ".".join(
[f'"{name}"' for name in key.to_strings_list()[0][1:]]
)
if not table_name:
table_name = f'"{schema_name}"'
query = sa.select(metadata.tables[table_name]).alias(val[-1])
new_path_and_query = (
tuple([schema_name, *val]),
sqlalchemy_query_to_string(query),
)
real_to_sarus[tuple([schema_name, *val])] = tuple(
[schema_name, *key.to_strings_list()[0][1:]]
)
queries.append(new_path_and_query)
renaming_ds = pyqrl_ds.from_queries(queries)
return [
(real_to_sarus[tuple(path)], rel)
for path, rel in renaming_ds.relations()
]