# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import UserList
import io
import pathlib
import pytest
import socket
import threading
import weakref
import numpy as np
import pyarrow as pa
from pyarrow.tests.util import changed_environ, invoke_script
try:
from pandas.testing import assert_frame_equal
import pandas as pd
except ImportError:
pass
class IpcFixture:
write_stats = None
def __init__(self, sink_factory=lambda: io.BytesIO()):
self._sink_factory = sink_factory
self.sink = self.get_sink()
def get_sink(self):
return self._sink_factory()
def get_source(self):
return self.sink.getvalue()
def write_batches(self, num_batches=5, as_table=False):
nrows = 5
schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])
writer = self._get_writer(self.sink, schema)
batches = []
for i in range(num_batches):
batch = pa.record_batch(
[np.random.randn(nrows),
['foo', None, 'bar', 'bazbaz', 'qux']],
schema=schema)
batches.append(batch)
if as_table:
table = pa.Table.from_batches(batches)
writer.write_table(table)
else:
for batch in batches:
writer.write_batch(batch)
self.write_stats = writer.stats
writer.close()
return batches
class FileFormatFixture(IpcFixture):
is_file = True
options = None
def _get_writer(self, sink, schema):
return pa.ipc.new_file(sink, schema, options=self.options)
def _check_roundtrip(self, as_table=False):
batches = self.write_batches(as_table=as_table)
file_contents = pa.BufferReader(self.get_source())
reader = pa.ipc.open_file(file_contents)
assert reader.num_record_batches == len(batches)
for i, batch in enumerate(batches):
# it works. Must convert back to DataFrame
batch = reader.get_batch(i)
assert batches[i].equals(batch)
assert reader.schema.equals(batches[0].schema)
assert isinstance(reader.stats, pa.ipc.ReadStats)
assert isinstance(self.write_stats, pa.ipc.WriteStats)
assert tuple(reader.stats) == tuple(self.write_stats)
class StreamFormatFixture(IpcFixture):
# ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
use_legacy_ipc_format = False
# ARROW-9395, for testing writing old metadata version
options = None
is_file = False
def _get_writer(self, sink, schema):
return pa.ipc.new_stream(
sink,
schema,
use_legacy_format=self.use_legacy_ipc_format,
options=self.options,
)
class MessageFixture(IpcFixture):
def _get_writer(self, sink, schema):
return pa.RecordBatchStreamWriter(sink, schema)
@pytest.fixture
def ipc_fixture():
return IpcFixture()
@pytest.fixture
def file_fixture():
return FileFormatFixture()
@pytest.fixture
def stream_fixture():
return StreamFormatFixture()
@pytest.fixture(params=[
pytest.param(
'file_fixture',
id='File Format'
),
pytest.param(
'stream_fixture',
id='Stream Format'
)
])
def format_fixture(request):
return request.getfixturevalue(request.param)
def test_empty_file():
buf = b''
with pytest.raises(pa.ArrowInvalid):
pa.ipc.open_file(pa.BufferReader(buf))
def test_file_simple_roundtrip(file_fixture):
file_fixture._check_roundtrip(as_table=False)
def test_file_write_table(file_fixture):
file_fixture._check_roundtrip(as_table=True)
@pytest.mark.parametrize("sink_factory", [
lambda: io.BytesIO(),
lambda: pa.BufferOutputStream()
])
def test_file_read_all(sink_factory):
fixture = FileFormatFixture(sink_factory)
batches = fixture.write_batches()
file_contents = pa.BufferReader(fixture.get_source())
reader = pa.ipc.open_file(file_contents)
result = reader.read_all()
expected = pa.Table.from_batches(batches)
assert result.equals(expected)
def test_open_file_from_buffer(file_fixture):
# ARROW-2859; APIs accept the buffer protocol
file_fixture.write_batches()
source = file_fixture.get_source()
reader1 = pa.ipc.open_file(source)
reader2 = pa.ipc.open_file(pa.BufferReader(source))
reader3 = pa.RecordBatchFileReader(source)
result1 = reader1.read_all()
result2 = reader2.read_all()
result3 = reader3.read_all()
assert result1.equals(result2)
assert result1.equals(result3)
st1 = reader1.stats
assert st1.num_messages == 6
assert st1.num_record_batches == 5
assert reader2.stats == st1
assert reader3.stats == st1
@pytest.mark.pandas
def test_file_read_pandas(file_fixture):
frames = [batch.to_pandas() for batch in file_fixture.write_batches()]
file_contents = pa.BufferReader(file_fixture.get_source())
reader = pa.ipc.open_file(file_contents)
result = reader.read_pandas()
expected = pd.concat(frames).reset_index(drop=True)
assert_frame_equal(result, expected)
def test_file_pathlib(file_fixture, tmpdir):
file_fixture.write_batches()
source = file_fixture.get_source()
path = tmpdir.join('file.arrow').strpath
with open(path, 'wb') as f:
f.write(source)
t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()
assert t1.equals(t2)
def test_empty_stream():
buf = io.BytesIO(b'')
with pytest.raises(pa.ArrowInvalid):
pa.ipc.open_stream(buf)
@pytest.mark.pandas
def test_read_year_month_nano_interval(tmpdir):
"""ARROW-15783: Verify to_pandas works for interval types.
Interval types require static structures to be enabled. This test verifies
that they are when no other library functions are invoked.
"""
mdn_interval_type = pa.month_day_nano_interval()
schema = pa.schema([pa.field('nums', mdn_interval_type)])
path = tmpdir.join('file.arrow').strpath
with pa.OSFile(path, 'wb') as sink:
with pa.ipc.new_file(sink, schema) as writer:
interval_array = pa.array([(1, 2, 3)], type=mdn_interval_type)
batch = pa.record_batch([interval_array], schema)
writer.write(batch)
invoke_script('read_record_batch.py', path)
@pytest.mark.pandas
def test_stream_categorical_roundtrip(stream_fixture):
df = pd.DataFrame({
'one': np.random.randn(5),
'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
categories=['foo', 'bar'],
ordered=True)
})
batch = pa.RecordBatch.from_pandas(df)
with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
wr.write_batch(batch)
table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
.read_all())
assert_frame_equal(table.to_pandas(), df)
def test_open_stream_from_buffer(stream_fixture):
# ARROW-2859
stream_fixture.write_batches()
source = stream_fixture.get_source()
reader1 = pa.ipc.open_stream(source)
reader2 = pa.ipc.open_stream(pa.BufferReader(source))
reader3 = pa.RecordBatchStreamReader(source)
result1 = reader1.read_all()
result2 = reader2.read_all()
result3 = reader3.read_all()
assert result1.equals(result2)
assert result1.equals(result3)
st1 = reader1.stats
assert st1.num_messages == 6
assert st1.num_record_batches == 5
assert reader2.stats == st1
assert reader3.stats == st1
assert tuple(st1) == tuple(stream_fixture.write_stats)
@pytest.mark.parametrize('options', [
pa.ipc.IpcReadOptions(),
pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_stream_options(stream_fixture, options):
stream_fixture.write_batches()
source = stream_fixture.get_source()
reader = pa.ipc.open_stream(source, options=options)
reader.read_all()
st = reader.stats
assert st.num_messages == 6
assert st.num_record_batches == 5
assert tuple(st) == tuple(stream_fixture.write_stats)
def test_open_stream_with_wrong_options(stream_fixture):
stream_fixture.write_batches()
source = stream_fixture.get_source()
with pytest.raises(TypeError):
pa.ipc.open_stream(source, options=True)
@pytest.mark.parametrize('options', [
pa.ipc.IpcReadOptions(),
pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_file_options(file_fixture, options):
file_fixture.write_batches()
source = file_fixture.get_source()
reader = pa.ipc.open_file(source, options=options)
reader.read_all()
st = reader.stats
assert st.num_messages == 6
assert st.num_record_batches == 5
Loading ...