Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

Version: 19.0.0.dev70 

/ tests / interchange / test_conversion.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import datetime as dt
import pyarrow as pa
from pyarrow.vendored.version import Version
import pytest

try:
    import numpy as np
except ImportError:
    np = None

import pyarrow.interchange as pi
from pyarrow.interchange.column import (
    _PyArrowColumn,
    ColumnNullType,
    DtypeKind,
)
from pyarrow.interchange.from_dataframe import _from_dataframe

try:
    import pandas as pd
    # import pandas.testing as tm
except ImportError:
    pass


@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
def test_datetime(unit, tz):
    dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), None]
    table = pa.table({"A": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))})
    col = table.__dataframe__().get_column_by_name("A")

    assert col.size() == 3
    assert col.offset == 0
    assert col.null_count == 1
    assert col.dtype[0] == DtypeKind.DATETIME
    assert col.describe_null == (ColumnNullType.USE_BITMASK, 0)


@pytest.mark.parametrize(
    ["test_data", "kind"],
    [
        (["foo", "bar"], 21),
        ([1.5, 2.5, 3.5], 2),
        ([1, 2, 3, 4], 0),
    ],
)
def test_array_to_pyarrowcolumn(test_data, kind):
    arr = pa.array(test_data)
    arr_column = _PyArrowColumn(arr)

    assert arr_column._col == arr
    assert arr_column.size() == len(test_data)
    assert arr_column.dtype[0] == kind
    assert arr_column.num_chunks() == 1
    assert arr_column.null_count == 0
    assert arr_column.get_buffers()["validity"] is None
    assert len(list(arr_column.get_chunks())) == 1

    for chunk in arr_column.get_chunks():
        assert chunk == arr_column


def test_offset_of_sliced_array():
    arr = pa.array([1, 2, 3, 4])
    arr_sliced = arr.slice(2, 2)

    table = pa.table([arr], names=["arr"])
    table_sliced = pa.table([arr_sliced], names=["arr_sliced"])

    col = table_sliced.__dataframe__().get_column(0)
    assert col.offset == 2

    result = _from_dataframe(table_sliced.__dataframe__())
    assert table_sliced.equals(result)
    assert not table.equals(result)

    # pandas hardcodes offset to 0:
    # https://github.com/pandas-dev/pandas/blob/5c66e65d7b9fef47ccb585ce2fd0b3ea18dc82ea/pandas/core/interchange/from_dataframe.py#L247
    # so conversion to pandas can't be tested currently

    # df = pandas_from_dataframe(table)
    # df_sliced = pandas_from_dataframe(table_sliced)

    # tm.assert_series_equal(df["arr"][2:4], df_sliced["arr_sliced"],
    #                        check_index=False, check_names=False)


@pytest.mark.pandas
@pytest.mark.parametrize(
    "uint", [pa.uint8(), pa.uint16(), pa.uint32()]
)
@pytest.mark.parametrize(
    "int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
)
@pytest.mark.parametrize(
    "float, np_float_str", [
        # (pa.float16(), np.float16),   #not supported by pandas
        (pa.float32(), "float32"),
        (pa.float64(), "float64")
    ]
)
def test_pandas_roundtrip(uint, int, float, np_float_str):
    if Version(pd.__version__) < Version("1.5.0"):
        pytest.skip("__dataframe__ added to pandas in 1.5.0")

    arr = [1, 2, 3]
    table = pa.table(
        {
            "a": pa.array(arr, type=uint),
            "b": pa.array(arr, type=int),
            "c": pa.array(np.array(arr, dtype=np.dtype(np_float_str)), type=float),
            "d": [True, False, True],
        }
    )
    from pandas.api.interchange import (
        from_dataframe as pandas_from_dataframe
    )
    pandas_df = pandas_from_dataframe(table)
    result = pi.from_dataframe(pandas_df)
    assert table.equals(result)

    table_protocol = table.__dataframe__()
    result_protocol = result.__dataframe__()

    assert table_protocol.num_columns() == result_protocol.num_columns()
    assert table_protocol.num_rows() == result_protocol.num_rows()
    assert table_protocol.num_chunks() == result_protocol.num_chunks()
    assert table_protocol.column_names() == result_protocol.column_names()


