Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

/ tests / test_extension_type.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import contextlib
import os
import shutil
import subprocess
import weakref
from uuid import uuid4, UUID
import sys

import numpy as np
import pyarrow as pa
from pyarrow.vendored.version import Version

import pytest


@contextlib.contextmanager
def registered_extension_type(ext_type):
    pa.register_extension_type(ext_type)
    try:
        yield
    finally:
        pa.unregister_extension_type(ext_type.extension_name)


@contextlib.contextmanager
def enabled_auto_load():
    pa.PyExtensionType.set_auto_load(True)
    try:
        yield
    finally:
        pa.PyExtensionType.set_auto_load(False)


class TinyIntType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.int8(), 'pyarrow.tests.TinyIntType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        assert storage_type == pa.int8()
        return cls()


class IntegerType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.int64(), 'pyarrow.tests.IntegerType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        assert storage_type == pa.int64()
        return cls()


class IntegerEmbeddedType(pa.ExtensionType):

    def __init__(self):
        super().__init__(IntegerType(), 'pyarrow.tests.IntegerType')

    def __arrow_ext_serialize__(self):
        # XXX pa.BaseExtensionType should expose C++ serialization method
        return self.storage_type.__arrow_ext_serialize__()

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        deserialized_storage_type = storage_type.__arrow_ext_deserialize__(
            serialized)
        assert deserialized_storage_type == storage_type
        return cls()


class UuidScalarType(pa.ExtensionScalar):
    def as_py(self):
        return None if self.value is None else UUID(bytes=self.value.as_py())


class UuidType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.binary(16), 'pyarrow.tests.UuidType')

    def __arrow_ext_scalar_class__(self):
        return UuidScalarType

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        return cls()


class UuidType2(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.binary(16), 'pyarrow.tests.UuidType2')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        return cls()


class LabelType(pa.ExtensionType):

    def __init__(self):
        super().__init__(pa.string(), 'pyarrow.tests.LabelType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        return cls()


class ParamExtType(pa.ExtensionType):

    def __init__(self, width):
        self._width = width
        super().__init__(pa.binary(width), 'pyarrow.tests.ParamExtType')

    @property
    def width(self):
        return self._width

    def __arrow_ext_serialize__(self):
        return str(self._width).encode()

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        width = int(serialized.decode())
        assert storage_type == pa.binary(width)
        return cls(width)


class MyStructType(pa.ExtensionType):
    storage_type = pa.struct([('left', pa.int64()),
                              ('right', pa.int64())])

    def __init__(self):
        super().__init__(self.storage_type, 'pyarrow.tests.MyStructType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        assert storage_type == cls.storage_type
        return cls()


class MyListType(pa.ExtensionType):

    def __init__(self, storage_type):
        assert isinstance(storage_type, pa.ListType)
        super().__init__(storage_type, 'pyarrow.tests.MyListType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        return cls(storage_type)


class AnnotatedType(pa.ExtensionType):
    """
    Generic extension type that can store any storage type.
    """

    def __init__(self, storage_type, annotation):
        self.annotation = annotation
        super().__init__(storage_type, 'pyarrow.tests.AnnotatedType')

    def __arrow_ext_serialize__(self):
        return b''

    @classmethod
    def __arrow_ext_deserialize__(cls, storage_type, serialized):
        assert serialized == b''
        return cls(storage_type)


class LegacyIntType(pa.PyExtensionType):

    def __init__(self):
        pa.PyExtensionType.__init__(self, pa.int8())

    def __reduce__(self):
        return LegacyIntType, ()


def ipc_write_batch(batch):
    stream = pa.BufferOutputStream()
    writer = pa.RecordBatchStreamWriter(stream, batch.schema)
    writer.write_batch(batch)
    writer.close()
    return stream.getvalue()


def ipc_read_batch(buf):
    reader = pa.RecordBatchStreamReader(buf)
    return reader.read_next_batch()


def test_ext_type_basics():
    ty = UuidType()
    assert ty.extension_name == "pyarrow.tests.UuidType"


def test_ext_type_str():
    ty = IntegerType()
    expected = "extension<pyarrow.tests.IntegerType<IntegerType>>"
    assert str(ty) == expected
    assert pa.DataType.__str__(ty) == expected


def test_ext_type_repr():
    ty = IntegerType()
    assert repr(ty) == "IntegerType(DataType(int64))"


def test_ext_type__lifetime():
    ty = UuidType()
    wr = weakref.ref(ty)
    del ty
    assert wr() is None


def test_ext_type__storage_type():
    ty = UuidType()
    assert ty.storage_type == pa.binary(16)
    assert ty.__class__ is UuidType
    ty = ParamExtType(5)
    assert ty.storage_type == pa.binary(5)
    assert ty.__class__ is ParamExtType


def test_ext_type_as_py():
    ty = UuidType()
    expected = uuid4()
    scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes)
    assert scalar.as_py() == expected

    # test array
    uuids = [uuid4() for _ in range(3)]
    storage = pa.array([uuid.bytes for uuid in uuids], type=pa.binary(16))
    arr = pa.ExtensionArray.from_storage(ty, storage)

    # Works for __get_item__
    for i, expected in enumerate(uuids):
        assert arr[i].as_py() == expected

    # Works for __iter__
    for result, expected in zip(arr, uuids):
        assert result.as_py() == expected

    # test chunked array
    data = [
        pa.ExtensionArray.from_storage(ty, storage),
        pa.ExtensionArray.from_storage(ty, storage)
    ]
    carr = pa.chunked_array(data)
    for i, expected in enumerate(uuids + uuids):
        assert carr[i].as_py() == expected

    for result, expected in zip(carr, uuids + uuids):
        assert result.as_py() == expected


def test_uuid_type_pickle(pickle_module):
    for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1):
        ty = UuidType()
        ser = pickle_module.dumps(ty, protocol=proto)
        del ty
        ty = pickle_module.loads(ser)
        wr = weakref.ref(ty)
        assert ty.extension_name == "pyarrow.tests.UuidType"
        del ty
        assert wr() is None


def test_ext_type_equality():
    a = ParamExtType(5)
    b = ParamExtType(6)
    c = ParamExtType(6)
    assert a != b
    assert b == c
    d = UuidType()
    e = UuidType()
    assert a != d
    assert d == e


def test_ext_array_basics():
    ty = ParamExtType(3)
    storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
    arr = pa.ExtensionArray.from_storage(ty, storage)
    arr.validate()
    assert arr.type is ty
    assert arr.storage.equals(storage)


def test_ext_array_lifetime():
    ty = ParamExtType(3)
    storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
    arr = pa.ExtensionArray.from_storage(ty, storage)

    refs = [weakref.ref(ty), weakref.ref(arr), weakref.ref(storage)]
    del ty, storage, arr
    for ref in refs:
        assert ref() is None

Loading ...