Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

/ tests / test_udf.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


import pytest

import numpy as np

import pyarrow as pa
from pyarrow import compute as pc

# UDFs are all tested with a dataset scan
pytestmark = pytest.mark.dataset

# For convenience, most of the test here doesn't care about udf func docs
empty_udf_doc = {"summary": "", "description": ""}

try:
    import pyarrow.dataset as ds
except ImportError:
    ds = None


def mock_udf_context(batch_length=10):
    from pyarrow._compute import _get_udf_context
    return _get_udf_context(pa.default_memory_pool(), batch_length)


class MyError(RuntimeError):
    pass


@pytest.fixture(scope="session")
def sum_agg_func_fixture():
    """
    Register a unary aggregate function (mean)
    """
    def func(ctx, x, *args):
        return pa.scalar(np.nansum(x))

    func_name = "sum_udf"
    func_doc = empty_udf_doc

    pc.register_aggregate_function(func,
                                   func_name,
                                   func_doc,
                                   {
                                       "x": pa.float64(),
                                   },
                                   pa.float64()
                                   )
    return func, func_name


@pytest.fixture(scope="session")
def exception_agg_func_fixture():
    def func(ctx, x):
        raise RuntimeError("Oops")
        return pa.scalar(len(x))

    func_name = "y=exception_len(x)"
    func_doc = empty_udf_doc

    pc.register_aggregate_function(func,
                                   func_name,
                                   func_doc,
                                   {
                                       "x": pa.int64(),
                                   },
                                   pa.int64()
                                   )
    return func, func_name


@pytest.fixture(scope="session")
def wrong_output_dtype_agg_func_fixture(scope="session"):
    def func(ctx, x):
        return pa.scalar(len(x), pa.int32())

    func_name = "y=wrong_output_dtype(x)"
    func_doc = empty_udf_doc

    pc.register_aggregate_function(func,
                                   func_name,
                                   func_doc,
                                   {
                                       "x": pa.int64(),
                                   },
                                   pa.int64()
                                   )
    return func, func_name


@pytest.fixture(scope="session")
def wrong_output_type_agg_func_fixture(scope="session"):
    def func(ctx, x):
        return len(x)

    func_name = "y=wrong_output_type(x)"
    func_doc = empty_udf_doc

    pc.register_aggregate_function(func,
                                   func_name,
                                   func_doc,
                                   {
                                       "x": pa.int64(),
                                   },
                                   pa.int64()
                                   )
    return func, func_name


@pytest.fixture(scope="session")
def binary_func_fixture():
    """
    Register a binary scalar function.
    """
    def binary_function(ctx, m, x):
        return pc.call_function("multiply", [m, x],
                                memory_pool=ctx.memory_pool)
    func_name = "y=mx"
    binary_doc = {"summary": "y=mx",
                  "description": "find y from y = mx"}
    pc.register_scalar_function(binary_function,
                                func_name,
                                binary_doc,
                                {"m": pa.int64(),
                                 "x": pa.int64(),
                                 },
                                pa.int64())
    return binary_function, func_name


@pytest.fixture(scope="session")
def ternary_func_fixture():
    """
    Register a ternary scalar function.
    """
    def ternary_function(ctx, m, x, c):
        mx = pc.call_function("multiply", [m, x],
                              memory_pool=ctx.memory_pool)
        return pc.call_function("add", [mx, c],
                                memory_pool=ctx.memory_pool)
    ternary_doc = {"summary": "y=mx+c",
                   "description": "find y from y = mx + c"}
    func_name = "y=mx+c"
    pc.register_scalar_function(ternary_function,
                                func_name,
                                ternary_doc,
                                {
                                    "array1": pa.int64(),
                                    "array2": pa.int64(),
                                    "array3": pa.int64(),
                                },
                                pa.int64())
    return ternary_function, func_name


@pytest.fixture(scope="session")
def varargs_func_fixture():
    """
    Register a varargs scalar function with at least two arguments.
    """
    def varargs_function(ctx, first, *values):
        acc = first
        for val in values:
            acc = pc.call_function("add", [acc, val],
                                   memory_pool=ctx.memory_pool)
        return acc
    func_name = "z=ax+by+c"
    varargs_doc = {"summary": "z=ax+by+c",
                   "description": "find z from z = ax + by + c"
                   }
    pc.register_scalar_function(varargs_function,
                                func_name,
                                varargs_doc,
                                {
                                    "array1": pa.int64(),
                                    "array2": pa.int64(),
                                },
                                pa.int64())
    return varargs_function, func_name


