# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import os
import shutil
import subprocess
import weakref
from uuid import uuid4, UUID
import sys
import pytest
try:
import numpy as np
except ImportError:
np = None
import pyarrow as pa
from pyarrow.vendored.version import Version
@contextlib.contextmanager
def registered_extension_type(ext_type):
pa.register_extension_type(ext_type)
try:
yield
finally:
pa.unregister_extension_type(ext_type.extension_name)
@contextlib.contextmanager
def enabled_auto_load():
pa.PyExtensionType.set_auto_load(True)
try:
yield
finally:
pa.PyExtensionType.set_auto_load(False)
class TinyIntType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.int8(), 'pyarrow.tests.TinyIntType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
assert storage_type == pa.int8()
return cls()
class IntegerType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.int64(), 'pyarrow.tests.IntegerType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
assert storage_type == pa.int64()
return cls()
class IntegerEmbeddedType(pa.ExtensionType):
def __init__(self):
super().__init__(IntegerType(), 'pyarrow.tests.IntegerType')
def __arrow_ext_serialize__(self):
# XXX pa.BaseExtensionType should expose C++ serialization method
return self.storage_type.__arrow_ext_serialize__()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
deserialized_storage_type = storage_type.__arrow_ext_deserialize__(
serialized)
assert deserialized_storage_type == storage_type
return cls()
class ExampleUuidScalarType(pa.ExtensionScalar):
def as_py(self):
return None if self.value is None else UUID(bytes=self.value.as_py())
class ExampleUuidType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType')
def __reduce__(self):
return ExampleUuidType, ()
def __arrow_ext_scalar_class__(self):
return ExampleUuidScalarType
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()
class ExampleUuidType2(pa.ExtensionType):
def __init__(self):
super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType2')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()
class LabelType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.string(), 'pyarrow.tests.LabelType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()
class ParamExtType(pa.ExtensionType):
def __init__(self, width):
self._width = width
super().__init__(pa.binary(width), 'pyarrow.tests.ParamExtType')
@property
def width(self):
return self._width
def __arrow_ext_serialize__(self):
return str(self._width).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
width = int(serialized.decode())
assert storage_type == pa.binary(width)
return cls(width)
class MyStructType(pa.ExtensionType):
storage_type = pa.struct([('left', pa.int64()),
('right', pa.int64())])
def __init__(self):
super().__init__(self.storage_type, 'pyarrow.tests.MyStructType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
assert storage_type == cls.storage_type
return cls()
class MyListType(pa.ExtensionType):
def __init__(self, storage_type):
assert isinstance(storage_type, pa.ListType)
super().__init__(storage_type, 'pyarrow.tests.MyListType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
return cls(storage_type)
class MyFixedListType(pa.ExtensionType):
def __init__(self, storage_type):
assert isinstance(storage_type, pa.FixedSizeListType)
super().__init__(storage_type, 'pyarrow.tests.MyFixedListType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
return cls(storage_type)
class AnnotatedType(pa.ExtensionType):
"""
Generic extension type that can store any storage type.
"""
def __init__(self, storage_type, annotation):
self.annotation = annotation
super().__init__(storage_type, 'pyarrow.tests.AnnotatedType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
return cls(storage_type)
class LegacyIntType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.int8())
def __reduce__(self):
return LegacyIntType, ()
def ipc_write_batch(batch):
stream = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
writer.close()
return stream.getvalue()
def ipc_read_batch(buf):
reader = pa.RecordBatchStreamReader(buf)
return reader.read_next_batch()
def test_ext_type_basics():
ty = ExampleUuidType()
assert ty.extension_name == "pyarrow.tests.ExampleUuidType"
def test_ext_type_str():
ty = IntegerType()
expected = "extension<pyarrow.tests.IntegerType<IntegerType>>"
assert str(ty) == expected
assert pa.DataType.__str__(ty) == expected
def test_ext_type_repr():
ty = IntegerType()
assert repr(ty) == "IntegerType(DataType(int64))"
def test_ext_type_lifetime():
ty = ExampleUuidType()
wr = weakref.ref(ty)
del ty
assert wr() is None
def test_ext_type_storage_type():
ty = ExampleUuidType()
assert ty.storage_type == pa.binary(16)
assert ty.__class__ is ExampleUuidType
ty = ParamExtType(5)
assert ty.storage_type == pa.binary(5)
assert ty.__class__ is ParamExtType
def test_ext_type_byte_width():
# Test for fixed-size binary types
ty = pa.uuid()
assert ty.byte_width == 16
ty = ParamExtType(5)
assert ty.byte_width == 5
# Test for non fixed-size binary types
ty = LabelType()
with pytest.raises(ValueError, match="Non-fixed width type"):
_ = ty.byte_width
def test_ext_type_bit_width():
# Test for fixed-size binary types
ty = pa.uuid()
assert ty.bit_width == 128
ty = ParamExtType(5)
assert ty.bit_width == 40
# Test for non fixed-size binary types
ty = LabelType()
with pytest.raises(ValueError, match="Non-fixed width type"):
_ = ty.bit_width
def test_ext_type_as_py():
ty = ExampleUuidType()
expected = uuid4()
scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes)
assert scalar.as_py() == expected
# test array
uuids = [uuid4() for _ in range(3)]
storage = pa.array([uuid.bytes for uuid in uuids], type=pa.binary(16))
arr = pa.ExtensionArray.from_storage(ty, storage)
# Works for __get_item__
for i, expected in enumerate(uuids):
assert arr[i].as_py() == expected
# Works for __iter__
for result, expected in zip(arr, uuids):
assert result.as_py() == expected
# test chunked array
data = [
pa.ExtensionArray.from_storage(ty, storage),
pa.ExtensionArray.from_storage(ty, storage)
]
carr = pa.chunked_array(data)
for i, expected in enumerate(uuids + uuids):
assert carr[i].as_py() == expected
for result, expected in zip(carr, uuids + uuids):
Loading ...