# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from cython.operator cimport dereference as deref
from libcpp.vector cimport vector as std_vector
from pyarrow import Buffer, py_buffer
from pyarrow._compute cimport Expression
from pyarrow.lib import frombytes, tobytes
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *
# TODO GH-37235: Fix exception handling
cdef CDeclaration _create_named_table_provider(
dict named_args, const std_vector[c_string]& names, const CSchema& schema
) noexcept:
cdef:
c_string c_name
shared_ptr[CTable] c_in_table
shared_ptr[CTableSourceNodeOptions] c_tablesourceopts
shared_ptr[CExecNodeOptions] c_input_node_opts
vector[CDeclaration.Input] no_c_inputs
py_names = []
for i in range(names.size()):
c_name = names[i]
py_names.append(frombytes(c_name))
py_schema = pyarrow_wrap_schema(make_shared[CSchema](schema))
py_table = named_args["provider"](py_names, py_schema)
c_in_table = pyarrow_unwrap_table(py_table)
c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table)
c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions](
c_tablesourceopts)
return CDeclaration(tobytes("table_source"),
no_c_inputs, c_input_node_opts)
def run_query(plan, *, table_provider=None, use_threads=True):
"""
Execute a Substrait plan and read the results as a RecordBatchReader.
Parameters
----------
plan : Union[Buffer, bytes]
The serialized Substrait plan to execute.
table_provider : object (optional)
A function to resolve any NamedTable relation to a table.
The function will receive two arguments which will be a list
of strings representing the table name and a pyarrow.Schema representing
the expected schema and should return a pyarrow.Table.
use_threads : bool, default True
If True then multiple threads will be used to run the query. If False then
all CPU intensive work will be done on the calling thread.
Returns
-------
RecordBatchReader
A reader containing the result of the executed query
Examples
--------
>>> import pyarrow as pa
>>> from pyarrow.lib import tobytes
>>> import pyarrow.substrait as substrait
>>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
>>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
>>> def table_provider(names, schema):
... if not names:
... raise Exception("No names provided")
... elif names[0] == "t1":
... return test_table_1
... elif names[1] == "t2":
... return test_table_2
... else:
... raise Exception("Unrecognized table name")
...
>>> substrait_query = '''
... {
... "relations": [
... {"rel": {
... "read": {
... "base_schema": {
... "struct": {
... "types": [
... {"i64": {}}
... ]
... },
... "names": [
... "x"
... ]
... },
... "namedTable": {
... "names": ["t1"]
... }
... }
... }}
... ]
... }
... '''
>>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
>>> reader = pa.substrait.run_query(buf, table_provider=table_provider)
>>> reader.read_all()
pyarrow.Table
x: int64
----
x: [[1,2,3]]
"""
cdef:
CResult[shared_ptr[CRecordBatchReader]] c_res_reader
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader reader
shared_ptr[CBuffer] c_buf_plan
CConversionOptions c_conversion_options
c_bool c_use_threads
c_use_threads = use_threads
if isinstance(plan, bytes):
c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan))
elif isinstance(plan, Buffer):
c_buf_plan = pyarrow_unwrap_buffer(plan)
else:
raise TypeError(
f"Expected 'pyarrow.Buffer' or bytes, got '{type(plan)}'")
if table_provider is not None:
named_table_args = {
"provider": table_provider
}
c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider](
&_create_named_table_provider, named_table_args)
with nogil:
c_res_reader = ExecuteSerializedPlan(
deref(c_buf_plan), default_extension_id_registry(),
GetFunctionRegistry(), c_conversion_options, c_use_threads)
c_reader = GetResultValue(c_res_reader)
reader = RecordBatchReader.__new__(RecordBatchReader)
reader.reader = c_reader
return reader
def _parse_json_plan(plan):
"""
Parse a JSON plan into equivalent serialized Protobuf.
Parameters
----------
plan : bytes
Substrait plan in JSON.
Returns
-------
Buffer
A buffer containing the serialized Protobuf plan.
"""
cdef:
CResult[shared_ptr[CBuffer]] c_res_buffer
c_string c_str_plan
shared_ptr[CBuffer] c_buf_plan
c_str_plan = plan
c_res_buffer = SerializeJsonPlan(c_str_plan)
with nogil:
c_buf_plan = GetResultValue(c_res_buffer)
return pyarrow_wrap_buffer(c_buf_plan)
def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False):
"""
Serialize a collection of expressions into Substrait
Substrait expressions must be bound to a schema. For example,
the Substrait expression ``a:i32 + b:i32`` is different from the
Substrait expression ``a:i64 + b:i64``. Pyarrow expressions are
typically unbound. For example, both of the above expressions
would be represented as ``a + b`` in pyarrow.
This means a schema must be provided when serializing an expression.
It also means that the serialization may fail if a matching function
call cannot be found for the expression.
Parameters
----------
exprs : list of Expression
The expressions to serialize
names : list of str
Names for the expressions
schema : Schema
The schema the expressions will be bound to
allow_arrow_extensions : bool, default False
If False then only functions that are part of the core Substrait function
definitions will be allowed. Set this to True to allow pyarrow-specific functions
and user defined functions but the result may not be accepted by other
compute libraries.
Returns
-------
Buffer
An ExtendedExpression message containing the serialized expressions
"""
cdef:
CResult[shared_ptr[CBuffer]] c_res_buffer
shared_ptr[CBuffer] c_buffer
CNamedExpression c_named_expr
CBoundExpressions c_bound_exprs
CConversionOptions c_conversion_options
if len(exprs) != len(names):
raise ValueError("exprs and names need to have the same length")
for expr, name in zip(exprs, names):
if not isinstance(expr, Expression):
raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs")
if not isinstance(name, str):
raise TypeError(f"Expected str, got '{type(name)}' in names")
c_named_expr.expression = (<Expression> expr).unwrap()
c_named_expr.name = tobytes(<str> name)
c_bound_exprs.named_expressions.push_back(c_named_expr)
c_bound_exprs.schema = (<Schema> schema).sp_schema
c_conversion_options.allow_arrow_extensions = allow_arrow_extensions
with nogil:
c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options)
c_buffer = GetResultValue(c_res_buffer)
return pyarrow_wrap_buffer(c_buffer)
cdef class BoundExpressions(_Weakrefable):
"""
A collection of named expressions and the schema they are bound to
This is equivalent to the Substrait ExtendedExpression message
"""
cdef:
CBoundExpressions c_bound_exprs
def __init__(self):
msg = 'BoundExpressions is an abstract class thus cannot be initialized.'
raise TypeError(msg)
cdef void init(self, CBoundExpressions bound_expressions):
self.c_bound_exprs = bound_expressions
@property
def schema(self):
"""
The common schema that all expressions are bound to
"""
return pyarrow_wrap_schema(self.c_bound_exprs.schema)
@property
def expressions(self):
"""
A dict from expression name to expression
"""
expr_dict = {}
for named_expr in self.c_bound_exprs.named_expressions:
name = frombytes(named_expr.name)
expr = Expression.wrap(named_expr.expression)
expr_dict[name] = expr
return expr_dict
@staticmethod
cdef wrap(const CBoundExpressions& bound_expressions):
cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions)
self.init(bound_expressions)
return self
def deserialize_expressions(buf):
"""
Deserialize an ExtendedExpression Substrait message into a BoundExpressions object
Parameters
----------
buf : Buffer or bytes
The message to deserialize
Returns
-------
BoundExpressions
The deserialized expressions, their names, and the bound schema
"""
cdef:
shared_ptr[CBuffer] c_buffer
CResult[CBoundExpressions] c_res_bound_exprs
CBoundExpressions c_bound_exprs
if isinstance(buf, bytes):
c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
elif isinstance(buf, Buffer):
c_buffer = pyarrow_unwrap_buffer(buf)
else:
raise TypeError(
f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'")
with nogil:
c_res_bound_exprs = DeserializeExpressions(deref(c_buffer))
c_bound_exprs = GetResultValue(c_res_bound_exprs)
return BoundExpressions.wrap(c_bound_exprs)
def get_supported_functions():
"""
Get a list of Substrait functions that the underlying
engine currently supports.
Returns
-------
list[str]
A list of function ids encoded as '{uri}#{name}'
"""
cdef:
ExtensionIdRegistry* c_id_registry
std_vector[c_string] c_ids
c_id_registry = default_extension_id_registry()
c_ids = c_id_registry.GetSupportedSubstraitFunctions()
Loading ...