Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

/ tests / test_ipc.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from collections import UserList
import io
import pathlib
import pytest
import socket
import threading
import weakref

import numpy as np

import pyarrow as pa
from pyarrow.tests.util import changed_environ, invoke_script


try:
    from pandas.testing import assert_frame_equal
    import pandas as pd
except ImportError:
    pass


class IpcFixture:
    write_stats = None

    def __init__(self, sink_factory=lambda: io.BytesIO()):
        self._sink_factory = sink_factory
        self.sink = self.get_sink()

    def get_sink(self):
        return self._sink_factory()

    def get_source(self):
        return self.sink.getvalue()

    def write_batches(self, num_batches=5, as_table=False):
        nrows = 5
        schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])

        writer = self._get_writer(self.sink, schema)

        batches = []
        for i in range(num_batches):
            batch = pa.record_batch(
                [np.random.randn(nrows),
                 ['foo', None, 'bar', 'bazbaz', 'qux']],
                schema=schema)
            batches.append(batch)

        if as_table:
            table = pa.Table.from_batches(batches)
            writer.write_table(table)
        else:
            for batch in batches:
                writer.write_batch(batch)

        self.write_stats = writer.stats
        writer.close()
        return batches


class FileFormatFixture(IpcFixture):

    is_file = True
    options = None

    def _get_writer(self, sink, schema):
        return pa.ipc.new_file(sink, schema, options=self.options)

    def _check_roundtrip(self, as_table=False):
        batches = self.write_batches(as_table=as_table)
        file_contents = pa.BufferReader(self.get_source())

        reader = pa.ipc.open_file(file_contents)

        assert reader.num_record_batches == len(batches)

        for i, batch in enumerate(batches):
            # it works. Must convert back to DataFrame
            batch = reader.get_batch(i)
            assert batches[i].equals(batch)
            assert reader.schema.equals(batches[0].schema)

        assert isinstance(reader.stats, pa.ipc.ReadStats)
        assert isinstance(self.write_stats, pa.ipc.WriteStats)
        assert tuple(reader.stats) == tuple(self.write_stats)


class StreamFormatFixture(IpcFixture):

    # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
    use_legacy_ipc_format = False
    # ARROW-9395, for testing writing old metadata version
    options = None
    is_file = False

    def _get_writer(self, sink, schema):
        return pa.ipc.new_stream(
            sink,
            schema,
            use_legacy_format=self.use_legacy_ipc_format,
            options=self.options,
        )


class MessageFixture(IpcFixture):

    def _get_writer(self, sink, schema):
        return pa.RecordBatchStreamWriter(sink, schema)


@pytest.fixture
def ipc_fixture():
    return IpcFixture()


@pytest.fixture
def file_fixture():
    return FileFormatFixture()


@pytest.fixture
def stream_fixture():
    return StreamFormatFixture()


@pytest.fixture(params=[
    pytest.param(
        'file_fixture',
        id='File Format'
    ),
    pytest.param(
        'stream_fixture',
        id='Stream Format'
    )
])
def format_fixture(request):
    return request.getfixturevalue(request.param)


def test_empty_file():
    buf = b''
    with pytest.raises(pa.ArrowInvalid):
        pa.ipc.open_file(pa.BufferReader(buf))


def test_file_simple_roundtrip(file_fixture):
    file_fixture._check_roundtrip(as_table=False)


def test_file_write_table(file_fixture):
    file_fixture._check_roundtrip(as_table=True)


@pytest.mark.parametrize("sink_factory", [
    lambda: io.BytesIO(),
    lambda: pa.BufferOutputStream()
])
def test_file_read_all(sink_factory):
    fixture = FileFormatFixture(sink_factory)

    batches = fixture.write_batches()
    file_contents = pa.BufferReader(fixture.get_source())

    reader = pa.ipc.open_file(file_contents)

    result = reader.read_all()
    expected = pa.Table.from_batches(batches)
    assert result.equals(expected)


