Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
import math
import random as rnd
import pytest
import dask.bag as db
from dask.bag import random
def test_choices_size_exactly_k():
seq = range(20)
sut = db.from_sequence(seq, npartitions=3)
li = list(random.choices(sut, k=2).compute())
assert len(li) == 2
assert all(i in seq for i in li)
def test_choices_k_bigger_than_bag_size():
seq = range(3)
sut = db.from_sequence(seq, npartitions=3)
li = list(random.choices(sut, k=4).compute())
assert len(li) == 4
assert all(i in seq for i in li)
def test_choices_empty_partition():
seq = range(10)
sut = db.from_sequence(seq, partition_size=9)
sut = sut.repartition(3)
li = list(random.choices(sut, k=2).compute())
assert sut.map_partitions(len).compute() == (9, 0, 1)
assert len(li) == 2
assert all(i in seq for i in li)
def test_choices_k_bigger_than_smallest_partition_size():
seq = range(10)
sut = db.from_sequence(seq, partition_size=9)
li = list(random.choices(sut, k=2).compute())
assert sut.map_partitions(len).compute() == (9, 1)
assert len(li) == 2
assert all(i in seq for i in li)
def test_choices_k_equal_bag_size_with_unbalanced_partitions():
seq = range(10)
sut = db.from_sequence(seq, partition_size=9)
li = list(random.choices(sut, k=10).compute())
assert sut.map_partitions(len).compute() == (9, 1)
assert len(li) == 10
assert all(i in seq for i in li)
def test_choices_with_more_bag_partitons():
# test with npartitions > split_every
seq = range(100)
sut = db.from_sequence(seq, npartitions=10)
li = list(random.choices(sut, k=10, split_every=8).compute())
assert sut.map_partitions(len).compute() == (10, 10, 10, 10, 10, 10, 10, 10, 10, 10)
assert len(li) == 10
assert all(i in seq for i in li)
def test_sample_with_more_bag_partitons():
# test with npartitions > split_every
seq = range(100)
sut = db.from_sequence(seq, npartitions=10)
li = list(random.sample(sut, k=10, split_every=8).compute())
assert sut.map_partitions(len).compute() == (10, 10, 10, 10, 10, 10, 10, 10, 10, 10)
assert len(li) == 10
assert all(i in seq for i in li)
assert len(set(li)) == len(li)
def test_sample_size_exactly_k():
seq = range(20)
sut = db.from_sequence(seq, npartitions=3)
li = list(random.sample(sut, k=2).compute())
assert sut.map_partitions(len).compute() == (7, 7, 6)
assert len(li) == 2
assert all(i in seq for i in li)
assert len(set(li)) == len(li)
def test_sample_k_bigger_than_bag_size():
seq = range(3)
sut = db.from_sequence(seq, npartitions=3)
# should raise: Sample larger than population or is negative
with pytest.raises(ValueError, match="Sample larger than population"):
random.sample(sut, k=4).compute()
def test_sample_empty_partition():
seq = range(10)
sut = db.from_sequence(seq, partition_size=9)
sut = sut.repartition(3)
li = list(random.sample(sut, k=2).compute())
assert sut.map_partitions(len).compute() == (9, 0, 1)
assert len(li) == 2
assert all(i in seq for i in li)
assert len(set(li)) == len(li)
def test_sample_size_k_bigger_than_smallest_partition_size():
seq = range(10)
sut = db.from_sequence(seq, partition_size=9)
li = list(random.sample(sut, k=2).compute())
assert sut.map_partitions(len).compute() == (9, 1)
assert len(li) == 2
assert all(i in seq for i in li)
assert len(set(li)) == len(li)
def test_sample_k_equal_bag_size_with_unbalanced_partitions():
seq = range(10)
sut = db.from_sequence(seq, partition_size=9)
li = list(random.sample(sut, k=10).compute())
assert sut.map_partitions(len).compute() == (9, 1)
assert len(li) == 10
assert all(i in seq for i in li)
assert len(set(li)) == len(li)
def test_sample_k_larger_than_partitions():
bag = db.from_sequence(range(10), partition_size=3)
bag2 = random.sample(bag, k=8, split_every=2)
seq = bag2.compute()
assert len(seq) == 8
def test_weighted_sampling_without_replacement():
population = range(4)
p = [0.01, 0.33, 0.33, 0.33]
k = 3
sampled = random._weighted_sampling_without_replacement(
population=population, weights=p, k=k
)
assert len(set(sampled)) == k
def test_sample_return_bag():
seq = range(20)
sut = db.from_sequence(seq, npartitions=3)
assert isinstance(random.sample(sut, k=2), db.Bag)
def test_partitions_are_coerced_to_lists():
# https://github.com/dask/dask/issues/6906
A = db.from_sequence([[1, 2], [3, 4, 5], [6], [7]])
B = db.from_sequence(["a", "b", "c", "d"])
a = random.choices(A.flatten(), k=B.count().compute()).repartition(4)
C = db.zip(B, a).compute()
assert len(C) == 4
def test_reservoir_sample_map_partitions_correctness():
N, k = 20, 10
seq = list(range(N))
distribution = [0 for _ in range(N)]
expected_distribution = [0 for _ in range(N)]
reps = 2000
for _ in range(reps):
picks, _ = random._sample_map_partitions(seq, k)
for pick in picks:
distribution[pick] += 1
for pick in rnd.sample(seq, k=k):
expected_distribution[pick] += 1
# convert to probabilities
distribution = [c / (reps * k) for c in distribution]
expected_distribution = [c / (reps * k) for c in expected_distribution]
# use bhattacharyya distance to asses the similarity of distributions
assert math.isclose(
0.0, bhattacharyya(distribution, expected_distribution), abs_tol=1e-2
)
def test_reservoir_sample_with_replacement_map_partitions_correctness():
N, k = 20, 10
seq = list(range(N))
distribution = [0 for _ in range(N)]
expected_distribution = [0 for _ in range(N)]
reps = 2000
for _ in range(reps):
picks, _ = random._sample_with_replacement_map_partitions(seq, k)
for pick in picks:
distribution[pick] += 1
for pick in rnd.choices(seq, k=k):
expected_distribution[pick] += 1
# convert to probabilities
distribution = [c / (reps * k) for c in distribution]
expected_distribution = [c / (reps * k) for c in expected_distribution]
# use bhattacharyya distance to asses the similarity of distributions
assert math.isclose(
0.0, bhattacharyya(distribution, expected_distribution), abs_tol=1e-2
)
def bhattacharyya(h1, h2):
return 1 - sum(math.sqrt(hi * hj) for hi, hj in zip(h1, h2))