Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
getdaft / utils.py
Size: Mime:
from __future__ import annotations

from typing import Any, Callable

from daft.dependencies import pa


def get_arrow_version():
    return tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric())


def in_notebook():
    """Check if we are in a Jupyter notebook."""
    try:
        from IPython import get_ipython

        if "IPKernelApp" not in get_ipython().config:  # pragma: no cover
            return False
    except ImportError:
        return False
    except AttributeError:
        return False
    return True


def pydict_to_rows(pydict: dict[str, list]) -> list[frozenset[tuple[str, Any]]]:
    """Converts a dataframe pydict to a list of rows representation.

    e.g.
    {
        "fruit": ["apple", "banana", "carrot"],
        "number": [1, 2, 3],
    }

    becomes
    [
        {("fruit", "apple"), ("number", 1)},
        {("fruit", "banana"), ("number", 2)},
        {("fruit", "carrot"), ("number", 3)},
    ]
    """
    return [
        frozenset((key, freeze(value)) for key, value in zip(pydict.keys(), values)) for values in zip(*pydict.values())
    ]


def freeze(input: dict | list | Any) -> frozenset | tuple | Any:
    """Freezes mutable containers for equality comparison."""
    if isinstance(input, dict):
        return frozenset((key, freeze(value)) for key, value in input.items())
    elif isinstance(input, list):
        return tuple(freeze(item) for item in input)
    else:
        return input


def map_operator_arrow_semantics_bool(
    operator: Callable[[Any, Any], Any],
    left_pylist: list,
    right_pylist: list,
) -> list[bool | None]:
    return [
        bool(operator(left, right)) if (left is not None and right is not None) else None
        for (left, right) in zip(left_pylist, right_pylist)
    ]


def python_list_membership_check(
    left_pylist: list,
    right_pylist: list,
) -> list:
    try:
        right_pyset = set(right_pylist)
        return [elem in right_pyset for elem in left_pylist]
    except TypeError:
        return [elem in right_pylist for elem in left_pylist]


def python_list_between_check(value_pylist: list, lower_pylist: list, upper_pylist: list) -> list:
    return [value <= upper and value >= lower for value, lower, upper in zip(value_pylist, lower_pylist, upper_pylist)]


def map_operator_arrow_semantics(
    operator: Callable[[Any, Any], Any],
    left_pylist: list,
    right_pylist: list,
) -> list:
    return [
        operator(left, right) if (left is not None and right is not None) else None
        for (left, right) in zip(left_pylist, right_pylist)
    ]


def pyarrow_supports_fixed_shape_tensor() -> bool:
    """Whether pyarrow supports the fixed_shape_tensor canonical extension type."""
    from daft.context import get_context

    return hasattr(pa, "fixed_shape_tensor") and (
        (get_context().get_or_create_runner().name != "ray") or get_arrow_version() >= (13, 0, 0)
    )