# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import os
import shutil
import subprocess
import weakref
from uuid import uuid4, UUID
import sys
import numpy as np
import pyarrow as pa
from pyarrow.vendored.version import Version
import pytest
@contextlib.contextmanager
def registered_extension_type(ext_type):
pa.register_extension_type(ext_type)
try:
yield
finally:
pa.unregister_extension_type(ext_type.extension_name)
@contextlib.contextmanager
def enabled_auto_load():
pa.PyExtensionType.set_auto_load(True)
try:
yield
finally:
pa.PyExtensionType.set_auto_load(False)
class TinyIntType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.int8(), 'pyarrow.tests.TinyIntType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
assert storage_type == pa.int8()
return cls()
class IntegerType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.int64(), 'pyarrow.tests.IntegerType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
assert storage_type == pa.int64()
return cls()
class IntegerEmbeddedType(pa.ExtensionType):
def __init__(self):
super().__init__(IntegerType(), 'pyarrow.tests.IntegerType')
def __arrow_ext_serialize__(self):
# XXX pa.BaseExtensionType should expose C++ serialization method
return self.storage_type.__arrow_ext_serialize__()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
deserialized_storage_type = storage_type.__arrow_ext_deserialize__(
serialized)
assert deserialized_storage_type == storage_type
return cls()
class UuidScalarType(pa.ExtensionScalar):
def as_py(self):
return None if self.value is None else UUID(bytes=self.value.as_py())
class UuidType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.binary(16), 'pyarrow.tests.UuidType')
def __arrow_ext_scalar_class__(self):
return UuidScalarType
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()
class UuidType2(pa.ExtensionType):
def __init__(self):
super().__init__(pa.binary(16), 'pyarrow.tests.UuidType2')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()
class LabelType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.string(), 'pyarrow.tests.LabelType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()
class ParamExtType(pa.ExtensionType):
def __init__(self, width):
self._width = width
super().__init__(pa.binary(width), 'pyarrow.tests.ParamExtType')
@property
def width(self):
return self._width
def __arrow_ext_serialize__(self):
return str(self._width).encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
width = int(serialized.decode())
assert storage_type == pa.binary(width)
return cls(width)
class MyStructType(pa.ExtensionType):
storage_type = pa.struct([('left', pa.int64()),
('right', pa.int64())])
def __init__(self):
super().__init__(self.storage_type, 'pyarrow.tests.MyStructType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
assert storage_type == cls.storage_type
return cls()
class MyListType(pa.ExtensionType):
def __init__(self, storage_type):
assert isinstance(storage_type, pa.ListType)
super().__init__(storage_type, 'pyarrow.tests.MyListType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
return cls(storage_type)
class AnnotatedType(pa.ExtensionType):
"""
Generic extension type that can store any storage type.
"""
def __init__(self, storage_type, annotation):
self.annotation = annotation
super().__init__(storage_type, 'pyarrow.tests.AnnotatedType')
def __arrow_ext_serialize__(self):
return b''
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
return cls(storage_type)
class LegacyIntType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.int8())
def __reduce__(self):
return LegacyIntType, ()
def ipc_write_batch(batch):
stream = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
writer.close()
return stream.getvalue()
def ipc_read_batch(buf):
reader = pa.RecordBatchStreamReader(buf)
return reader.read_next_batch()
def test_ext_type_basics():
ty = UuidType()
assert ty.extension_name == "pyarrow.tests.UuidType"
def test_ext_type_str():
ty = IntegerType()
expected = "extension<pyarrow.tests.IntegerType<IntegerType>>"
assert str(ty) == expected
assert pa.DataType.__str__(ty) == expected
def test_ext_type_repr():
ty = IntegerType()
assert repr(ty) == "IntegerType(DataType(int64))"
def test_ext_type__lifetime():
ty = UuidType()
wr = weakref.ref(ty)
del ty
assert wr() is None
def test_ext_type__storage_type():
ty = UuidType()
assert ty.storage_type == pa.binary(16)
assert ty.__class__ is UuidType
ty = ParamExtType(5)
assert ty.storage_type == pa.binary(5)
assert ty.__class__ is ParamExtType
def test_ext_type_as_py():
ty = UuidType()
expected = uuid4()
scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes)
assert scalar.as_py() == expected
# test array
uuids = [uuid4() for _ in range(3)]
storage = pa.array([uuid.bytes for uuid in uuids], type=pa.binary(16))
arr = pa.ExtensionArray.from_storage(ty, storage)
# Works for __get_item__
for i, expected in enumerate(uuids):
assert arr[i].as_py() == expected
# Works for __iter__
for result, expected in zip(arr, uuids):
assert result.as_py() == expected
# test chunked array
data = [
pa.ExtensionArray.from_storage(ty, storage),
pa.ExtensionArray.from_storage(ty, storage)
]
carr = pa.chunked_array(data)
for i, expected in enumerate(uuids + uuids):
assert carr[i].as_py() == expected
for result, expected in zip(carr, uuids + uuids):
assert result.as_py() == expected
def test_uuid_type_pickle(pickle_module):
for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1):
ty = UuidType()
ser = pickle_module.dumps(ty, protocol=proto)
del ty
ty = pickle_module.loads(ser)
wr = weakref.ref(ty)
assert ty.extension_name == "pyarrow.tests.UuidType"
del ty
assert wr() is None
def test_ext_type_equality():
a = ParamExtType(5)
b = ParamExtType(6)
c = ParamExtType(6)
assert a != b
assert b == c
d = UuidType()
e = UuidType()
assert a != d
assert d == e
def test_ext_array_basics():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
arr.validate()
assert arr.type is ty
assert arr.storage.equals(storage)
def test_ext_array_lifetime():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
arr = pa.ExtensionArray.from_storage(ty, storage)
refs = [weakref.ref(ty), weakref.ref(arr), weakref.ref(storage)]
del ty, storage, arr
for ref in refs:
assert ref() is None
Loading ...