Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

/ tests / test_acero.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest

import pyarrow as pa
import pyarrow.compute as pc
from pyarrow.compute import field

try:
    from pyarrow.acero import (
        Declaration,
        TableSourceNodeOptions,
        FilterNodeOptions,
        ProjectNodeOptions,
        AggregateNodeOptions,
        OrderByNodeOptions,
        HashJoinNodeOptions,
        AsofJoinNodeOptions,
    )
except ImportError:
    pass

try:
    import pyarrow.dataset as ds
    from pyarrow.acero import ScanNodeOptions
except ImportError:
    ds = None

pytestmark = pytest.mark.acero


@pytest.fixture
def table_source():
    table = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]})
    table_opts = TableSourceNodeOptions(table)
    table_source = Declaration("table_source", options=table_opts)
    return table_source


def test_declaration():

    table = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]})
    table_opts = TableSourceNodeOptions(table)
    filter_opts = FilterNodeOptions(field('a') > 1)

    # using sequence
    decl = Declaration.from_sequence([
        Declaration("table_source", options=table_opts),
        Declaration("filter", options=filter_opts)
    ])
    result = decl.to_table()
    assert result.equals(table.slice(1, 2))

    # using explicit inputs
    table_source = Declaration("table_source", options=table_opts)
    filtered = Declaration("filter", options=filter_opts, inputs=[table_source])
    result = filtered.to_table()
    assert result.equals(table.slice(1, 2))


def test_declaration_repr(table_source):

    assert "TableSourceNode" in str(table_source)
    assert "TableSourceNode" in repr(table_source)


def test_declaration_to_reader(table_source):
    with table_source.to_reader() as reader:
        assert reader.schema == pa.schema([("a", pa.int64()), ("b", pa.int64())])
        result = reader.read_all()
    expected = pa.table({'a': [1, 2, 3], 'b': [4, 5, 6]})
    assert result.equals(expected)


def test_table_source():
    with pytest.raises(TypeError):
        TableSourceNodeOptions(pa.record_batch([pa.array([1, 2, 3])], ["a"]))

    table_source = TableSourceNodeOptions(None)
    decl = Declaration("table_source", table_source)
    with pytest.raises(
        ValueError, match="TableSourceNode requires table which is not null"
    ):
        _ = decl.to_table()


def test_filter(table_source):
    # referencing unknown field
    decl = Declaration.from_sequence([
        table_source,
        Declaration("filter", options=FilterNodeOptions(field("c") > 1))
    ])
    with pytest.raises(ValueError, match=r"No match for FieldRef.Name\(c\)"):
        _ = decl.to_table()

    # requires a pyarrow Expression
    with pytest.raises(TypeError):
        FilterNodeOptions(pa.array([True, False, True]))
    with pytest.raises(TypeError):
        FilterNodeOptions(None)


def test_project(table_source):
    # default name from expression
    decl = Declaration.from_sequence([
        table_source,
        Declaration("project", ProjectNodeOptions([pc.multiply(field("a"), 2)]))
    ])
    result = decl.to_table()
    assert result.schema.names == ["multiply(a, 2)"]
    assert result[0].to_pylist() == [2, 4, 6]

    # provide name
    decl = Declaration.from_sequence([
        table_source,
        Declaration("project", ProjectNodeOptions([pc.multiply(field("a"), 2)], ["a2"]))
    ])
    result = decl.to_table()
    assert result.schema.names == ["a2"]
    assert result["a2"].to_pylist() == [2, 4, 6]

    # input validation
    with pytest.raises(ValueError):
        ProjectNodeOptions([pc.multiply(field("a"), 2)], ["a2", "b2"])

    # no scalar expression
    decl = Declaration.from_sequence([
        table_source,
        Declaration("project", ProjectNodeOptions([pc.sum(field("a"))]))
    ])
    with pytest.raises(ValueError, match="cannot Execute non-scalar expression"):
        _ = decl.to_table()


