Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

Version: 19.0.0.dev70 

/ tests / test_extension_type.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import contextlib
import os
import shutil
import subprocess
import weakref
from uuid import uuid4, UUID
import sys

import pytest
try:
    import numpy as np
except ImportError:
    np = None

import pyarrow as pa
from pyarrow.vendored.version import Version


@contextlib.contextmanager
def registered_extension_type(ext_type):
    pa.register_extension_type(ext_type)
    try:
        yield
    finally:
        pa.unregister_extension_type(ext_type.extension_name)


@contextlib.contextmanager
def enabled_auto_load():
    pa.PyExtensionType.set_auto_load(True)
    try:
        yield
    finally:
        pa.PyExtensionType.set_auto_load(False)


class TinyIntType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.int8(), 'pyarrow.tests.TinyIntType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        assert storage_type == pa.int8()
        return cls()


class IntegerType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.int64(), 'pyarrow.tests.IntegerType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        assert storage_type == pa.int64()
        return cls()


class IntegerEmbeddedType(pa.ExtensionType):

    def __init__(self):
        super().__init__(IntegerType(), 'pyarrow.tests.IntegerType')

    def __arrow_ext_serialize__(self):
        # XXX pa.BaseExtensionType should expose C++ serialization method
        return self.storage_type.__arrow_ext_serialize__()

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        deserialized_storage_type = storage_type.__arrow_ext_deserialize__(
            serialized)
        assert deserialized_storage_type == storage_type
        return cls()


class ExampleUuidScalarType(pa.ExtensionScalar):
    def as_py(self):
        return None if self.value is None else UUID(bytes=self.value.as_py())


class ExampleUuidType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType')

    def __reduce__(self):
        return ExampleUuidType, ()

    def __arrow_ext_scalar_class__(self):
        return ExampleUuidScalarType

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        return cls()


class ExampleUuidType2(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType2')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        return cls()


class LabelType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.string(), 'pyarrow.tests.LabelType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        return cls()


class ParamExtType(pa.ExtensionType):

    def __init__(self, width):
        self._width = width
        super().__init__(pa.binary(width), 'pyarrow.tests.ParamExtType')

    @property
    def width(self):
        return self._width

    def __arrow_ext_serialize__(self):
        return str(self._width).encode()

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        width = int(serialized.decode())
        assert storage_type == pa.binary(width)
        return cls(width)


class MyStructType(pa.ExtensionType):
    storage_type = pa.struct([('left', pa.int64()),
                              ('right', pa.int64())])

    def __init__(self):
        super().__init__(self.storage_type, 'pyarrow.tests.MyStructType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        assert storage_type == cls.storage_type
        return cls()


class MyListType(pa.ExtensionType):

    def __init__(self, storage_type):
        assert isinstance(storage_type, pa.ListType)
        super().__init__(storage_type, 'pyarrow.tests.MyListType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        return cls(storage_type)


class MyFixedListType(pa.ExtensionType):

    def __init__(self, storage_type):
        assert isinstance(storage_type, pa.FixedSizeListType)
        super().__init__(storage_type, 'pyarrow.tests.MyFixedListType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        return cls(storage_type)


class AnnotatedType(pa.ExtensionType):
    """
    Generic extension type that can store any storage type.
    """

    def __init__(self, storage_type, annotation):
        self.annotation = annotation
        super().__init__(storage_type, 'pyarrow.tests.AnnotatedType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        return cls(storage_type)


class LegacyIntType(pa.PyExtensionType):

    def __init__(self):
        pa.PyExtensionType.__init__(self, pa.int8())

    def __reduce__(self):
        return LegacyIntType, ()


def ipc_write_batch(batch):
    stream = pa.BufferOutputStream()
    writer = pa.RecordBatchStreamWriter(stream, batch.schema)
    writer.write_batch(batch)
    writer.close()
    return stream.getvalue()


def ipc_read_batch(buf):
    reader = pa.RecordBatchStreamReader(buf)
    return reader.read_next_batch()


def test_ext_type_basics():
    ty = ExampleUuidType()
    assert ty.extension_name == "pyarrow.tests.ExampleUuidType"


def test_ext_type_str():
    ty = IntegerType()
    expected = "extension<pyarrow.tests.IntegerType<IntegerType>>"
    assert str(ty) == expected
    assert pa.DataType.__str__(ty) == expected


def test_ext_type_repr():
    ty = IntegerType()
    assert repr(ty) == "IntegerType(DataType(int64))"


def test_ext_type_lifetime():
    ty = ExampleUuidType()
    wr = weakref.ref(ty)
    del ty
    assert wr() is None


def test_ext_type_storage_type():
    ty = ExampleUuidType()
    assert ty.storage_type == pa.binary(16)
    assert ty.__class__ is ExampleUuidType
    ty = ParamExtType(5)
    assert ty.storage_type == pa.binary(5)
    assert ty.__class__ is ParamExtType


def test_ext_type_byte_width():
    # Test for fixed-size binary types
    ty = pa.uuid()
    assert ty.byte_width == 16
    ty = ParamExtType(5)
    assert ty.byte_width == 5

    # Test for non fixed-size binary types
    ty = LabelType()
    with pytest.raises(ValueError, match="Non-fixed width type"):
        _ = ty.byte_width


def test_ext_type_bit_width():
    # Test for fixed-size binary types
    ty = pa.uuid()
    assert ty.bit_width == 128
    ty = ParamExtType(5)
    assert ty.bit_width == 40

    # Test for non fixed-size binary types
    ty = LabelType()
    with pytest.raises(ValueError, match="Non-fixed width type"):
        _ = ty.bit_width


def test_ext_type_as_py():
    ty = ExampleUuidType()
    expected = uuid4()
    scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes)
    assert scalar.as_py() == expected

    # test array
    uuids = [uuid4() for _ in range(3)]
    storage = pa.array([uuid.bytes for uuid in uuids], type=pa.binary(16))
    arr = pa.ExtensionArray.from_storage(ty, storage)

    # Works for __get_item__
    for i, expected in enumerate(uuids):
        assert arr[i].as_py() == expected

    # Works for __iter__
    for result, expected in zip(arr, uuids):
        assert result.as_py() == expected

    # test chunked array
    data = [
        pa.ExtensionArray.from_storage(ty, storage),
        pa.ExtensionArray.from_storage(ty, storage)
    ]
    carr = pa.chunked_array(data)
    for i, expected in enumerate(uuids + uuids):
        assert carr[i].as_py() == expected

    for result, expected in zip(carr, uuids + uuids):
Loading ...