Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

/ _substrait.pyx

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# cython: language_level = 3
from cython.operator cimport dereference as deref
from libcpp.vector cimport vector as std_vector

from pyarrow import Buffer, py_buffer
from pyarrow._compute cimport Expression
from pyarrow.lib import frombytes, tobytes
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *


# TODO GH-37235: Fix exception handling
cdef CDeclaration _create_named_table_provider(
    dict named_args, const std_vector[c_string]& names, const CSchema& schema
) noexcept:
    cdef:
        c_string c_name
        shared_ptr[CTable] c_in_table
        shared_ptr[CTableSourceNodeOptions] c_tablesourceopts
        shared_ptr[CExecNodeOptions] c_input_node_opts
        vector[CDeclaration.Input] no_c_inputs

    py_names = []
    for i in range(names.size()):
        c_name = names[i]
        py_names.append(frombytes(c_name))
    py_schema = pyarrow_wrap_schema(make_shared[CSchema](schema))

    py_table = named_args["provider"](py_names, py_schema)
    c_in_table = pyarrow_unwrap_table(py_table)
    c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table)
    c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions](
        c_tablesourceopts)
    return CDeclaration(tobytes("table_source"),
                        no_c_inputs, c_input_node_opts)


def run_query(plan, *, table_provider=None, use_threads=True):
    """
    Execute a Substrait plan and read the results as a RecordBatchReader.

    Parameters
    ----------
    plan : Union[Buffer, bytes]
        The serialized Substrait plan to execute.
    table_provider : object (optional)
        A function to resolve any NamedTable relation to a table.
        The function will receive two arguments which will be a list
        of strings representing the table name and a pyarrow.Schema representing
        the expected schema and should return a pyarrow.Table.
    use_threads : bool, default True
        If True then multiple threads will be used to run the query.  If False then
        all CPU intensive work will be done on the calling thread.

    Returns
    -------
    RecordBatchReader
        A reader containing the result of the executed query

    Examples
    --------
    >>> import pyarrow as pa
    >>> from pyarrow.lib import tobytes
    >>> import pyarrow.substrait as substrait
    >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
    >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
    >>> def table_provider(names, schema):
    ...     if not names:
    ...        raise Exception("No names provided")
    ...     elif names[0] == "t1":
    ...        return test_table_1
    ...     elif names[1] == "t2":
    ...        return test_table_2
    ...     else:
    ...        raise Exception("Unrecognized table name")
    ...
    >>> substrait_query = '''
    ...         {
    ...             "relations": [
    ...             {"rel": {
    ...                 "read": {
    ...                 "base_schema": {
    ...                     "struct": {
    ...                     "types": [
    ...                                 {"i64": {}}
    ...                             ]
    ...                     },
    ...                     "names": [
    ...                             "x"
    ...                             ]
    ...                 },
    ...                 "namedTable": {
    ...                         "names": ["t1"]
    ...                 }
    ...                 }
    ...             }}
    ...             ]
    ...         }
    ... '''
    >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
    >>> reader = pa.substrait.run_query(buf, table_provider=table_provider)
    >>> reader.read_all()
    pyarrow.Table
    x: int64
    ----
    x: [[1,2,3]]
    """

    cdef:
        CResult[shared_ptr[CRecordBatchReader]] c_res_reader
        shared_ptr[CRecordBatchReader] c_reader
        RecordBatchReader reader
        shared_ptr[CBuffer] c_buf_plan
        CConversionOptions c_conversion_options
        c_bool c_use_threads

    c_use_threads = use_threads
    if isinstance(plan, bytes):
        c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan))
    elif isinstance(plan, Buffer):
        c_buf_plan = pyarrow_unwrap_buffer(plan)
    else:
        raise TypeError(
            f"Expected 'pyarrow.Buffer' or bytes, got '{type(plan)}'")

    if table_provider is not None:
        named_table_args = {
            "provider": table_provider
        }
        c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider](
            &_create_named_table_provider, named_table_args)

    with nogil:
        c_res_reader = ExecuteSerializedPlan(
            deref(c_buf_plan), default_extension_id_registry(),
            GetFunctionRegistry(), c_conversion_options, c_use_threads)

    c_reader = GetResultValue(c_res_reader)

    reader = RecordBatchReader.__new__(RecordBatchReader)
    reader.reader = c_reader
    return reader


def _parse_json_plan(plan):
    """
    Parse a JSON plan into equivalent serialized Protobuf.

    Parameters
    ----------
    plan : bytes
        Substrait plan in JSON.

    Returns
    -------
    Buffer
        A buffer containing the serialized Protobuf plan.
    """

    cdef:
        CResult[shared_ptr[CBuffer]] c_res_buffer
        c_string c_str_plan
        shared_ptr[CBuffer] c_buf_plan

    c_str_plan = plan
    c_res_buffer = SerializeJsonPlan(c_str_plan)
    with nogil:
        c_buf_plan = GetResultValue(c_res_buffer)
    return pyarrow_wrap_buffer(c_buf_plan)


