Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
dask / dask / tests / test_spark_compat.py
Size: Mime:
import signal
import sys
import threading

import pytest

from dask.datasets import timeseries

dd = pytest.importorskip("dask.dataframe")
pyspark = pytest.importorskip("pyspark")
pytest.importorskip("pyarrow")
pytest.importorskip("fastparquet")

from dask.dataframe.utils import assert_eq

pytestmark = pytest.mark.skipif(
    sys.platform != "linux",
    reason="Unnecessary, and hard to get spark working on non-linux platforms",
)

# pyspark auto-converts timezones -- round-tripping timestamps is easier if
# we set everything to UTC.
pdf = timeseries(freq="1H").compute()
pdf.index = pdf.index.tz_localize("UTC")
pdf = pdf.reset_index()


@pytest.fixture(scope="module")
def spark_session():
    # Spark registers a global signal handler that can cause problems elsewhere
    # in the test suite. In particular, the handler fails if the spark session
    # is stopped (a bug in pyspark).
    prev = signal.getsignal(signal.SIGINT)
    # Create a spark session. Note that we set the timezone to UTC to avoid
    # conversion to local time when reading parquet files.
    spark = (
        pyspark.sql.SparkSession.builder.master("local")
        .appName("Dask Testing")
        .config("spark.sql.session.timeZone", "UTC")
        .getOrCreate()
    )
    yield spark

    spark.stop()
    # Make sure we get rid of the signal once we leave stop the session.
    if threading.current_thread() is threading.main_thread():
        signal.signal(signal.SIGINT, prev)


@pytest.mark.parametrize("npartitions", (1, 5, 10))
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_parquet_spark_to_dask(spark_session, npartitions, tmpdir, engine):
    tmpdir = str(tmpdir)

    sdf = spark_session.createDataFrame(pdf)
    # We are not overwriting any data, but spark complains if the directory
    # already exists (as tmpdir does) and we don't set overwrite
    sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite")

    ddf = dd.read_parquet(tmpdir, engine=engine)
    # Papercut: pandas TZ localization doesn't survive roundtrip
    ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC"))
    assert ddf.npartitions == npartitions

    assert_eq(ddf, pdf, check_index=False)


@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_hive_parquet_spark_to_dask(spark_session, tmpdir, engine):
    tmpdir = str(tmpdir)

    sdf = spark_session.createDataFrame(pdf)
    # not overwriting any data, but spark complains if the directory
    # already exists and we don't set overwrite
    sdf.write.parquet(tmpdir, mode="overwrite", partitionBy="name")

    ddf = dd.read_parquet(tmpdir, engine=engine)
    # Papercut: pandas TZ localization doesn't survive roundtrip
    ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC"))

    # Partitioning can change the column order. This is mostly okay,
    # but we sort them here to ease comparison
    ddf = ddf.compute().sort_index(axis=1)
    # Dask automatically converts hive-partitioned columns to categories.
    # This is fine, but convert back to strings for comparison.
    ddf = ddf.assign(name=ddf.name.astype("str"))

    assert_eq(ddf, pdf.sort_index(axis=1), check_index=False)


@pytest.mark.parametrize("npartitions", (1, 5, 10))
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_parquet_dask_to_spark(spark_session, npartitions, tmpdir, engine):
    tmpdir = str(tmpdir)
    ddf = dd.from_pandas(pdf, npartitions=npartitions)

    # Papercut: https://github.com/dask/fastparquet/issues/646#issuecomment-885614324
    kwargs = {"times": "int96"} if engine == "fastparquet" else {}

    ddf.to_parquet(tmpdir, engine=engine, write_index=False, **kwargs)

    sdf = spark_session.read.parquet(tmpdir)
    sdf = sdf.toPandas()

    # Papercut: pandas TZ localization doesn't survive roundtrip
    sdf = sdf.assign(timestamp=sdf.timestamp.dt.tz_localize("UTC"))

    assert_eq(sdf, ddf, check_index=False)