# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import UserList
import datetime
import io
import pathlib
import pytest
import random
import socket
import threading
import weakref
try:
import numpy as np
except ImportError:
np = None
import pyarrow as pa
from pyarrow.tests.util import changed_environ, invoke_script
try:
from pandas.testing import assert_frame_equal
import pandas as pd
except ImportError:
pass
class IpcFixture:
write_stats = None
def __init__(self, sink_factory=lambda: io.BytesIO()):
self._sink_factory = sink_factory
self.sink = self.get_sink()
def get_sink(self):
return self._sink_factory()
def get_source(self):
return self.sink.getvalue()
def write_batches(self, num_batches=5, as_table=False):
nrows = 5
schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])
writer = self._get_writer(self.sink, schema)
batches = []
for i in range(num_batches):
batch = pa.record_batch(
[[random.random() for _ in range(nrows)],
['foo', None, 'bar', 'bazbaz', 'qux']],
schema=schema)
batches.append(batch)
if as_table:
table = pa.Table.from_batches(batches)
writer.write_table(table)
else:
for batch in batches:
writer.write_batch(batch)
self.write_stats = writer.stats
writer.close()
return batches
class FileFormatFixture(IpcFixture):
is_file = True
options = None
def _get_writer(self, sink, schema):
return pa.ipc.new_file(sink, schema, options=self.options)
def _check_roundtrip(self, as_table=False):
batches = self.write_batches(as_table=as_table)
file_contents = pa.BufferReader(self.get_source())
reader = pa.ipc.open_file(file_contents)
assert reader.num_record_batches == len(batches)
for i, batch in enumerate(batches):
# it works. Must convert back to DataFrame
batch = reader.get_batch(i)
assert batches[i].equals(batch)
assert reader.schema.equals(batches[0].schema)
assert isinstance(reader.stats, pa.ipc.ReadStats)
assert isinstance(self.write_stats, pa.ipc.WriteStats)
assert tuple(reader.stats) == tuple(self.write_stats)
class StreamFormatFixture(IpcFixture):
# ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
use_legacy_ipc_format = False
# ARROW-9395, for testing writing old metadata version
options = None
is_file = False
def _get_writer(self, sink, schema):
return pa.ipc.new_stream(
sink,
schema,
use_legacy_format=self.use_legacy_ipc_format,
options=self.options,
)
class MessageFixture(IpcFixture):
def _get_writer(self, sink, schema):
return pa.RecordBatchStreamWriter(sink, schema)
@pytest.fixture
def ipc_fixture():
return IpcFixture()
@pytest.fixture
def file_fixture():
return FileFormatFixture()
@pytest.fixture
def stream_fixture():
return StreamFormatFixture()
@pytest.fixture(params=[
pytest.param(
'file_fixture',
id='File Format'
),
pytest.param(
'stream_fixture',
id='Stream Format'
)
])
def format_fixture(request):
return request.getfixturevalue(request.param)
def test_empty_file():
buf = b''
with pytest.raises(pa.ArrowInvalid):
pa.ipc.open_file(pa.BufferReader(buf))
def test_file_simple_roundtrip(file_fixture):
file_fixture._check_roundtrip(as_table=False)
def test_file_write_table(file_fixture):
file_fixture._check_roundtrip(as_table=True)
@pytest.mark.parametrize("sink_factory", [
lambda: io.BytesIO(),
lambda: pa.BufferOutputStream()
])
def test_file_read_all(sink_factory):
fixture = FileFormatFixture(sink_factory)
batches = fixture.write_batches()
file_contents = pa.BufferReader(fixture.get_source())
reader = pa.ipc.open_file(file_contents)
result = reader.read_all()
expected = pa.Table.from_batches(batches)
assert result.equals(expected)
def test_open_file_from_buffer(file_fixture):
# ARROW-2859; APIs accept the buffer protocol
file_fixture.write_batches()
source = file_fixture.get_source()
reader1 = pa.ipc.open_file(source)
reader2 = pa.ipc.open_file(pa.BufferReader(source))
reader3 = pa.RecordBatchFileReader(source)
result1 = reader1.read_all()
result2 = reader2.read_all()
result3 = reader3.read_all()
assert result1.equals(result2)
assert result1.equals(result3)
st1 = reader1.stats
assert st1.num_messages == 6
assert st1.num_record_batches == 5
assert reader2.stats == st1
assert reader3.stats == st1
@pytest.mark.pandas
def test_file_read_pandas(file_fixture):
frames = [batch.to_pandas() for batch in file_fixture.write_batches()]
file_contents = pa.BufferReader(file_fixture.get_source())
reader = pa.ipc.open_file(file_contents)
result = reader.read_pandas()
expected = pd.concat(frames).reset_index(drop=True)
assert_frame_equal(result, expected)
def test_file_pathlib(file_fixture, tmpdir):
file_fixture.write_batches()
source = file_fixture.get_source()
path = tmpdir.join('file.arrow').strpath
with open(path, 'wb') as f:
f.write(source)
t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()
assert t1.equals(t2)
def test_empty_stream():
buf = io.BytesIO(b'')
with pytest.raises(pa.ArrowInvalid):
pa.ipc.open_stream(buf)
@pytest.mark.pandas
@pytest.mark.processes
def test_read_year_month_nano_interval(tmpdir):
"""ARROW-15783: Verify to_pandas works for interval types.
Interval types require static structures to be enabled. This test verifies
that they are when no other library functions are invoked.
"""
mdn_interval_type = pa.month_day_nano_interval()
schema = pa.schema([pa.field('nums', mdn_interval_type)])
path = tmpdir.join('file.arrow').strpath
with pa.OSFile(path, 'wb') as sink:
with pa.ipc.new_file(sink, schema) as writer:
interval_array = pa.array([(1, 2, 3)], type=mdn_interval_type)
batch = pa.record_batch([interval_array], schema)
writer.write(batch)
invoke_script('read_record_batch.py', path)
@pytest.mark.pandas
def test_stream_categorical_roundtrip(stream_fixture):
df = pd.DataFrame({
'one': np.random.randn(5),
'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
categories=['foo', 'bar'],
ordered=True)
})
batch = pa.RecordBatch.from_pandas(df)
with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
wr.write_batch(batch)
table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
.read_all())
assert_frame_equal(table.to_pandas(), df)
def test_open_stream_from_buffer(stream_fixture):
# ARROW-2859
stream_fixture.write_batches()
source = stream_fixture.get_source()
reader1 = pa.ipc.open_stream(source)
reader2 = pa.ipc.open_stream(pa.BufferReader(source))
reader3 = pa.RecordBatchStreamReader(source)
result1 = reader1.read_all()
result2 = reader2.read_all()
result3 = reader3.read_all()
assert result1.equals(result2)
assert result1.equals(result3)
st1 = reader1.stats
assert st1.num_messages == 6
assert st1.num_record_batches == 5
assert reader2.stats == st1
assert reader3.stats == st1
assert tuple(st1) == tuple(stream_fixture.write_stats)
@pytest.mark.parametrize('options', [
pa.ipc.IpcReadOptions(),
pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_stream_options(stream_fixture, options):
stream_fixture.write_batches()
source = stream_fixture.get_source()
reader = pa.ipc.open_stream(source, options=options)
reader.read_all()
st = reader.stats
assert st.num_messages == 6
assert st.num_record_batches == 5
assert tuple(st) == tuple(stream_fixture.write_stats)
def test_open_stream_with_wrong_options(stream_fixture):
stream_fixture.write_batches()
source = stream_fixture.get_source()
with pytest.raises(TypeError):
pa.ipc.open_stream(source, options=True)
@pytest.mark.parametrize('options', [
pa.ipc.IpcReadOptions(),
pa.ipc.IpcReadOptions(use_threads=False),
])
def test_open_file_options(file_fixture, options):
file_fixture.write_batches()
source = file_fixture.get_source()
reader = pa.ipc.open_file(source, options=options)
Loading ...