Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

Version: 19.0.0.dev70 

/ tests / test_ipc.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from collections import UserList
import datetime
import io
import pathlib
import pytest
import random
import socket
import threading
import weakref

try:
    import numpy as np
except ImportError:
    np = None

import pyarrow as pa
from pyarrow.tests.util import changed_environ, invoke_script


try:
    from pandas.testing import assert_frame_equal
    import pandas as pd
except ImportError:
    pass


class IpcFixture:
    write_stats = None

    def __init__(self, sink_factory=lambda: io.BytesIO()):
        self._sink_factory = sink_factory
        self.sink = self.get_sink()

    def get_sink(self):
        return self._sink_factory()

    def get_source(self):
        return self.sink.getvalue()

    def write_batches(self, num_batches=5, as_table=False):
        nrows = 5
        schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])

        writer = self._get_writer(self.sink, schema)

        batches = []
        for i in range(num_batches):
            batch = pa.record_batch(
                [[random.random() for _ in range(nrows)],
                 ['foo', None, 'bar', 'bazbaz', 'qux']],
                schema=schema)
            batches.append(batch)

        if as_table:
            table = pa.Table.from_batches(batches)
            writer.write_table(table)
        else:
            for batch in batches:
                writer.write_batch(batch)

        self.write_stats = writer.stats
        writer.close()
        return batches


class FileFormatFixture(IpcFixture):

    is_file = True
    options = None

    def _get_writer(self, sink, schema):
        return pa.ipc.new_file(sink, schema, options=self.options)

    def _check_roundtrip(self, as_table=False):
        batches = self.write_batches(as_table=as_table)
        file_contents = pa.BufferReader(self.get_source())

        reader = pa.ipc.open_file(file_contents)

        assert reader.num_record_batches == len(batches)

        for i, batch in enumerate(batches):
            # it works. Must convert back to DataFrame
            batch = reader.get_batch(i)
            assert batches[i].equals(batch)
            assert reader.schema.equals(batches[0].schema)

        assert isinstance(reader.stats, pa.ipc.ReadStats)
        assert isinstance(self.write_stats, pa.ipc.WriteStats)
        assert tuple(reader.stats) == tuple(self.write_stats)


class StreamFormatFixture(IpcFixture):

    # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
    use_legacy_ipc_format = False
    # ARROW-9395, for testing writing old metadata version
    options = None
    is_file = False

    def _get_writer(self, sink, schema):
        return pa.ipc.new_stream(
            sink,
            schema,
            use_legacy_format=self.use_legacy_ipc_format,
            options=self.options,
        )


class MessageFixture(IpcFixture):

    def _get_writer(self, sink, schema):
        return pa.RecordBatchStreamWriter(sink, schema)


@pytest.fixture
def ipc_fixture():
    return IpcFixture()


@pytest.fixture
def file_fixture():
    return FileFormatFixture()


@pytest.fixture
def stream_fixture():
    return StreamFormatFixture()


@pytest.fixture(params=[
    pytest.param(
        'file_fixture',
        id='File Format'
    ),
    pytest.param(
        'stream_fixture',
        id='Stream Format'
    )
])
def format_fixture(request):
    return request.getfixturevalue(request.param)


def test_empty_file():
    buf = b''
    with pytest.raises(pa.ArrowInvalid):
        pa.ipc.open_file(pa.BufferReader(buf))


def test_file_simple_roundtrip(file_fixture):
    file_fixture._check_roundtrip(as_table=False)


def test_file_write_table(file_fixture):
    file_fixture._check_roundtrip(as_table=True)


@pytest.mark.parametrize("sink_factory", [
    lambda: io.BytesIO(),
    lambda: pa.BufferOutputStream()
])
def test_file_read_all(sink_factory):
    fixture = FileFormatFixture(sink_factory)

    batches = fixture.write_batches()
    file_contents = pa.BufferReader(fixture.get_source())

    reader = pa.ipc.open_file(file_contents)

    result = reader.read_all()
    expected = pa.Table.from_batches(batches)
    assert result.equals(expected)