def test_open_file_from_buffer(file_fixture):
    # ARROW-2859; APIs accept the buffer protocol
    file_fixture.write_batches()
    source = file_fixture.get_source()

    reader1 = pa.ipc.open_file(source)
    reader2 = pa.ipc.open_file(pa.BufferReader(source))
    reader3 = pa.RecordBatchFileReader(source)

    result1 = reader1.read_all()
    result2 = reader2.read_all()
    result3 = reader3.read_all()

    assert result1.equals(result2)
    assert result1.equals(result3)

    st1 = reader1.stats
    assert st1.num_messages == 6
    assert st1.num_record_batches == 5
    assert reader2.stats == st1
    assert reader3.stats == st1


@pytest.mark.pandas
def test_file_read_pandas(file_fixture):
    frames = [batch.to_pandas() for batch in file_fixture.write_batches()]

    file_contents = pa.BufferReader(file_fixture.get_source())
    reader = pa.ipc.open_file(file_contents)
    result = reader.read_pandas()

    expected = pd.concat(frames).reset_index(drop=True)
    assert_frame_equal(result, expected)


def test_file_pathlib(file_fixture, tmpdir):
    file_fixture.write_batches()
    source = file_fixture.get_source()

    path = tmpdir.join('file.arrow').strpath
    with open(path, 'wb') as f:
        f.write(source)

    t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
    t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()

    assert t1.equals(t2)


def test_empty_stream():
    buf = io.BytesIO(b'')
    with pytest.raises(pa.ArrowInvalid):
        pa.ipc.open_stream(buf)


@pytest.mark.pandas
def test_read_year_month_nano_interval(tmpdir):
    """ARROW-15783: Verify to_pandas works for interval types.

    Interval types require static structures to be enabled. This test verifies
    that they are when no other library functions are invoked.
    """
    mdn_interval_type = pa.month_day_nano_interval()
    schema = pa.schema([pa.field('nums', mdn_interval_type)])

    path = tmpdir.join('file.arrow').strpath
    with pa.OSFile(path, 'wb') as sink:
        with pa.ipc.new_file(sink, schema) as writer:
            interval_array = pa.array([(1, 2, 3)], type=mdn_interval_type)
            batch = pa.record_batch([interval_array], schema)
            writer.write(batch)
    invoke_script('read_record_batch.py', path)


@pytest.mark.pandas
def test_stream_categorical_roundtrip(stream_fixture):
    df = pd.DataFrame({
        'one': np.random.randn(5),
        'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
                              categories=['foo', 'bar'],
                              ordered=True)
    })
    batch = pa.RecordBatch.from_pandas(df)
    with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
        wr.write_batch(batch)

    table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
             .read_all())
    assert_frame_equal(table.to_pandas(), df)


def test_open_stream_from_buffer(stream_fixture):
    # ARROW-2859
    stream_fixture.write_batches()
    source = stream_fixture.get_source()

    reader1 = pa.ipc.open_stream(source)
    reader2 = pa.ipc.open_stream(pa.BufferReader(source))
    reader3 = pa.RecordBatchStreamReader(source)

    result1 = reader1.read_all()
    result2 = reader2.read_all()
    result3 = reader3.read_all()

    assert result1.equals(result2)
    assert result1.equals(result3)

    st1 = reader1.stats
    assert st1.num_messages == 6
    assert st1.num_record_batches == 5
    assert reader2.stats == st1
    assert reader3.stats == st1

    assert tuple(st1) == tuple(stream_fixture.write_stats)


@pytest.mark.parametrize('options', [
    pa.ipc.IpcReadOptions(),
    pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_stream_options(stream_fixture, options):
    stream_fixture.write_batches()
    source = stream_fixture.get_source()

    reader = pa.ipc.open_stream(source, options=options)

    reader.read_all()
    st = reader.stats
    assert st.num_messages == 6
    assert st.num_record_batches == 5

    assert tuple(st) == tuple(stream_fixture.write_stats)


def test_open_stream_with_wrong_options(stream_fixture):
    stream_fixture.write_batches()
    source = stream_fixture.get_source()

    with pytest.raises(TypeError):
        pa.ipc.open_stream(source, options=True)


@pytest.mark.parametrize('options', [
    pa.ipc.IpcReadOptions(),
    pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_file_options(file_fixture, options):
    file_fixture.write_batches()
    source = file_fixture.get_source()

    reader = pa.ipc.open_file(source, options=options)

    reader.read_all()

    st = reader.stats
    assert st.num_messages == 6
    assert st.num_record_batches == 5
Loading ...