// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/python/udf.h"
#include "arrow/array/array_nested.h"
#include "arrow/array/builder_base.h"
#include "arrow/buffer_builder.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/function.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/row/grouper.h"
#include "arrow/python/common.h"
#include "arrow/python/vendored/pythoncapi_compat.h"
#include "arrow/table.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
namespace arrow {
using compute::ExecSpan;
using compute::Grouper;
using compute::KernelContext;
using compute::KernelState;
using internal::checked_cast;
namespace py {
namespace {
struct PythonUdfKernelState : public compute::KernelState {
// NOTE: this KernelState constructor doesn't require the GIL.
// If it did, the corresponding KernelInit::operator() should be wrapped
// within SafeCallIntoPython (GH-43487).
explicit PythonUdfKernelState(std::shared_ptr<OwnedRefNoGIL> function)
: function(std::move(function)) {}
std::shared_ptr<OwnedRefNoGIL> function;
};
struct PythonUdfKernelInit {
explicit PythonUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function)
: function(std::move(function)) {}
Result<std::unique_ptr<compute::KernelState>> operator()(
compute::KernelContext*, const compute::KernelInitArgs&) {
return std::make_unique<PythonUdfKernelState>(function);
}
std::shared_ptr<OwnedRefNoGIL> function;
};
struct ScalarUdfAggregator : public compute::KernelState {
virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0;
virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0;
virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0;
};
struct HashUdfAggregator : public compute::KernelState {
virtual Status Resize(KernelContext* ctx, int64_t size) = 0;
virtual Status Consume(KernelContext* ctx, const ExecSpan& batch) = 0;
virtual Status Merge(KernelContext* ct, KernelState&& other, const ArrayData&) = 0;
virtual Status Finalize(KernelContext* ctx, Datum* out) = 0;
};
Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) {
return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
}
Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src,
compute::KernelState* dst) {
return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx, std::move(src));
}
Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) {
return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
}
Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Resize(ctx, size);
}
Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Consume(ctx, batch);
}
Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src,
const ArrayData& group_id_mapping) {
return checked_cast<HashUdfAggregator*>(ctx->state())
->Merge(ctx, std::move(src), group_id_mapping);
}
Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Finalize(ctx, out);
}
struct PythonTableUdfKernelInit {
PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
UdfWrapperCallback cb)
: function_maker(std::move(function_maker)), cb(std::move(cb)) {}
Result<std::unique_ptr<compute::KernelState>> operator()(
compute::KernelContext* ctx, const compute::KernelInitArgs&) {
return SafeCallIntoPython(
[this, ctx]() -> Result<std::unique_ptr<compute::KernelState>> {
UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0};
OwnedRef empty_tuple(PyTuple_New(0));
auto function = std::make_shared<OwnedRefNoGIL>(
cb(function_maker->obj(), udf_context, empty_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
if (!PyCallable_Check(function->obj())) {
return Status::TypeError("Expected a callable Python object.");
}
return std::make_unique<PythonUdfKernelState>(std::move(function));
});
}
std::shared_ptr<OwnedRefNoGIL> function_maker;
UdfWrapperCallback cb;
};
struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
PythonUdfScalarAggregatorImpl(std::shared_ptr<OwnedRefNoGIL> function,
UdfWrapperCallback cb,
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: function(std::move(function)),
cb(std::move(cb)),
output_type(std::move(output_type)) {
std::vector<std::shared_ptr<Field>> fields;
for (size_t i = 0; i < input_types.size(); i++) {
fields.push_back(field("", input_types[i]));
}
input_schema = schema(std::move(fields));
};
Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override {
ARROW_ASSIGN_OR_RAISE(
auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
values.push_back(std::move(rb));
return Status::OK();
}
Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override {
auto& other_values = checked_cast<PythonUdfScalarAggregatorImpl&>(src).values;
values.insert(values.end(), std::make_move_iterator(other_values.begin()),
std::make_move_iterator(other_values.end()));
other_values.erase(other_values.begin(), other_values.end());
return Status::OK();
}
Status Finalize(compute::KernelContext* ctx, Datum* out) override {
auto state =
arrow::internal::checked_cast<PythonUdfScalarAggregatorImpl*>(ctx->state());
const int num_args = input_schema->num_fields();
// Note: The way that batches are concatenated together
// would result in using double amount of the memory.
// This is OK for now because non decomposable aggregate
// UDF is supposed to be used with segmented aggregation
// where the size of the segment is more or less constant
// so doubling that is not a big deal. This can be also
// improved in the future to use more efficient way to
// concatenate.
ARROW_ASSIGN_OR_RAISE(auto table,
arrow::Table::FromRecordBatches(input_schema, values));
ARROW_ASSIGN_OR_RAISE(table, table->CombineChunks(ctx->memory_pool()));
UdfContext udf_context{ctx->memory_pool(), table->num_rows()};
if (table->num_rows() == 0) {
return Status::Invalid("Finalized is called with empty inputs");
}
RETURN_NOT_OK(SafeCallIntoPython([&] {
std::unique_ptr<OwnedRef> result;
OwnedRef arg_tuple(PyTuple_New(num_args));
RETURN_NOT_OK(CheckPyError());
for (int arg_id = 0; arg_id < num_args; arg_id++) {
// Since we combined chunks there is only one chunk
std::shared_ptr<Array> c_data = table->column(arg_id)->chunk(0);
PyObject* data = wrap_array(c_data);
PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
}
result =
std::make_unique<OwnedRef>(cb(function->obj(), udf_context, arg_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
// unwrapping the output for expected output type
if (is_scalar(result->obj())) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> val, unwrap_scalar(result->obj()));
if (*output_type != *val->type) {
return Status::TypeError("Expected output datatype ", output_type->ToString(),
", but function returned datatype ",
val->type->ToString());
}
out->value = std::move(val);
return Status::OK();
}
return Status::TypeError("Unexpected output type: ",
Py_TYPE(result->obj())->tp_name, " (expected Scalar)");
}));
return Status::OK();
}
std::shared_ptr<OwnedRefNoGIL> function;
UdfWrapperCallback cb;
std::vector<std::shared_ptr<RecordBatch>> values;
std::shared_ptr<Schema> input_schema;
std::shared_ptr<DataType> output_type;
};
struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
PythonUdfHashAggregatorImpl(std::shared_ptr<OwnedRefNoGIL> function,
UdfWrapperCallback cb,
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: function(std::move(function)),
cb(std::move(cb)),
output_type(std::move(output_type)) {
std::vector<std::shared_ptr<Field>> fields;
fields.reserve(input_types.size());
for (size_t i = 0; i < input_types.size(); i++) {
fields.push_back(field("", input_types[i]));
}
input_schema = schema(std::move(fields));
};
// same as ApplyGrouping in partition.cc
// replicated the code here to avoid complicating the dependencies
static Result<RecordBatchVector> ApplyGroupings(
const ListArray& groupings, const std::shared_ptr<RecordBatch>& batch) {
ARROW_ASSIGN_OR_RAISE(Datum sorted,
compute::Take(batch, groupings.data()->child_data[0]));
const auto& sorted_batch = *sorted.record_batch();
RecordBatchVector out(static_cast<size_t>(groupings.length()));
for (size_t i = 0; i < out.size(); ++i) {
out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i));
}
return out;
}
Status Resize(KernelContext* ctx, int64_t new_num_groups) override {
// We only need to change num_groups in resize
// similar to other hash aggregate kernels
num_groups = new_num_groups;
return Status::OK();
}
Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<RecordBatch> rb,
batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
// This is similar to GroupedListImpl
// last array is the group id
const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array;
DCHECK_EQ(groups_array_data.offset, 0);
int64_t batch_num_values = groups_array_data.length;
const auto* batch_groups = groups_array_data.GetValues<uint32_t>(1);
RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values));
values.push_back(std::move(rb));
num_values += batch_num_values;
return Status::OK();
}
Status Merge(KernelContext* ctx, KernelState&& other_state,
const ArrayData& group_id_mapping) override {
// This is similar to GroupedListImpl
auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
auto& other_values = other.values;
const uint32_t* other_raw_groups = other.groups.data();
values.insert(values.end(), std::make_move_iterator(other_values.begin()),
std::make_move_iterator(other_values.end()));
auto g = group_id_mapping.GetValues<uint32_t>(1);
for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < other.num_values;
++other_g) {
// Different state can have different group_id mappings, so we
// need to translate the ids
RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]]));
}
num_values += other.num_values;
return Status::OK();
}
Status Finalize(KernelContext* ctx, Datum* out) override {
// Exclude the last column which is the group id
const int num_args = input_schema->num_fields() - 1;
ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish());
ARROW_ASSIGN_OR_RAISE(auto groupings,
Grouper::MakeGroupings(UInt32Array(num_values, groups_buffer),
static_cast<uint32_t>(num_groups)));
ARROW_ASSIGN_OR_RAISE(auto table,
arrow::Table::FromRecordBatches(input_schema, values));
ARROW_ASSIGN_OR_RAISE(auto rb, table->CombineChunksToBatch(ctx->memory_pool()));
UdfContext udf_context{ctx->memory_pool(), table->num_rows()};
if (rb->num_rows() == 0) {
*out = Datum();
return Status::OK();
}
ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb));
return SafeCallIntoPython([&] {
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ArrayBuilder> builder,
MakeBuilder(output_type, ctx->memory_pool()));
for (auto& group_rb : rbs) {
std::unique_ptr<OwnedRef> result;
OwnedRef arg_tuple(PyTuple_New(num_args));
RETURN_NOT_OK(CheckPyError());
for (int arg_id = 0; arg_id < num_args; arg_id++) {
// Since we combined chunks there is only one chunk
std::shared_ptr<Array> c_data = group_rb->column(arg_id);
PyObject* data = wrap_array(c_data);
PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
}
result =
std::make_unique<OwnedRef>(cb(function->obj(), udf_context, arg_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
// unwrapping the output for expected output type
if (is_scalar(result->obj())) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> val,
unwrap_scalar(result->obj()));
if (*output_type != *val->type) {
return Status::TypeError("Expected output datatype ", output_type->ToString(),
", but function returned datatype ",
val->type->ToString());
}
ARROW_RETURN_NOT_OK(builder->AppendScalar(std::move(*val)));
} else {
return Status::TypeError("Unexpected output type: ",
Py_TYPE(result->obj())->tp_name, " (expected Scalar)");
}
}
ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
out->value = std::move(result->data());
return Status::OK();
});
}
std::shared_ptr<OwnedRefNoGIL> function;
UdfWrapperCallback cb;
// Accumulated input batches
std::vector<std::shared_ptr<RecordBatch>> values;
// Group ids - extracted from the last column from the batch
TypedBufferBuilder<uint32_t> groups;
int64_t num_groups = 0;
int64_t num_values = 0;
std::shared_ptr<Schema> input_schema;
std::shared_ptr<DataType> output_type;
};
struct PythonUdf : public PythonUdfKernelState {
PythonUdf(std::shared_ptr<OwnedRefNoGIL> function, UdfWrapperCallback cb,
std::vector<TypeHolder> input_types, compute::OutputType output_type)
: PythonUdfKernelState(std::move(function)),
cb(std::move(cb)),
input_types(std::move(input_types)),
output_type(std::move(output_type)) {}
UdfWrapperCallback cb;
std::vector<TypeHolder> input_types;
compute::OutputType output_type;
TypeHolder resolved_type;
Result<TypeHolder> ResolveType(compute::KernelContext* ctx,
const std::vector<TypeHolder>& types) {
if (input_types == types) {
if (!resolved_type) {
ARROW_ASSIGN_OR_RAISE(resolved_type, output_type.Resolve(ctx, input_types));
}
return resolved_type;
}
return output_type.Resolve(ctx, types);
}
Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
compute::ExecResult* out) {
auto state = arrow::internal::checked_cast<PythonUdfKernelState*>(ctx->state());
PyObject* function = state->function->obj();
const int num_args = batch.num_values();
UdfContext udf_context{ctx->memory_pool(), batch.length};
OwnedRef arg_tuple(PyTuple_New(num_args));
RETURN_NOT_OK(CheckPyError());
for (int arg_id = 0; arg_id < num_args; arg_id++) {
if (batch[arg_id].is_scalar()) {
std::shared_ptr<Scalar> c_data = batch[arg_id].scalar->GetSharedPtr();
PyObject* data = wrap_scalar(c_data);
PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
} else {
std::shared_ptr<Array> c_data = batch[arg_id].array.ToArray();
PyObject* data = wrap_array(c_data);
PyTuple_SetItem(arg_tuple.obj(), arg_id, data);
}
}
OwnedRef result(cb(function, udf_context, arg_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
// unwrapping the output for expected output type
if (is_array(result.obj())) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> val, unwrap_array(result.obj()));
ARROW_ASSIGN_OR_RAISE(TypeHolder type, ResolveType(ctx, batch.GetTypes()));
if (type.type == NULLPTR) {
return Status::TypeError("expected output datatype is null");
}
if (*type.type != *val->type()) {
return Status::TypeError("Expected output datatype ", type.type->ToString(),
", but function returned datatype ",
val->type()->ToString());
}
out->value = std::move(val->data());
return Status::OK();
} else {
return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name,
" (expected Array)");
}
return Status::OK();
}
};
Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
compute::ExecResult* out) {
auto udf = static_cast<PythonUdf*>(ctx->kernel()->data.get());
return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); });
}
template <class Function, class Kernel>
Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init,
UdfWrapperCallback cb, const UdfOptions& options,
compute::FunctionRegistry* registry) {
if (!PyCallable_Check(function)) {
return Status::TypeError("Expected a callable Python object.");
}
auto scalar_func =
std::make_shared<Function>(options.func_name, options.arity, options.func_doc);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto udf_data = std::make_shared<PythonUdf>(
std::make_shared<OwnedRefNoGIL>(function), cb,
TypeHolder::FromTypes(options.input_types), options.output_type);
Kernel kernel(
compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
options.arity.is_varargs),
PythonUdfExec, kernel_init);
kernel.data = std::move(udf_data);
kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
if (registry == NULLPTR) {
registry = compute::GetFunctionRegistry();
}
RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
return Status::OK();
}
} // namespace
Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
function, PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
options, registry);
}
Status RegisterVectorFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
return RegisterUdf<compute::VectorFunction, compute::VectorKernel>(
function, PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
options, registry);
}
Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
if (options.arity.num_args != 0 || options.arity.is_varargs) {
return Status::NotImplemented("tabular function of non-null arity");
}
if (options.output_type->id() != Type::type::STRUCT) {
return Status::Invalid("tabular function with non-struct output");
}
return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
function, PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function), cb},
cb, options, registry);
}
Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
if (!PyCallable_Check(function)) {
return Status::TypeError("Expected a callable Python object.");
}
if (registry == NULLPTR) {
registry = compute::GetFunctionRegistry();
}
static auto default_scalar_aggregate_options =
compute::ScalarAggregateOptions::Defaults();
auto aggregate_func = std::make_shared<compute::ScalarAggregateFunction>(
options.func_name, options.arity, options.func_doc,
&default_scalar_aggregate_options);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto function_ref = std::make_shared<OwnedRefNoGIL>(function);
compute::KernelInit init = [cb, function_ref, options](
compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
-> Result<std::unique_ptr<compute::KernelState>> {
return std::make_unique<PythonUdfScalarAggregatorImpl>(
function_ref, cb, options.input_types, options.output_type);
};
auto sig = compute::KernelSignature::Make(
std::move(input_types), std::move(output_type), options.arity.is_varargs);
compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init),
AggregateUdfConsume, AggregateUdfMerge,
AggregateUdfFinalize, /*ordered=*/false);
RETURN_NOT_OK(aggregate_func->AddKernel(std::move(kernel)));
RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func)));
return Status::OK();
}
/// \brief Create a new UdfOptions with adjustment for hash kernel
/// \param options User provided udf options
UdfOptions AdjustForHashAggregate(const UdfOptions& options) {
UdfOptions hash_options;
// Append hash_ before the function name to separate from the scalar
// version
hash_options.func_name = "hash_" + options.func_name;
// Extend input types with group id. Group id is appended by the group
// aggregation node. Here we change both arity and input types
if (options.arity.is_varargs) {
hash_options.arity = options.arity;
} else {
hash_options.arity = compute::Arity(options.arity.num_args + 1, false);
}
// Changing the function doc shouldn't be necessarily because group id
// is not user visible, however, this is currently needed to pass the
// function validation. The name group_id_array is consistent with
// hash kernels in hash_aggregate.cc
hash_options.func_doc = options.func_doc;
hash_options.func_doc.arg_names.emplace_back("group_id_array");
std::vector<std::shared_ptr<DataType>> input_dtypes = options.input_types;
input_dtypes.emplace_back(uint32());
hash_options.input_types = std::move(input_dtypes);
hash_options.output_type = options.output_type;
return hash_options;
}
Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
if (!PyCallable_Check(function)) {
return Status::TypeError("Expected a callable Python object.");
}
if (registry == NULLPTR) {
registry = compute::GetFunctionRegistry();
}
UdfOptions hash_options = AdjustForHashAggregate(options);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : hash_options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(hash_options.output_type);
static auto default_hash_aggregate_options =
compute::ScalarAggregateOptions::Defaults();
auto hash_aggregate_func = std::make_shared<compute::HashAggregateFunction>(
hash_options.func_name, hash_options.arity, hash_options.func_doc,
&default_hash_aggregate_options);
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto function_ref = std::make_shared<OwnedRefNoGIL>(function);
compute::KernelInit init = [function_ref, cb, hash_options](
compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
-> Result<std::unique_ptr<compute::KernelState>> {
return std::make_unique<PythonUdfHashAggregatorImpl>(
function_ref, cb, hash_options.input_types, hash_options.output_type);
};
auto sig = compute::KernelSignature::Make(
std::move(input_types), std::move(output_type), hash_options.arity.is_varargs);
compute::HashAggregateKernel kernel(
std::move(sig), std::move(init), HashAggregateUdfResize, HashAggregateUdfConsume,
HashAggregateUdfMerge, HashAggregateUdfFinalize, /*ordered=*/false);
RETURN_NOT_OK(hash_aggregate_func->AddKernel(std::move(kernel)));
RETURN_NOT_OK(registry->AddFunction(std::move(hash_aggregate_func)));
return Status::OK();
}
Status RegisterAggregateFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
RETURN_NOT_OK(RegisterScalarAggregateFunction(function, cb, options, registry));
RETURN_NOT_OK(RegisterHashAggregateFunction(function, cb, options, registry));
return Status::OK();
}
Result<std::shared_ptr<RecordBatchReader>> CallTabularFunction(
const std::string& func_name, const std::vector<Datum>& args,
compute::FunctionRegistry* registry) {
if (args.size() != 0) {
return Status::NotImplemented("non-empty arguments to tabular function");
}
if (registry == NULLPTR) {
registry = compute::GetFunctionRegistry();
}
ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name));
if (func->kind() != compute::Function::SCALAR) {
return Status::Invalid("tabular function of non-scalar kind");
}
auto arity = func->arity();
if (arity.num_args != 0 || arity.is_varargs) {
return Status::NotImplemented("tabular function of non-null arity");
}
auto kernels =
arrow::internal::checked_pointer_cast<compute::ScalarFunction>(func)->kernels();
if (kernels.size() != 1) {
return Status::NotImplemented("tabular function with non-single kernel");
}
const compute::ScalarKernel* kernel = kernels[0];
auto out_type = kernel->signature->out_type();
if (out_type.kind() != compute::OutputType::FIXED) {
return Status::Invalid("tabular kernel of non-fixed kind");
}
auto datatype = out_type.type();
if (datatype->id() != Type::type::STRUCT) {
return Status::Invalid("tabular kernel with non-struct output");
}
auto struct_type = arrow::internal::checked_cast<StructType*>(datatype.get());
auto schema = ::arrow::schema(struct_type->fields());
std::vector<TypeHolder> in_types;
ARROW_ASSIGN_OR_RAISE(auto func_exec,
GetFunctionExecutor(func_name, in_types, NULLPTR, registry));
auto next_func = [schema, func_exec = std::move(
func_exec)]() -> Result<std::shared_ptr<RecordBatch>> {
std::vector<Datum> args;
// passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator
// in exec.cc and to never invoking the source function, so 1 is passed instead
// TODO: GH-33612: Support batch size in user-defined tabular functions
ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, /*passed_length=*/1));
if (!datum.is_array()) {
return Status::Invalid("UDF result of non-array kind");
}
std::shared_ptr<Array> array = datum.make_array();
if (array->length() == 0) {
return IterationTraits<std::shared_ptr<RecordBatch>>::End();
}
ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(std::move(array)));
if (!schema->Equals(batch->schema())) {
return Status::Invalid("UDF result with shape not conforming to schema");
}
return std::move(batch);
};
return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(std::move(next_func)),
schema);
}
} // namespace py
} // namespace arrow