def test_aggregate_scalar(table_source):
    decl = Declaration.from_sequence([
        table_source,
        Declaration("aggregate", AggregateNodeOptions([("a", "sum", None, "a_sum")]))
    ])
    result = decl.to_table()
    assert result.schema.names == ["a_sum"]
    assert result["a_sum"].to_pylist() == [6]

    # with options class
    table = pa.table({'a': [1, 2, None]})
    aggr_opts = AggregateNodeOptions(
        [("a", "sum", pc.ScalarAggregateOptions(skip_nulls=False), "a_sum")]
    )
    decl = Declaration.from_sequence([
        Declaration("table_source", TableSourceNodeOptions(table)),
        Declaration("aggregate", aggr_opts),
    ])
    result = decl.to_table()
    assert result.schema.names == ["a_sum"]
    assert result["a_sum"].to_pylist() == [None]

    # test various ways of specifying the target column
    for target in ["a", field("a"), 0, field(0), ["a"], [field("a")], [0]]:
        aggr_opts = AggregateNodeOptions([(target, "sum", None, "a_sum")])
        decl = Declaration.from_sequence(
            [table_source, Declaration("aggregate", aggr_opts)]
        )
        result = decl.to_table()
        assert result.schema.names == ["a_sum"]
        assert result["a_sum"].to_pylist() == [6]

    # proper error when specifying the wrong number of target columns
    aggr_opts = AggregateNodeOptions([(["a", "b"], "sum", None, "a_sum")])
    decl = Declaration.from_sequence(
        [table_source, Declaration("aggregate", aggr_opts)]
    )
    with pytest.raises(
        ValueError, match="Function 'sum' accepts 1 arguments but 2 passed"
    ):
        _ = decl.to_table()

    # proper error when using hash aggregation without keys
    aggr_opts = AggregateNodeOptions([("a", "hash_sum", None, "a_sum")])
    decl = Declaration.from_sequence(
        [table_source, Declaration("aggregate", aggr_opts)]
    )
    with pytest.raises(ValueError, match="is a hash aggregate function"):
        _ = decl.to_table()


def test_aggregate_hash():
    table = pa.table({'a': [1, 2, None], 'b': ["foo", "bar", "foo"]})
    table_opts = TableSourceNodeOptions(table)
    table_source = Declaration("table_source", options=table_opts)

    # default options
    aggr_opts = AggregateNodeOptions(
        [("a", "hash_count", None, "count(a)")], keys=["b"])
    decl = Declaration.from_sequence([
        table_source, Declaration("aggregate", aggr_opts)
    ])
    result = decl.to_table()
    expected = pa.table({"b": ["foo", "bar"], "count(a)": [1, 1]})
    assert result.equals(expected)

    # specify function options
    aggr_opts = AggregateNodeOptions(
        [("a", "hash_count", pc.CountOptions("all"), "count(a)")], keys=["b"]
    )
    decl = Declaration.from_sequence([
        table_source, Declaration("aggregate", aggr_opts)
    ])
    result = decl.to_table()
    expected_all = pa.table({"b": ["foo", "bar"], "count(a)": [2, 1]})
    assert result.equals(expected_all)

    # specify keys as field references
    aggr_opts = AggregateNodeOptions(
        [("a", "hash_count", None, "count(a)")], keys=[field("b")]
    )
    decl = Declaration.from_sequence([
        table_source, Declaration("aggregate", aggr_opts)
    ])
    result = decl.to_table()
    assert result.equals(expected)

    # wrong type of (aggregation) function
    # TODO test with kernel that matches number of arguments (arity) -> avoid segfault
    aggr_opts = AggregateNodeOptions([("a", "sum", None, "a_sum")], keys=["b"])
    decl = Declaration.from_sequence([
        table_source, Declaration("aggregate", aggr_opts)
    ])
    with pytest.raises(ValueError):
        _ = decl.to_table()