def test_open_file_from_buffer(file_fixture):
    # ARROW-2859; APIs accept the buffer protocol
    file_fixture.write_batches()
    source = file_fixture.get_source()

    reader1 = pa.ipc.open_file(source)
    reader2 = pa.ipc.open_file(pa.BufferReader(source))
    reader3 = pa.RecordBatchFileReader(source)

    result1 = reader1.read_all()
    result2 = reader2.read_all()
    result3 = reader3.read_all()

    assert result1.equals(result2)
    assert result1.equals(result3)

    st1 = reader1.stats
    assert st1.num_messages == 6
    assert st1.num_record_batches == 5
    assert reader2.stats == st1
    assert reader3.stats == st1


@pytest.mark.pandas
def test_file_read_pandas(file_fixture):
    frames = [batch.to_pandas() for batch in file_fixture.write_batches()]

    file_contents = pa.BufferReader(file_fixture.get_source())
    reader = pa.ipc.open_file(file_contents)
    result = reader.read_pandas()

    expected = pd.concat(frames).reset_index(drop=True)
    assert_frame_equal(result, expected)


def test_file_pathlib(file_fixture, tmpdir):
    file_fixture.write_batches()
    source = file_fixture.get_source()

    path = tmpdir.join('file.arrow').strpath
    with open(path, 'wb') as f:
        f.write(source)

    t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
    t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()

    assert t1.equals(t2)


def test_empty_stream():
    buf = io.BytesIO(b'')
    with pytest.raises(pa.ArrowInvalid):
        pa.ipc.open_stream(buf)


@pytest.mark.pandas
@pytest.mark.processes
def test_read_year_month_nano_interval(tmpdir):
    """ARROW-15783: Verify to_pandas works for interval types.

    Interval types require static structures to be enabled. This test verifies
    that they are when no other library functions are invoked.
    """
    mdn_interval_type = pa.month_day_nano_interval()
    schema = pa.schema([pa.field('nums', mdn_interval_type)])

    path = tmpdir.join('file.arrow').strpath
    with pa.OSFile(path, 'wb') as sink:
        with pa.ipc.new_file(sink, schema) as writer:
            interval_array = pa.array([(1, 2, 3)], type=mdn_interval_type)
            batch = pa.record_batch([interval_array], schema)
            writer.write(batch)
    invoke_script('read_record_batch.py', path)


@pytest.mark.pandas
def test_stream_categorical_roundtrip(stream_fixture):
    df = pd.DataFrame({
        'one': np.random.randn(5),
        'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
                              categories=['foo', 'bar'],
                              ordered=True)
    })
    batch = pa.RecordBatch.from_pandas(df)
    with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
        wr.write_batch(batch)

    table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
             .read_all())
    assert_frame_equal(table.to_pandas(), df)


def test_open_stream_from_buffer(stream_fixture):
    # ARROW-2859
    stream_fixture.write_batches()
    source = stream_fixture.get_source()

    reader1 = pa.ipc.open_stream(source)
    reader2 = pa.ipc.open_stream(pa.BufferReader(source))
    reader3 = pa.RecordBatchStreamReader(source)

    result1 = reader1.read_all()
    result2 = reader2.read_all()
    result3 = reader3.read_all()

    assert result1.equals(result2)
    assert result1.equals(result3)

    st1 = reader1.stats
    assert st1.num_messages == 6
    assert st1.num_record_batches == 5
    assert reader2.stats == st1
    assert reader3.stats == st1

    assert tuple(st1) == tuple(stream_fixture.write_stats)


@pytest.mark.parametrize('options', [
    pa.ipc.IpcReadOptions(),
    pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_stream_options(stream_fixture, options):
    stream_fixture.write_batches()
    source = stream_fixture.get_source()

    reader = pa.ipc.open_stream(source, options=options)

    reader.read_all()
    st = reader.stats
    assert st.num_messages == 6
    assert st.num_record_batches == 5

    assert tuple(st) == tuple(stream_fixture.write_stats)


def test_open_stream_with_wrong_options(stream_fixture):
    stream_fixture.write_batches()
    source = stream_fixture.get_source()

    with pytest.raises(TypeError):
        pa.ipc.open_stream(source, options=True)


@pytest.mark.parametrize('options', [
    pa.ipc.IpcReadOptions(),
    pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_file_options(file_fixture, options):
    file_fixture.write_batches()
    source = file_fixture.get_source()

    reader = pa.ipc.open_file(source, options=options)
Loading ...