@pytest.mark.pandas
def test_pandas_roundtrip_string():
    # See https://github.com/pandas-dev/pandas/issues/50554
    if Version(pd.__version__) < Version("1.6"):
        pytest.skip("Column.size() bug in pandas")

    arr = ["a", "", "c"]
    table = pa.table({"a": pa.array(arr)})

    from pandas.api.interchange import (
        from_dataframe as pandas_from_dataframe
    )

    pandas_df = pandas_from_dataframe(table)
    result = pi.from_dataframe(pandas_df)

    assert result["a"].to_pylist() == table["a"].to_pylist()
    assert pa.types.is_string(table["a"].type)
    assert pa.types.is_large_string(result["a"].type)

    table_protocol = table.__dataframe__()
    result_protocol = result.__dataframe__()

    assert table_protocol.num_columns() == result_protocol.num_columns()
    assert table_protocol.num_rows() == result_protocol.num_rows()
    assert table_protocol.num_chunks() == result_protocol.num_chunks()
    assert table_protocol.column_names() == result_protocol.column_names()


@pytest.mark.pandas
def test_pandas_roundtrip_large_string():
    # See https://github.com/pandas-dev/pandas/issues/50554
    if Version(pd.__version__) < Version("1.6"):
        pytest.skip("Column.size() bug in pandas")

    arr = ["a", "", "c"]
    table = pa.table({"a_large": pa.array(arr, type=pa.large_string())})

    from pandas.api.interchange import (
        from_dataframe as pandas_from_dataframe
    )

    if Version(pd.__version__) >= Version("2.0.1"):
        pandas_df = pandas_from_dataframe(table)
        result = pi.from_dataframe(pandas_df)

        assert result["a_large"].to_pylist() == table["a_large"].to_pylist()
        assert pa.types.is_large_string(table["a_large"].type)
        assert pa.types.is_large_string(result["a_large"].type)

        table_protocol = table.__dataframe__()
        result_protocol = result.__dataframe__()

        assert table_protocol.num_columns() == result_protocol.num_columns()
        assert table_protocol.num_rows() == result_protocol.num_rows()
        assert table_protocol.num_chunks() == result_protocol.num_chunks()
        assert table_protocol.column_names() == result_protocol.column_names()

    else:
        # large string not supported by pandas implementation for
        # older versions of pandas
        # https://github.com/pandas-dev/pandas/issues/52795
        with pytest.raises(AssertionError):
            pandas_from_dataframe(table)


@pytest.mark.pandas
def test_pandas_roundtrip_string_with_missing():
    # See https://github.com/pandas-dev/pandas/issues/50554
    if Version(pd.__version__) < Version("1.6"):
        pytest.skip("Column.size() bug in pandas")

    arr = ["a", "", "c", None]
    table = pa.table({"a": pa.array(arr),
                      "a_large": pa.array(arr, type=pa.large_string())})

    from pandas.api.interchange import (
        from_dataframe as pandas_from_dataframe
    )

    if Version(pd.__version__) >= Version("2.0.2"):
        pandas_df = pandas_from_dataframe(table)
        result = pi.from_dataframe(pandas_df)

        assert result["a"].to_pylist() == table["a"].to_pylist()
        assert pa.types.is_string(table["a"].type)
        assert pa.types.is_large_string(result["a"].type)

        assert result["a_large"].to_pylist() == table["a_large"].to_pylist()
        assert pa.types.is_large_string(table["a_large"].type)
        assert pa.types.is_large_string(result["a_large"].type)
    else:
        # older versions of pandas do not have bitmask support
        # https://github.com/pandas-dev/pandas/issues/49888
        with pytest.raises(NotImplementedError):
            pandas_from_dataframe(table)


