# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
try:
import numpy as np
except ImportError:
np = None
import pyarrow as pa
from pyarrow import compute as pc
# UDFs are all tested with a dataset scan
pytestmark = pytest.mark.dataset
# For convenience, most of the test here doesn't care about udf func docs
empty_udf_doc = {"summary": "", "description": ""}
try:
import pyarrow.dataset as ds
except ImportError:
ds = None
def mock_udf_context(batch_length=10):
from pyarrow._compute import _get_udf_context
return _get_udf_context(pa.default_memory_pool(), batch_length)
class MyError(RuntimeError):
pass
@pytest.fixture(scope="session")
def sum_agg_func_fixture():
"""
Register a unary aggregate function (mean)
"""
def func(ctx, x, *args):
return pa.scalar(np.nansum(x))
func_name = "sum_udf"
func_doc = empty_udf_doc
pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.float64(),
},
pa.float64()
)
return func, func_name
@pytest.fixture(scope="session")
def exception_agg_func_fixture():
def func(ctx, x):
raise RuntimeError("Oops")
return pa.scalar(len(x))
func_name = "y=exception_len(x)"
func_doc = empty_udf_doc
pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.int64(),
},
pa.int64()
)
return func, func_name
@pytest.fixture(scope="session")
def wrong_output_dtype_agg_func_fixture(scope="session"):
def func(ctx, x):
return pa.scalar(len(x), pa.int32())
func_name = "y=wrong_output_dtype(x)"
func_doc = empty_udf_doc
pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.int64(),
},
pa.int64()
)
return func, func_name
@pytest.fixture(scope="session")
def wrong_output_type_agg_func_fixture(scope="session"):
def func(ctx, x):
return len(x)
func_name = "y=wrong_output_type(x)"
func_doc = empty_udf_doc
pc.register_aggregate_function(func,
func_name,
func_doc,
{
"x": pa.int64(),
},
pa.int64()
)
return func, func_name
@pytest.fixture(scope="session")
def binary_func_fixture():
"""
Register a binary scalar function.
"""
def binary_function(ctx, m, x):
return pc.call_function("multiply", [m, x],
memory_pool=ctx.memory_pool)
func_name = "y=mx"
binary_doc = {"summary": "y=mx",
"description": "find y from y = mx"}
pc.register_scalar_function(binary_function,
func_name,
binary_doc,
{"m": pa.int64(),
"x": pa.int64(),
},
pa.int64())
return binary_function, func_name
@pytest.fixture(scope="session")
def ternary_func_fixture():
"""
Register a ternary scalar function.
"""
def ternary_function(ctx, m, x, c):
mx = pc.call_function("multiply", [m, x],
memory_pool=ctx.memory_pool)
return pc.call_function("add", [mx, c],
memory_pool=ctx.memory_pool)
ternary_doc = {"summary": "y=mx+c",
"description": "find y from y = mx + c"}
func_name = "y=mx+c"
pc.register_scalar_function(ternary_function,
func_name,
ternary_doc,
{
"array1": pa.int64(),
"array2": pa.int64(),
"array3": pa.int64(),
},
pa.int64())
return ternary_function, func_name
@pytest.fixture(scope="session")
def varargs_func_fixture():
"""
Register a varargs scalar function with at least two arguments.
"""
def varargs_function(ctx, first, *values):
acc = first
for val in values:
acc = pc.call_function("add", [acc, val],
memory_pool=ctx.memory_pool)
return acc
func_name = "z=ax+by+c"
varargs_doc = {"summary": "z=ax+by+c",
"description": "find z from z = ax + by + c"
}
pc.register_scalar_function(varargs_function,
func_name,
varargs_doc,
{
"array1": pa.int64(),
"array2": pa.int64(),
},
pa.int64())
return varargs_function, func_name
@pytest.fixture(scope="session")
def nullary_func_fixture():
"""
Register a nullary scalar function.
"""
def nullary_func(context):
return pa.array([42] * context.batch_length, type=pa.int64(),
memory_pool=context.memory_pool)
func_doc = {
"summary": "random function",
"description": "generates a random value"
}
func_name = "test_nullary_func"
pc.register_scalar_function(nullary_func,
func_name,
func_doc,
{},
pa.int64())
return nullary_func, func_name
@pytest.fixture(scope="session")
def ephemeral_nullary_func_fixture():
"""
Register a nullary scalar function with an ephemeral Python function.
This stresses that the Python function object is properly kept alive by the
registered function.
"""
def nullary_func(context):
return pa.array([42] * context.batch_length, type=pa.int64(),
memory_pool=context.memory_pool)
func_doc = {
"summary": "random function",
"description": "generates a random value"
}
func_name = "test_ephemeral_nullary_func"
pc.register_scalar_function(nullary_func,
func_name,
func_doc,
{},
pa.int64())
return func_name
@pytest.fixture(scope="session")
def wrong_output_type_func_fixture():
"""
Register a scalar function which returns something that is neither
a Arrow scalar or array.
"""
def wrong_output_type(ctx):
return 42
func_name = "test_wrong_output_type"
in_types = {}
out_type = pa.int64()
doc = {
"summary": "return wrong output type",
"description": ""
}
pc.register_scalar_function(wrong_output_type, func_name, doc,
in_types, out_type)
return wrong_output_type, func_name
@pytest.fixture(scope="session")
def wrong_output_datatype_func_fixture():
"""
Register a scalar function whose actual output DataType doesn't
match the declared output DataType.
"""
def wrong_output_datatype(ctx, array):
return pc.call_function("add", [array, 1])
func_name = "test_wrong_output_datatype"
in_types = {"array": pa.int64()}
# The actual output DataType will be int64.
out_type = pa.int16()
doc = {
"summary": "return wrong output datatype",
"description": ""
}
pc.register_scalar_function(wrong_output_datatype, func_name, doc,
in_types, out_type)
return wrong_output_datatype, func_name
@pytest.fixture(scope="session")
def wrong_signature_func_fixture():
"""
Register a scalar function with the wrong signature.
"""
# Missing the context argument
def wrong_signature():
return pa.scalar(1, type=pa.int64())
func_name = "test_wrong_signature"
in_types = {}
out_type = pa.int64()
doc = {
"summary": "UDF with wrong signature",
"description": ""
}
pc.register_scalar_function(wrong_signature, func_name, doc,
in_types, out_type)
return wrong_signature, func_name
@pytest.fixture(scope="session")
def raising_func_fixture():
"""
Register a scalar function which raises a custom exception.
"""
def raising_func(ctx):
raise MyError("error raised by scalar UDF")
func_name = "test_raise"
doc = {
"summary": "raising function",
"description": ""
}
pc.register_scalar_function(raising_func, func_name, doc,
{}, pa.int64())
return raising_func, func_name
@pytest.fixture(scope="session")
def unary_vector_func_fixture():
"""
Register a vector function
"""
def pct_rank(ctx, x):
# copy here to get around pandas 1.0 issue
return pa.array(x.to_pandas().copy().rank(pct=True))
func_name = "y=pct_rank(x)"
doc = empty_udf_doc
pc.register_vector_function(pct_rank, func_name, doc, {
'x': pa.float64()}, pa.float64())
return pct_rank, func_name
Loading ...