Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import io
import sys
from contextlib import contextmanager
import pytest
# import dask
from dask.dataframe.io.sql import read_sql, read_sql_query, read_sql_table
from dask.dataframe.utils import PANDAS_GT_120, assert_eq
from dask.utils import tmpfile
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
pytest.importorskip("sqlalchemy")
pytest.importorskip("sqlite3")
np = pytest.importorskip("numpy")
if not PANDAS_GT_120:
pytestmark = pytest.mark.filterwarnings("ignore")
data = """
name,number,age,negish
Alice,0,33,-5
Bob,1,40,-3
Chris,2,22,3
Dora,3,16,5
Edith,4,53,0
Francis,5,30,0
Garreth,6,20,0
"""
df = pd.read_csv(io.StringIO(data), index_col="number")
@pytest.fixture
def db():
with tmpfile() as f:
uri = "sqlite:///%s" % f
df.to_sql("test", uri, index=True, if_exists="replace")
yield uri
def test_empty(db):
from sqlalchemy import Column, Integer, MetaData, Table, create_engine
with tmpfile() as f:
uri = "sqlite:///%s" % f
metadata = MetaData()
engine = create_engine(uri)
table = Table(
"empty_table",
metadata,
Column("id", Integer, primary_key=True),
Column("col2", Integer),
)
metadata.create_all(engine)
dask_df = read_sql_table(table.name, uri, index_col="id", npartitions=1)
assert dask_df.index.name == "id"
# The dtype of the empty result might no longer be as expected
# assert dask_df.col2.dtype == np.dtype("int64")
pd_dataframe = dask_df.compute()
assert pd_dataframe.empty is True
@pytest.mark.filterwarnings(
"ignore:The default dtype for empty Series " "will be 'object' instead of 'float64'"
)
@pytest.mark.parametrize("use_head", [True, False])
def test_single_column(db, use_head):
from sqlalchemy import Column, Integer, MetaData, Table, create_engine
with tmpfile() as f:
uri = "sqlite:///%s" % f
metadata = MetaData()
engine = create_engine(uri)
table = Table(
"single_column",
metadata,
Column("id", Integer, primary_key=True),
)
metadata.create_all(engine)
test_data = pd.DataFrame({"id": list(range(50))}).set_index("id")
test_data.to_sql(table.name, uri, index=True, if_exists="replace")
if use_head:
dask_df = read_sql_table(table.name, uri, index_col="id", npartitions=2)
else:
dask_df = read_sql_table(
table.name,
uri,
head_rows=0,
npartitions=2,
meta=test_data.iloc[:0],
index_col="id",
)
assert dask_df.index.name == "id"
assert dask_df.npartitions == 2
pd_dataframe = dask_df.compute()
assert_eq(test_data, pd_dataframe)
def test_passing_engine_as_uri_raises_helpful_error(db):
# https://github.com/dask/dask/issues/6473
from sqlalchemy import create_engine
df = pd.DataFrame([{"i": i, "s": str(i) * 2} for i in range(4)])
ddf = dd.from_pandas(df, npartitions=2)
with tmpfile() as f:
db = "sqlite:///%s" % f
engine = create_engine(db)
with pytest.raises(ValueError, match="Expected URI to be a string"):
ddf.to_sql("test", engine, if_exists="replace")
@pytest.mark.skip(
reason="Requires a postgres server. Sqlite does not support multiple schemas."
)
def test_empty_other_schema():
from sqlalchemy import DDL, Column, Integer, MetaData, Table, create_engine, event
# Database configurations.
pg_host = "localhost"
pg_port = "5432"
pg_user = "user"
pg_pass = "pass"
pg_db = "db"
db_url = f"postgresql://{pg_user}:{pg_pass}@{pg_host}:{pg_port}/{pg_db}"
# Create an empty table in a different schema.
table_name = "empty_table"
schema_name = "other_schema"
engine = create_engine(db_url)
metadata = MetaData()
table = Table(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("col2", Integer),
schema=schema_name,
)
# Create the schema and the table.
event.listen(
metadata, "before_create", DDL("CREATE SCHEMA IF NOT EXISTS %s" % schema_name)
)
metadata.create_all(engine)
# Read the empty table from the other schema.
dask_df = read_sql_table(
table.name, db_url, index_col="id", schema=table.schema, npartitions=1
)
# Validate that the retrieved table is empty.
assert dask_df.index.name == "id"
assert dask_df.col2.dtype == np.dtype("int64")
pd_dataframe = dask_df.compute()
assert pd_dataframe.empty is True
# Drop the schema and the table.
engine.execute("DROP SCHEMA IF EXISTS %s CASCADE" % schema_name)
def test_needs_rational(db):
import datetime
now = datetime.datetime.now()
d = datetime.timedelta(seconds=1)
df = pd.DataFrame(
{
"a": list("ghjkl"),
"b": [now + i * d for i in range(5)],
"c": [True, True, False, True, True],
}
)
df = pd.concat(
[
df,
pd.DataFrame(
[
{"a": "x", "b": now + d * 1000, "c": None},
{"a": None, "b": now + d * 1001, "c": None},
]
),
]
)
with tmpfile() as f:
uri = "sqlite:///%s" % f
df.to_sql("test", uri, index=False, if_exists="replace")
# one partition contains NULL
data = read_sql_table("test", uri, npartitions=2, index_col="b")
df2 = df.set_index("b")
assert_eq(data, df2.astype({"c": bool})) # bools are coerced
# one partition contains NULL, but big enough head
data = read_sql_table("test", uri, npartitions=2, index_col="b", head_rows=12)
df2 = df.set_index("b")
assert_eq(data, df2)
# empty partitions
data = read_sql_table("test", uri, npartitions=20, index_col="b")
part = data.get_partition(12).compute()
assert part.dtypes.tolist() == ["O", bool]
assert part.empty
df2 = df.set_index("b")
assert_eq(data, df2.astype({"c": bool}))
# explicit meta
data = read_sql_table("test", uri, npartitions=2, index_col="b", meta=df2[:0])
part = data.get_partition(1).compute()
assert part.dtypes.tolist() == ["O", "O"]
df2 = df.set_index("b")
assert_eq(data, df2)
def test_simple(db):
# single chunk
data = read_sql_table("test", db, npartitions=2, index_col="number").compute()
assert (data.name == df.name).all()
assert data.index.name == "number"
assert_eq(data, df)
def test_npartitions(db):
data = read_sql_table(
"test", db, columns=list(df.columns), npartitions=2, index_col="number"
)
assert len(data.divisions) == 3
assert (data.name.compute() == df.name).all()
data = read_sql_table(
"test", db, columns=["name"], npartitions=6, index_col="number"
)
assert_eq(data, df[["name"]])
data = read_sql_table(
"test",
db,
columns=list(df.columns),
bytes_per_chunk="2 GiB",
index_col="number",
)
assert data.npartitions == 1
assert (data.name.compute() == df.name).all()
data_1 = read_sql_table(
"test",
db,
columns=list(df.columns),
bytes_per_chunk=2**30,
index_col="number",
head_rows=1,
)
assert data_1.npartitions == 1
assert (data_1.name.compute() == df.name).all()
data = read_sql_table(
"test",
db,
columns=list(df.columns),
bytes_per_chunk=250,
index_col="number",
head_rows=1,
)
assert data.npartitions == 2
def test_divisions(db):
data = read_sql_table(
"test", db, columns=["name"], divisions=[0, 2, 4], index_col="number"
)
assert data.divisions == (0, 2, 4)
assert data.index.max().compute() == 4
assert_eq(data, df[["name"]][df.index <= 4])
def test_division_or_partition(db):
with pytest.raises(TypeError):
read_sql_table(
"test",
db,
columns=["name"],
index_col="number",
divisions=[0, 2, 4],
npartitions=3,
)
out = read_sql_table("test", db, index_col="number", bytes_per_chunk=100)
m = out.map_partitions(
lambda d: d.memory_usage(deep=True, index=True).sum()
).compute()
assert (50 < m).all() and (m < 200).all()
assert_eq(out, df)
def test_meta(db):
data = read_sql_table(
"test", db, index_col="number", meta=dd.from_pandas(df, npartitions=1)
).compute()
assert (data.name == df.name).all()
assert data.index.name == "number"
assert_eq(data, df)
def test_meta_no_head_rows(db):
data = read_sql_table(
"test",
db,
index_col="number",
meta=dd.from_pandas(df, npartitions=1),
npartitions=2,
head_rows=0,
)
assert len(data.divisions) == 3
data = data.compute()
assert (data.name == df.name).all()
assert data.index.name == "number"
assert_eq(data, df)
data = read_sql_table(
"test",
db,
index_col="number",
meta=dd.from_pandas(df, npartitions=1),
divisions=[0, 3, 6],
head_rows=0,
)
assert len(data.divisions) == 3
data = data.compute()
assert (data.name == df.name).all()
assert data.index.name == "number"
assert_eq(data, df)
def test_no_meta_no_head_rows(db):
with pytest.raises(ValueError):
read_sql_table("test", db, index_col="number", head_rows=0, npartitions=1)
def test_limits(db):
data = read_sql_table("test", db, npartitions=2, index_col="number", limits=[1, 4])
assert data.index.min().compute() == 1
assert data.index.max().compute() == 4
def test_datetimes():
import datetime
now = datetime.datetime.now()
d = datetime.timedelta(seconds=1)
df = pd.DataFrame(
{"a": list("ghjkl"), "b": [now + i * d for i in range(2, -3, -1)]}
)
with tmpfile() as f:
uri = "sqlite:///%s" % f
df.to_sql("test", uri, index=False, if_exists="replace")
data = read_sql_table("test", uri, npartitions=2, index_col="b")
assert data.index.dtype.kind == "M"
assert data.divisions[0] == df.b.min()
df2 = df.set_index("b")
assert_eq(data.map_partitions(lambda x: x.sort_index()), df2.sort_index())
def test_extra_connection_engine_keywords(caplog, db):
data = read_sql_table(
"test", db, npartitions=2, index_col="number", engine_kwargs={"echo": False}
).compute()
# no captured message from the stdout with the echo=False parameter (this is the default)
out = "\n".join(r.message for r in caplog.records)
assert out == ""
assert_eq(data, df)
# with the echo=True sqlalchemy parameter, you should get all SQL queries in the stdout
data = read_sql_table(
"test", db, npartitions=2, index_col="number", engine_kwargs={"echo": True}
).compute()
out = "\n".join(r.message for r in caplog.records)
assert "WHERE" in out
assert "FROM" in out
assert "SELECT" in out
assert "AND" in out
assert ">= ?" in out
assert "< ?" in out
assert "<= ?" in out
assert_eq(data, df)
def test_query(db):
import sqlalchemy as sa
from sqlalchemy import sql
s1 = sql.select([sql.column("number"), sql.column("name")]).select_from(
sql.table("test")
)
out = read_sql_query(s1, db, npartitions=2, index_col="number")
assert_eq(out, df[["name"]])
s2 = (
sql.select(
[
sa.cast(sql.column("number"), sa.types.BigInteger).label("number"),
sql.column("name"),
]
)
.where(sql.column("number") >= 5)
.select_from(sql.table("test"))
)
out = read_sql_query(s2, db, npartitions=2, index_col="number")
assert_eq(out, df.loc[5:, ["name"]])
def test_query_index_from_query(db):
from sqlalchemy import sql
number = sql.column("number")
name = sql.column("name")
s1 = sql.select([number, name, sql.func.length(name).label("lenname")]).select_from(
sql.table("test")
)
out = read_sql_query(s1, db, npartitions=2, index_col="lenname")
lenname_df = df.copy()
lenname_df["lenname"] = lenname_df["name"].str.len()
lenname_df = lenname_df.reset_index().set_index("lenname")
assert_eq(out, lenname_df.loc[:, ["number", "name"]])
def test_query_with_meta(db):
from sqlalchemy import sql
data = {
"name": pd.Series([], name="name", dtype="str"),
"age": pd.Series([], name="age", dtype="int"),
}
index = pd.Index([], name="number", dtype="int")
meta = pd.DataFrame(data, index=index)
s1 = sql.select(
[sql.column("number"), sql.column("name"), sql.column("age")]
).select_from(sql.table("test"))
out = read_sql_query(s1, db, npartitions=2, index_col="number", meta=meta)
# Don't check dtype for windows https://github.com/dask/dask/issues/8620
assert_eq(out, df[["name", "age"]], check_dtype=sys.platform != "win32")
def test_no_character_index_without_divisions(db):
# attempt to read the sql table with a character index and no divisions
with pytest.raises(TypeError):
read_sql_table("test", db, npartitions=2, index_col="name", divisions=None)
def test_read_sql(db):
from sqlalchemy import sql
s = sql.select([sql.column("number"), sql.column("name")]).select_from(
sql.table("test")
)
out = read_sql(s, db, npartitions=2, index_col="number")
assert_eq(out, df[["name"]])
data = read_sql_table("test", db, npartitions=2, index_col="number").compute()
assert (data.name == df.name).all()
assert data.index.name == "number"
assert_eq(data, df)
@contextmanager
def tmp_db_uri():
with tmpfile() as f:
yield "sqlite:///%s" % f
@pytest.mark.parametrize("npartitions", (1, 2))
@pytest.mark.parametrize("parallel", (False, True))
def test_to_sql(npartitions, parallel):
df_by_age = df.set_index("age")
df_appended = pd.concat(
[
df,
df,
]
)
ddf = dd.from_pandas(df, npartitions)
ddf_by_age = ddf.set_index("age")
# Simple round trip test: use existing "number" index_col
with tmp_db_uri() as uri:
ddf.to_sql("test", uri, parallel=parallel)
result = read_sql_table("test", uri, "number")
assert_eq(df, result)
# Test writing no index, and reading back in with one of the other columns as index (`read_sql_table` requires
# an index_col)
with tmp_db_uri() as uri:
ddf.to_sql("test", uri, parallel=parallel, index=False)
result = read_sql_table("test", uri, "negish")
assert_eq(df.set_index("negish"), result)
result = read_sql_table("test", uri, "age")
assert_eq(df_by_age, result)
# Index by "age" instead
with tmp_db_uri() as uri:
ddf_by_age.to_sql("test", uri, parallel=parallel)
result = read_sql_table("test", uri, "age")
assert_eq(df_by_age, result)
# Index column can't have "object" dtype if no partitions are provided
with tmp_db_uri() as uri:
ddf.set_index("name").to_sql("test", uri)
with pytest.raises(
TypeError,
match='Provided index column is of type "object". If divisions is not provided the index column type must be numeric or datetime.', # noqa: E501
):
read_sql_table("test", uri, "name")
# Test various "if_exists" values
with tmp_db_uri() as uri:
ddf.to_sql("test", uri)
# Writing a table that already exists fails
with pytest.raises(ValueError, match="Table 'test' already exists"):
ddf.to_sql("test", uri)
ddf.to_sql("test", uri, parallel=parallel, if_exists="append")
result = read_sql_table("test", uri, "number")
assert_eq(df_appended, result)
ddf_by_age.to_sql("test", uri, parallel=parallel, if_exists="replace")
result = read_sql_table("test", uri, "age")
assert_eq(df_by_age, result)
# Verify number of partitions returned, when compute=False
with tmp_db_uri() as uri:
result = ddf.to_sql("test", uri, parallel=parallel, compute=False)
# the first result is from the "meta" insert
actual = len(result.compute())
assert actual == npartitions
def test_to_sql_kwargs():
ddf = dd.from_pandas(df, 2)
with tmp_db_uri() as uri:
ddf.to_sql("test", uri, method="multi")
with pytest.raises(
TypeError, match="to_sql\\(\\) got an unexpected keyword argument 'unknown'"
):
ddf.to_sql("test", uri, unknown=None)
def test_to_sql_engine_kwargs(caplog):
ddf = dd.from_pandas(df, 2)
with tmp_db_uri() as uri:
ddf.to_sql("test", uri, engine_kwargs={"echo": False})
logs = "\n".join(r.message for r in caplog.records)
assert logs == ""
assert_eq(df, read_sql_table("test", uri, "number"))
with tmp_db_uri() as uri:
ddf.to_sql("test", uri, engine_kwargs={"echo": True})
logs = "\n".join(r.message for r in caplog.records)
assert "CREATE" in logs
assert "INSERT" in logs
assert_eq(df, read_sql_table("test", uri, "number"))