def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False):
    """
    Serialize a collection of expressions into Substrait

    Substrait expressions must be bound to a schema.  For example,
    the Substrait expression ``a:i32 + b:i32`` is different from the
    Substrait expression ``a:i64 + b:i64``.  Pyarrow expressions are
    typically unbound.  For example, both of the above expressions
    would be represented as ``a + b`` in pyarrow.

    This means a schema must be provided when serializing an expression.
    It also means that the serialization may fail if a matching function
    call cannot be found for the expression.

    Parameters
    ----------
    exprs : list of Expression
        The expressions to serialize
    names : list of str
        Names for the expressions
    schema : Schema
        The schema the expressions will be bound to
    allow_arrow_extensions : bool, default False
        If False then only functions that are part of the core Substrait function
        definitions will be allowed.  Set this to True to allow pyarrow-specific functions
        and user defined functions but the result may not be accepted by other
        compute libraries.

    Returns
    -------
    Buffer
        An ExtendedExpression message containing the serialized expressions
    """
    cdef:
        CResult[shared_ptr[CBuffer]] c_res_buffer
        shared_ptr[CBuffer] c_buffer
        CNamedExpression c_named_expr
        CBoundExpressions c_bound_exprs
        CConversionOptions c_conversion_options

    if len(exprs) != len(names):
        raise ValueError("exprs and names need to have the same length")
    for expr, name in zip(exprs, names):
        if not isinstance(expr, Expression):
            raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs")
        if not isinstance(name, str):
            raise TypeError(f"Expected str, got '{type(name)}' in names")
        c_named_expr.expression = (<Expression> expr).unwrap()
        c_named_expr.name = tobytes(<str> name)
        c_bound_exprs.named_expressions.push_back(c_named_expr)

    c_bound_exprs.schema = (<Schema> schema).sp_schema

    c_conversion_options.allow_arrow_extensions = allow_arrow_extensions

    with nogil:
        c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options)
        c_buffer = GetResultValue(c_res_buffer)
    return pyarrow_wrap_buffer(c_buffer)


cdef class BoundExpressions(_Weakrefable):
    """
    A collection of named expressions and the schema they are bound to

    This is equivalent to the Substrait ExtendedExpression message
    """

    cdef:
        CBoundExpressions c_bound_exprs

    def __init__(self):
        msg = 'BoundExpressions is an abstract class thus cannot be initialized.'
        raise TypeError(msg)

    cdef void init(self, CBoundExpressions bound_expressions):
        self.c_bound_exprs = bound_expressions

    @property
    def schema(self):
        """
        The common schema that all expressions are bound to
        """
        return pyarrow_wrap_schema(self.c_bound_exprs.schema)

    @property
    def expressions(self):
        """
        A dict from expression name to expression
        """
        expr_dict = {}
        for named_expr in self.c_bound_exprs.named_expressions:
            name = frombytes(named_expr.name)
            expr = Expression.wrap(named_expr.expression)
            expr_dict[name] = expr
        return expr_dict

    @staticmethod
    cdef wrap(const CBoundExpressions& bound_expressions):
        cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions)
        self.init(bound_expressions)
        return self


def deserialize_expressions(buf):
    """
    Deserialize an ExtendedExpression Substrait message into a BoundExpressions object

    Parameters
    ----------
    buf : Buffer or bytes
        The message to deserialize

    Returns
    -------
    BoundExpressions
        The deserialized expressions, their names, and the bound schema
    """
    cdef:
        shared_ptr[CBuffer] c_buffer
        CResult[CBoundExpressions] c_res_bound_exprs
        CBoundExpressions c_bound_exprs

    if isinstance(buf, bytes):
        c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
    elif isinstance(buf, Buffer):
        c_buffer = pyarrow_unwrap_buffer(buf)
    else:
        raise TypeError(
            f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'")

    with nogil:
        c_res_bound_exprs = DeserializeExpressions(deref(c_buffer))
        c_bound_exprs = GetResultValue(c_res_bound_exprs)

    return BoundExpressions.wrap(c_bound_exprs)


def get_supported_functions():
    """
    Get a list of Substrait functions that the underlying
    engine currently supports.

    Returns
    -------
    list[str]
        A list of function ids encoded as '{uri}#{name}'
    """

    cdef:
        ExtensionIdRegistry* c_id_registry
        std_vector[c_string] c_ids

    c_id_registry = default_extension_id_registry()
    c_ids = c_id_registry.GetSupportedSubstraitFunctions()
Loading ...