@pytest.fixture(scope="session")
def nullary_func_fixture():
    """
    Register a nullary scalar function.
    """
    def nullary_func(context):
        return pa.array([42] * context.batch_length, type=pa.int64(),
                        memory_pool=context.memory_pool)

    func_doc = {
        "summary": "random function",
        "description": "generates a random value"
    }
    func_name = "test_nullary_func"
    pc.register_scalar_function(nullary_func,
                                func_name,
                                func_doc,
                                {},
                                pa.int64())

    return nullary_func, func_name


@pytest.fixture(scope="session")
def wrong_output_type_func_fixture():
    """
    Register a scalar function which returns something that is neither
    a Arrow scalar or array.
    """
    def wrong_output_type(ctx):
        return 42

    func_name = "test_wrong_output_type"
    in_types = {}
    out_type = pa.int64()
    doc = {
        "summary": "return wrong output type",
        "description": ""
    }
    pc.register_scalar_function(wrong_output_type, func_name, doc,
                                in_types, out_type)
    return wrong_output_type, func_name


@pytest.fixture(scope="session")
def wrong_output_datatype_func_fixture():
    """
    Register a scalar function whose actual output DataType doesn't
    match the declared output DataType.
    """
    def wrong_output_datatype(ctx, array):
        return pc.call_function("add", [array, 1])
    func_name = "test_wrong_output_datatype"
    in_types = {"array": pa.int64()}
    # The actual output DataType will be int64.
    out_type = pa.int16()
    doc = {
        "summary": "return wrong output datatype",
        "description": ""
    }
    pc.register_scalar_function(wrong_output_datatype, func_name, doc,
                                in_types, out_type)
    return wrong_output_datatype, func_name


@pytest.fixture(scope="session")
def wrong_signature_func_fixture():
    """
    Register a scalar function with the wrong signature.
    """
    # Missing the context argument
    def wrong_signature():
        return pa.scalar(1, type=pa.int64())

    func_name = "test_wrong_signature"
    in_types = {}
    out_type = pa.int64()
    doc = {
        "summary": "UDF with wrong signature",
        "description": ""
    }
    pc.register_scalar_function(wrong_signature, func_name, doc,
                                in_types, out_type)
    return wrong_signature, func_name


@pytest.fixture(scope="session")
def raising_func_fixture():
    """
    Register a scalar function which raises a custom exception.
    """
    def raising_func(ctx):
        raise MyError("error raised by scalar UDF")
    func_name = "test_raise"
    doc = {
        "summary": "raising function",
        "description": ""
    }
    pc.register_scalar_function(raising_func, func_name, doc,
                                {}, pa.int64())
    return raising_func, func_name


@pytest.fixture(scope="session")
def unary_vector_func_fixture():
    """
    Register a vector function
    """
    def pct_rank(ctx, x):
        # copy here to get around pandas 1.0 issue
        return pa.array(x.to_pandas().copy().rank(pct=True))

    func_name = "y=pct_rank(x)"
    doc = empty_udf_doc
    pc.register_vector_function(pct_rank, func_name, doc, {
                                'x': pa.float64()}, pa.float64())

    return pct_rank, func_name


@pytest.fixture(scope="session")
def struct_vector_func_fixture():
    """
    Register a vector function that returns a struct array
    """
    def pivot(ctx, k, v, c):
        df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v', 'c']).to_pandas()
        df_pivot = df.pivot(columns='c', values='v', index='k').reset_index()
        return pa.RecordBatch.from_pandas(df_pivot).to_struct_array()

    func_name = "y=pivot(x)"
    doc = empty_udf_doc
    pc.register_vector_function(
        pivot, func_name, doc,
        {'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()},
        pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2', pa.float64())])
    )

    return pivot, func_name


def check_scalar_function(func_fixture,
                          inputs, *,
                          run_in_dataset=True,
                          batch_length=None):
    function, name = func_fixture
    if batch_length is None:
Loading ...