Repository URL to install this package:
|
Version:
0.35.0 ▾
|
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
"""Bfloat16 support."""
from typing import Any, Optional, Union
import pyarrow as pa
from lance.dependencies import numpy as np
from ..lance import BFloat16
from ..lance import bfloat16_array as bfloat16_array
class BFloat16Array(pa.ExtensionArray):
"""Bfloat16 PyArrow Array."""
def __repr__(self):
return "<lance.arrow.BFloat16Array object at 0x%016x>\n%s" % (
id(self),
str(self),
)
def __str__(self):
if len(self) < 22:
return "[\n" + ",\n".join([f" {v}" for v in self]) + "\n]"
return (
"[\n"
+ ",\n".join([f" {v}" for v in self[:10]])
+ ",\n...\n"
+ ",\n".join([f" {v}" for v in self[-10:]])
+ "\n]"
)
def to_numpy(self, zero_copy_only=False):
"""Convert to a NumPy array.
This will do a zero-copy conversion.
The conversion will fail if the array contains null values."""
if self.null_count > 0:
raise ValueError("Cannot convert null values to numpy")
from ml_dtypes import bfloat16
buffer = self.storage.buffers()[1]
array = np.frombuffer(buffer, dtype=bfloat16)
return array
@classmethod
def from_numpy(cls, array: np.ndarray):
"""Create a :class:`BFloat16Array` from a NumPy array.
Can only convert from a NumPy array of dtype ``bfloat16`` from the ``ml_dtypes``
module.
Examples
--------
>>> import numpy as np
>>> from ml_dtypes import bfloat16
>>> from lance.arrow import BFloat16Array
>>> arr = np.array([1.0, 2.0, 3.0], dtype=bfloat16)
>>> print(BFloat16Array.from_numpy(arr))
[
1,
2,
3
]
"""
from ml_dtypes import bfloat16
if array.dtype != bfloat16:
raise ValueError("Cannot convert non-bfloat16 values to BFloat16Array")
if array.ndim != 1:
raise ValueError("Cannot convert multi-dimensional array to BFloat16Array")
data = pa.py_buffer(array.tobytes())
return pa.Array.from_buffers(BFloat16Type(), len(array), [None, data])
class BFloat16Scalar(pa.ExtensionScalar):
def as_py(self, **kwargs) -> Optional[BFloat16]:
if self.value is None:
return None
else:
return BFloat16.from_bytes(self.value.as_py(**kwargs))
def __eq__(self, other: Any):
from ml_dtypes import bfloat16
if isinstance(other, BFloat16):
return self.as_py() == other
elif isinstance(other, BFloat16Scalar):
return self.as_py() == other.as_py()
elif isinstance(other, bfloat16):
return self.as_py() == BFloat16.from_bytes(other.tobytes())
else:
return False
class BFloat16Type(pa.ExtensionType):
def __init__(self):
pa.ExtensionType.__init__(self, pa.binary(2), "lance.bfloat16")
def __arrow_ext_serialize__(self):
# TODO: encode endianness
return b""
@classmethod
def __arrow_ext_deserialize__(self, storage_type, serialized):
# TODO: decode endianness
return BFloat16Type()
def __arrow_ext_class__(self):
return BFloat16Array
def __arrow_ext_scalar_class__(self):
return BFloat16Scalar
def to_pandas_dtype(self):
return PandasBFloat16Type()
try:
from pandas.api.extensions import (
ExtensionArray,
ExtensionDtype,
register_extension_dtype,
)
except ImportError:
pass
else:
# Define Pandas and register Pandas extensions
@register_extension_dtype
class PandasBFloat16Type(ExtensionDtype):
kind = "f"
na_value = None
name = "lance.bfloat16"
names = None
type = BFloat16
_is_numeric = True
def __from_arrow__(
self, array: Union[pa.Array, pa.ChunkedArray]
) -> ExtensionArray:
return PandasBFloat16Array(array)
def construct_array_type(self):
return PandasBFloat16Array
@classmethod
def construct_from_string(cls, string):
if string == "lance.bfloat16":
return cls()
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
class PandasBFloat16Array(ExtensionArray):
dtype = PandasBFloat16Type()
def __init__(self, data):
self.data = data
@classmethod
def _from_sequence(
cls, scalars, *, dtype: PandasBFloat16Type, copy: bool = False
):
return PandasBFloat16Array(bfloat16_array(scalars))
def __getitem__(self, item):
if isinstance(item, int):
return self.data[item].as_py()
elif isinstance(item, slice):
return PandasBFloat16Array(self.data[item])
elif isinstance(item, np.ndarray) and np.issubdtype(item.dtype, np.bool_):
return PandasBFloat16Array(self.data.filter(pa.array(item)))
else:
raise NotImplementedError()
def __len__(self):
return len(self.data)
def isna(self):
return self.data.is_null().to_numpy(zero_copy_only=False)
def to_numpy(self, *args, **kwargs):
return self.data.to_numpy()
def __arrow_array__(self, type=None):
return self.data
@classmethod
def from_numpy(cls, array):
inner = BFloat16Array.from_numpy(array)
return cls(inner)
pa.register_extension_type(BFloat16Type())