@pytest.mark.pandas
def test_pandas_roundtrip_categorical():
    if Version(pd.__version__) < Version("2.0.2"):
        pytest.skip("Bitmasks not supported in pandas interchange implementation")

    arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
    table = pa.table(
        {"weekday": pa.array(arr).dictionary_encode()}
    )

    from pandas.api.interchange import (
        from_dataframe as pandas_from_dataframe
    )
    pandas_df = pandas_from_dataframe(table)
    result = pi.from_dataframe(pandas_df)

    assert result["weekday"].to_pylist() == table["weekday"].to_pylist()
    assert pa.types.is_dictionary(table["weekday"].type)
    assert pa.types.is_dictionary(result["weekday"].type)
    assert pa.types.is_string(table["weekday"].chunk(0).dictionary.type)
    assert pa.types.is_large_string(result["weekday"].chunk(0).dictionary.type)
    assert pa.types.is_int32(table["weekday"].chunk(0).indices.type)
    assert pa.types.is_int8(result["weekday"].chunk(0).indices.type)

    table_protocol = table.__dataframe__()
    result_protocol = result.__dataframe__()

    assert table_protocol.num_columns() == result_protocol.num_columns()
    assert table_protocol.num_rows() == result_protocol.num_rows()
    assert table_protocol.num_chunks() == result_protocol.num_chunks()
    assert table_protocol.column_names() == result_protocol.column_names()

    col_table = table_protocol.get_column(0)
    col_result = result_protocol.get_column(0)

    assert col_result.dtype[0] == DtypeKind.CATEGORICAL
    assert col_result.dtype[0] == col_table.dtype[0]
    assert col_result.size() == col_table.size()
    assert col_result.offset == col_table.offset

    desc_cat_table = col_result.describe_categorical
    desc_cat_result = col_result.describe_categorical

    assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
    assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"]
    assert isinstance(desc_cat_result["categories"]._col, pa.Array)


@pytest.mark.pandas
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
def test_pandas_roundtrip_datetime(unit):
    if Version(pd.__version__) < Version("1.5.0"):
        pytest.skip("__dataframe__ added to pandas in 1.5.0")
    from datetime import datetime as dt

    # timezones not included as they are not yet supported in
    # the pandas implementation
    dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
    table = pa.table({"a": pa.array(dt_arr, type=pa.timestamp(unit))})

    if Version(pd.__version__) < Version("1.6"):
        # pandas < 2.0 always creates datetime64 in "ns"
        # resolution
        expected = pa.table({"a": pa.array(dt_arr, type=pa.timestamp('ns'))})
    else:
        expected = table

    from pandas.api.interchange import (
        from_dataframe as pandas_from_dataframe
    )
    pandas_df = pandas_from_dataframe(table)
    result = pi.from_dataframe(pandas_df)

    assert expected.equals(result)

    expected_protocol = expected.__dataframe__()
    result_protocol = result.__dataframe__()

    assert expected_protocol.num_columns() == result_protocol.num_columns()
    assert expected_protocol.num_rows() == result_protocol.num_rows()
    assert expected_protocol.num_chunks() == result_protocol.num_chunks()
    assert expected_protocol.column_names() == result_protocol.column_names()


@pytest.mark.pandas
@pytest.mark.parametrize(
    "np_float_str", ["float32", "float64"]
)
def test_pandas_to_pyarrow_with_missing(np_float_str):
    if Version(pd.__version__) < Version("1.5.0"):
        pytest.skip("__dataframe__ added to pandas in 1.5.0")

    np_array = np.array([0, np.nan, 2], dtype=np.dtype(np_float_str))
    datetime_array = [None, dt(2007, 7, 14), dt(2007, 7, 15)]
    df = pd.DataFrame({
        # float, ColumnNullType.USE_NAN
        "a": np_array,
        # ColumnNullType.USE_SENTINEL
        "dt": np.array(datetime_array, dtype="datetime64[ns]")
Loading ...