# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import pyarrow as pa
import pyarrow.compute as pc
from .test_extension_type import IntegerType
try:
import pyarrow.dataset as ds
except ImportError:
pass
try:
from pyarrow.acero import _perform_join, _filter_table
except ImportError:
pass
pytestmark = pytest.mark.acero
def test_joins_corner_cases():
t1 = pa.Table.from_pydict({
"colA": [1, 2, 3, 4, 5, 6],
"col2": ["a", "b", "c", "d", "e", "f"]
})
t2 = pa.Table.from_pydict({
"colB": [1, 2, 3, 4, 5],
"col3": ["A", "B", "C", "D", "E"]
})
with pytest.raises(pa.ArrowInvalid):
_perform_join("left outer", t1, "", t2, "")
with pytest.raises(TypeError):
_perform_join("left outer", None, "colA", t2, "colB")
with pytest.raises(ValueError):
_perform_join("super mario join", t1, "colA", t2, "colB")
@pytest.mark.parametrize("jointype,expected", [
("left semi", {
"colA": [1, 2],
"col2": ["a", "b"]
}),
("right semi", {
"colB": [1, 2],
"col3": ["A", "B"]
}),
("left anti", {
"colA": [6],
"col2": ["f"]
}),
("right anti", {
"colB": [99],
"col3": ["Z"]
}),
("inner", {
"colA": [1, 2],
"col2": ["a", "b"],
"colB": [1, 2],
"col3": ["A", "B"]
}),
("left outer", {
"colA": [1, 2, 6],
"col2": ["a", "b", "f"],
"colB": [1, 2, None],
"col3": ["A", "B", None]
}),
("right outer", {
"colA": [1, 2, None],
"col2": ["a", "b", None],
"colB": [1, 2, 99],
"col3": ["A", "B", "Z"]
}),
("full outer", {
"colA": [1, 2, 6, None],
"col2": ["a", "b", "f", None],
"colB": [1, 2, None, 99],
"col3": ["A", "B", None, "Z"]
})
])
@pytest.mark.parametrize("use_threads", [True, False])
@pytest.mark.parametrize("coalesce_keys", [True, False])
@pytest.mark.parametrize("use_datasets",
[False, pytest.param(True, marks=pytest.mark.dataset)])
def test_joins(jointype, expected, use_threads, coalesce_keys, use_datasets):
# Allocate table here instead of using parametrize
# this prevents having arrow allocated memory forever around.
expected = pa.table(expected)
t1 = pa.Table.from_pydict({
"colA": [1, 2, 6],
"col2": ["a", "b", "f"]
})
t2 = pa.Table.from_pydict({
"colB": [99, 2, 1],
"col3": ["Z", "B", "A"]
})
if use_datasets:
t1 = ds.dataset([t1])
t2 = ds.dataset([t2])
r = _perform_join(jointype, t1, "colA", t2, "colB",
use_threads=use_threads, coalesce_keys=coalesce_keys)
r = r.combine_chunks()
if "right" in jointype:
r = r.sort_by("colB")
else:
r = r.sort_by("colA")
if coalesce_keys:
if jointype in ("inner", "left outer"):
expected = expected.drop(["colB"])
elif jointype == "right outer":
expected = expected.drop(["colA"])
elif jointype == "full outer":
expected = expected.drop(["colB"]).set_column(0, "colA", [[1, 2, 6, 99]])
assert r == expected
def test_table_join_collisions():
t1 = pa.table({
"colA": [1, 2, 6],
"colB": [10, 20, 60],
"colVals": ["a", "b", "f"]
})
t2 = pa.table({
"colB": [99, 20, 10],
"colVals": ["Z", "B", "A"],
"colUniq": [100, 200, 300],
"colA": [99, 2, 1],
})
result = _perform_join(
"full outer", t1, ["colA", "colB"], t2, ["colA", "colB"])
result = result.combine_chunks()
result = result.sort_by("colUniq")
assert result == pa.table([
[None, 2, 1, 6],
[None, 20, 10, 60],
[None, "b", "a", "f"],
[99, 20, 10, None],
["Z", "B", "A", None],
[100, 200, 300, None],
[99, 2, 1, None],
], names=["colA", "colB", "colVals", "colB", "colVals", "colUniq", "colA"])
result = _perform_join("full outer", t1, "colA",
t2, "colA", right_suffix="_r",
coalesce_keys=False)
result = result.combine_chunks()
result = result.sort_by("colA")
assert result == pa.table({
"colA": [1, 2, 6, None],
"colB": [10, 20, 60, None],
"colVals": ["a", "b", "f", None],
"colB_r": [10, 20, None, 99],
"colVals_r": ["A", "B", None, "Z"],
"colUniq": [300, 200, None, 100],
"colA_r": [1, 2, None, 99],
})
result = _perform_join("full outer", t1, "colA",
t2, "colA", right_suffix="_r",
coalesce_keys=True)
result = result.combine_chunks()
result = result.sort_by("colA")
assert result == pa.table({
"colA": [1, 2, 6, 99],
"colB": [10, 20, 60, None],
"colVals": ["a", "b", "f", None],
"colB_r": [10, 20, None, 99],
"colVals_r": ["A", "B", None, "Z"],
"colUniq": [300, 200, None, 100]
})
def test_table_join_keys_order():
t1 = pa.table({
"colB": [10, 20, 60],
"colA": [1, 2, 6],
"colVals": ["a", "b", "f"]
})
t2 = pa.table({
"colVals": ["Z", "B", "A"],
"colX": [99, 2, 1],
})
result = _perform_join("full outer", t1, "colA", t2, "colX",
left_suffix="_l", right_suffix="_r",
coalesce_keys=True)
result = result.combine_chunks()
result = result.sort_by("colA")
assert result == pa.table({
"colB": [10, 20, 60, None],
"colA": [1, 2, 6, 99],
"colVals_l": ["a", "b", "f", None],
"colVals_r": ["A", "B", None, "Z"],
})
def test_filter_table_errors():
t = pa.table({
"a": [1, 2, 3, 4, 5],
"b": [10, 20, 30, 40, 50]
})
with pytest.raises(pa.ArrowTypeError):
_filter_table(t, pc.divide(pc.field("a"), pc.scalar(2)))
with pytest.raises(pa.ArrowInvalid):
_filter_table(t, (pc.field("Z") <= pc.scalar(2)))
def test_filter_table():
t = pa.table({
"a": [1, 2, 3, 4, 5],
"b": [10, 20, 30, 40, 50]
})
result = _filter_table(
t, (pc.field("a") <= pc.scalar(3)) & (pc.field("b") == pc.scalar(20)),
)
assert result == pa.table({
"a": [2],
"b": [20]
})
result = _filter_table(t, pc.field("b") > pc.scalar(30))
assert result == pa.table({
"a": [4, 5],
"b": [40, 50]
})
def test_filter_table_ordering():
table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a'] * 4})
table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b'] * 4})
table = pa.concat_tables([table1, table2])
for _ in range(20):
# 20 seems to consistently cause errors when order is not preserved.
# If the order problem is reintroduced this test will become flaky
# which is still a signal that the order is not preserved.
r = _filter_table(table, pc.field('a') == 1)
assert r["b"] == pa.chunked_array([["a"], ["b"]])
def test_complex_filter_table():
t = pa.table({
"a": [1, 2, 3, 4, 5, 6, 6],
"b": [10, 20, 30, 40, 50, 60, 61]
})
result = _filter_table(
t, ((pc.bit_wise_and(pc.field("a"), pc.scalar(1)) == pc.scalar(0)) &
(pc.multiply(pc.field("a"), pc.scalar(10)) == pc.field("b")))
)
assert result == pa.table({
"a": [2, 4, 6], # second six must be omitted because 6*10 != 61
"b": [20, 40, 60]
})
def test_join_extension_array_column():
storage = pa.array([1, 2, 3], type=pa.int64())
ty = IntegerType()
ext_array = pa.ExtensionArray.from_storage(ty, storage)
dict_array = pa.DictionaryArray.from_arrays(
pa.array([0, 2, 1]), pa.array(['a', 'b', 'c']))
t1 = pa.table({
"colA": [1, 2, 6],
"colB": ext_array,
"colVals": ext_array,
})
t2 = pa.table({
"colA": [99, 2, 1],
"colC": ext_array,
})
t3 = pa.table({
"colA": [99, 2, 1],
"colC": ext_array,
"colD": dict_array,
})
result = _perform_join(
"left outer", t1, ["colA"], t2, ["colA"])
assert result["colVals"] == pa.chunked_array(ext_array)
result = _perform_join(
"left outer", t1, ["colB"], t2, ["colC"])
assert result["colB"] == pa.chunked_array(ext_array)
result = _perform_join(
"left outer", t1, ["colA"], t3, ["colA"])
assert result["colVals"] == pa.chunked_array(ext_array)
result = _perform_join(
"left outer", t1, ["colB"], t3, ["colC"])
assert result["colB"] == pa.chunked_array(ext_array)
def test_group_by_ordering():
# GH-36709 - preserve ordering in groupby by setting use_threads=False
table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a'] * 4})
table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b'] * 4})
table = pa.concat_tables([table1, table2])
for _ in range(50):
# 50 seems to consistently cause errors when order is not preserved.
# If the order problem is reintroduced this test will become flaky
# which is still a signal that the order is not preserved.
result = table.group_by("b", use_threads=False).aggregate([])
assert result["b"] == pa.chunked_array([["a"], ["b"]])