def test_order_by():
    table = pa.table({'a': [1, 2, 3, 4], 'b': [1, 3, None, 2]})
    table_source = Declaration("table_source", TableSourceNodeOptions(table))

    ord_opts = OrderByNodeOptions([("b", "ascending")])
    decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)])
    result = decl.to_table()
    expected = pa.table({"a": [1, 4, 2, 3], "b": [1, 2, 3, None]})
    assert result.equals(expected)

    ord_opts = OrderByNodeOptions([(field("b"), "descending")])
    decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)])
    result = decl.to_table()
    expected = pa.table({"a": [2, 4, 1, 3], "b": [3, 2, 1, None]})
    assert result.equals(expected)

    ord_opts = OrderByNodeOptions([(1, "descending")], null_placement="at_start")
    decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)])
    result = decl.to_table()
    expected = pa.table({"a": [3, 2, 4, 1], "b": [None, 3, 2, 1]})
    assert result.equals(expected)

    # empty ordering
    ord_opts = OrderByNodeOptions([])
    decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)])
    with pytest.raises(
        ValueError, match="`ordering` must be an explicit non-empty ordering"
    ):
        _ = decl.to_table()

    with pytest.raises(ValueError, match="\"decreasing\" is not a valid sort order"):
        _ = OrderByNodeOptions([("b", "decreasing")])

    with pytest.raises(ValueError, match="\"start\" is not a valid null placement"):
        _ = OrderByNodeOptions([("b", "ascending")], null_placement="start")


def test_hash_join():
    left = pa.table({'key': [1, 2, 3], 'a': [4, 5, 6]})
    left_source = Declaration("table_source", options=TableSourceNodeOptions(left))
    right = pa.table({'key': [2, 3, 4], 'b': [4, 5, 6]})
    right_source = Declaration("table_source", options=TableSourceNodeOptions(right))

    # inner join
    join_opts = HashJoinNodeOptions("inner", left_keys="key", right_keys="key")
    joined = Declaration(
        "hashjoin", options=join_opts, inputs=[left_source, right_source])
    result = joined.to_table()
    expected = pa.table(
        [[2, 3], [5, 6], [2, 3], [4, 5]],
        names=["key", "a", "key", "b"])
    assert result.equals(expected)

    for keys in [field("key"), ["key"], [field("key")]]:
        join_opts = HashJoinNodeOptions("inner", left_keys=keys, right_keys=keys)
        joined = Declaration(
            "hashjoin", options=join_opts, inputs=[left_source, right_source])
        result = joined.to_table()
        assert result.equals(expected)

    # left join
    join_opts = HashJoinNodeOptions(
        "left outer", left_keys="key", right_keys="key")
    joined = Declaration(
        "hashjoin", options=join_opts, inputs=[left_source, right_source])
    result = joined.to_table()
    expected = pa.table(
        [[1, 2, 3], [4, 5, 6], [None, 2, 3], [None, 4, 5]],
        names=["key", "a", "key", "b"]
    )
    assert result.sort_by("a").equals(expected)

    # suffixes
    join_opts = HashJoinNodeOptions(
        "left outer", left_keys="key", right_keys="key",
        output_suffix_for_left="_left", output_suffix_for_right="_right")
    joined = Declaration(
        "hashjoin", options=join_opts, inputs=[left_source, right_source])
    result = joined.to_table()
    expected = pa.table(
        [[1, 2, 3], [4, 5, 6], [None, 2, 3], [None, 4, 5]],
        names=["key_left", "a", "key_right", "b"]
    )
    assert result.sort_by("a").equals(expected)

    # manually specifying output columns
    join_opts = HashJoinNodeOptions(
        "left outer", left_keys="key", right_keys="key",
        left_output=["key", "a"], right_output=[field("b")])
    joined = Declaration(
        "hashjoin", options=join_opts, inputs=[left_source, right_source])
    result = joined.to_table()
    expected = pa.table(
        [[1, 2, 3], [4, 5, 6], [None, 4, 5]],
        names=["key", "a", "b"]
    )
    assert result.sort_by("a").equals(expected)

Loading ...