Repository URL to install this package:
|
Version:
0.3.0.dev0 ▾
|
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Callable
from urllib.parse import urlparse
from daft.dependencies import pa
from daft.logical.schema import Schema
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
logger = logging.getLogger(__name__)
class SQLConnection:
def __init__(self, conn: str | Callable[[], Connection], driver: str, dialect: str, url: str) -> None:
self.conn = conn
self.dialect = dialect
self.driver = driver
self.url = url
def __repr__(self) -> str:
return f"SQLConnection(conn={self.conn})"
@classmethod
def from_url(cls, url: str) -> SQLConnection:
scheme = urlparse(url).scheme.strip().lower()
if "+" in scheme:
dialect, driver = scheme.split("+")
else:
dialect, driver = scheme, ""
return cls(url, driver, dialect, url)
@classmethod
def from_connection_factory(cls, conn_factory: Callable[[], Connection]) -> SQLConnection:
from sqlalchemy.engine import Connection
try:
with conn_factory() as connection:
if not isinstance(connection, Connection):
raise ValueError(
f"Connection factory must return a SQLAlchemy connection object, got: {type(connection)}"
)
dialect = connection.engine.dialect.name
driver = connection.engine.driver
url = connection.engine.url.render_as_string()
return cls(conn_factory, driver, dialect, url)
except Exception as e:
raise ValueError(f"Unexpected error while calling the connection factory: {e}") from e
def read_schema(self, sql: str, infer_schema_length: int) -> Schema:
if self._should_use_connectorx():
sql = self.construct_sql_query(sql, limit=0)
else:
sql = self.construct_sql_query(sql, limit=infer_schema_length)
table = self.execute_sql_query(sql)
schema = Schema.from_pyarrow_schema(table.schema)
return schema
def construct_sql_query(
self,
sql: str,
projection: list[str] | None = None,
predicate: str | None = None,
limit: int | None = None,
partition_bounds: tuple[str, str] | None = None,
) -> str:
import sqlglot
target_dialect = self.dialect
# sqlglot does not support "postgresql" dialect, it only supports "postgres"
if target_dialect == "postgresql":
target_dialect = "postgres"
# sqlglot does not recognize "mssql" as a dialect, it instead recognizes "tsql", which is the SQL dialect for Microsoft SQL Server
elif target_dialect == "mssql":
target_dialect = "tsql"
if not any(target_dialect == supported_dialect.value for supported_dialect in sqlglot.Dialects):
raise ValueError(
f"Unsupported dialect: {target_dialect}, please refer to the documentation for supported dialects."
)
query = sqlglot.subquery(sql, "subquery")
if projection is not None:
query = query.select(*projection)
else:
query = query.select("*")
if predicate is not None:
query = query.where(predicate)
if partition_bounds is not None:
query = query.where(partition_bounds[0]).where(partition_bounds[1])
if limit is not None:
query = query.limit(limit)
return query.sql(dialect=target_dialect)
def _should_use_connectorx(self) -> bool:
# Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources
connectorx_supported_dbs = {
"postgres",
"postgresql",
"mysql",
"mssql",
"oracle",
"bigquery",
"sqlite",
"clickhouse",
"redshift",
}
if isinstance(self.conn, str):
if self.dialect in connectorx_supported_dbs and self.driver == "":
return True
return False
def execute_sql_query(self, sql: str) -> pa.Table:
if self._should_use_connectorx():
return self._execute_sql_query_with_connectorx(sql)
else:
return self._execute_sql_query_with_sqlalchemy(sql)
def _execute_sql_query_with_connectorx(self, sql: str) -> pa.Table:
import connectorx as cx
assert isinstance(self.conn, str)
logger.info("Using connectorx to execute sql: %s", sql)
try:
table = cx.read_sql(conn=self.conn, query=sql, return_type="arrow")
return table
except Exception as e:
raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.conn}, error: {e}") from e
def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table:
from sqlalchemy import create_engine, text
logger.info("Using sqlalchemy to execute sql: %s", sql)
try:
if isinstance(self.conn, str):
with create_engine(self.conn).connect() as connection:
result = connection.execute(text(sql))
rows = result.fetchall()
else:
with self.conn() as connection:
result = connection.execute(text(sql))
rows = result.fetchall()
pydict = {column_name: [row[i] for row in rows] for i, column_name in enumerate(result.keys())}
# TODO: Use type codes from cursor description to create pyarrow schema
return pa.Table.from_pydict(pydict)
except Exception as e:
connection_str = self.conn if isinstance(self.conn, str) else self.conn.__name__
raise RuntimeError(f"Failed to execute sql: {sql} from connection: {connection_str}, error: {e}") from e