Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import signal
import sys
import threading
import pytest
from dask.datasets import timeseries
dd = pytest.importorskip("dask.dataframe")
pyspark = pytest.importorskip("pyspark")
pytest.importorskip("pyarrow")
pytest.importorskip("fastparquet")
from dask.dataframe.utils import assert_eq
pytestmark = pytest.mark.skipif(
sys.platform != "linux",
reason="Unnecessary, and hard to get spark working on non-linux platforms",
)
# pyspark auto-converts timezones -- round-tripping timestamps is easier if
# we set everything to UTC.
pdf = timeseries(freq="1H").compute()
pdf.index = pdf.index.tz_localize("UTC")
pdf = pdf.reset_index()
@pytest.fixture(scope="module")
def spark_session():
# Spark registers a global signal handler that can cause problems elsewhere
# in the test suite. In particular, the handler fails if the spark session
# is stopped (a bug in pyspark).
prev = signal.getsignal(signal.SIGINT)
# Create a spark session. Note that we set the timezone to UTC to avoid
# conversion to local time when reading parquet files.
spark = (
pyspark.sql.SparkSession.builder.master("local")
.appName("Dask Testing")
.config("spark.sql.session.timeZone", "UTC")
.getOrCreate()
)
yield spark
spark.stop()
# Make sure we get rid of the signal once we leave stop the session.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, prev)
@pytest.mark.parametrize("npartitions", (1, 5, 10))
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_parquet_spark_to_dask(spark_session, npartitions, tmpdir, engine):
tmpdir = str(tmpdir)
sdf = spark_session.createDataFrame(pdf)
# We are not overwriting any data, but spark complains if the directory
# already exists (as tmpdir does) and we don't set overwrite
sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite")
ddf = dd.read_parquet(tmpdir, engine=engine)
# Papercut: pandas TZ localization doesn't survive roundtrip
ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC"))
assert ddf.npartitions == npartitions
assert_eq(ddf, pdf, check_index=False)
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_hive_parquet_spark_to_dask(spark_session, tmpdir, engine):
tmpdir = str(tmpdir)
sdf = spark_session.createDataFrame(pdf)
# not overwriting any data, but spark complains if the directory
# already exists and we don't set overwrite
sdf.write.parquet(tmpdir, mode="overwrite", partitionBy="name")
ddf = dd.read_parquet(tmpdir, engine=engine)
# Papercut: pandas TZ localization doesn't survive roundtrip
ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC"))
# Partitioning can change the column order. This is mostly okay,
# but we sort them here to ease comparison
ddf = ddf.compute().sort_index(axis=1)
# Dask automatically converts hive-partitioned columns to categories.
# This is fine, but convert back to strings for comparison.
ddf = ddf.assign(name=ddf.name.astype("str"))
assert_eq(ddf, pdf.sort_index(axis=1), check_index=False)
@pytest.mark.parametrize("npartitions", (1, 5, 10))
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_parquet_dask_to_spark(spark_session, npartitions, tmpdir, engine):
tmpdir = str(tmpdir)
ddf = dd.from_pandas(pdf, npartitions=npartitions)
# Papercut: https://github.com/dask/fastparquet/issues/646#issuecomment-885614324
kwargs = {"times": "int96"} if engine == "fastparquet" else {}
ddf.to_parquet(tmpdir, engine=engine, write_index=False, **kwargs)
sdf = spark_session.read.parquet(tmpdir)
sdf = sdf.toPandas()
# Papercut: pandas TZ localization doesn't survive roundtrip
sdf = sdf.assign(timestamp=sdf.timestamp.dt.tz_localize("UTC"))
assert_eq(sdf, ddf, check_index=False)