Repository URL to install this package:
|
Version:
4.0.1 ▾
|
import pandas as pd
from sarus_statistics.ops.utils import rescale_weights
def test_rescale_weights():
# Apply the `rescale_weights` function
data = pd.DataFrame(
{
'user': ['A', 'A', 'B', 'B', 'C', 'C'],
'data': [100, 200, 100, 200, 300, 300],
'private': [0, 1, 0, 1, 0, 1], # 0 for public, 1 for private
'weight': [1.0, 2.0, 1.5, 2.5, 1.0, 3.0],
}
)
# Define a maximum multiplicity
max_multiplicity = 2.5
rescaled_data = rescale_weights(
data=data,
user_col='user',
private_col='private',
weight_col='weight',
max_multiplicity=max_multiplicity,
)
# Check that the output is a DataFrame
assert isinstance(
rescaled_data, pd.DataFrame
), "Output should be a DataFrame."
# Check that no user's private weight sum exceeds the max_multiplicity
private_data = rescaled_data[~rescaled_data['private'].astype('bool')]
weight_sums = private_data.groupby('user')['weight'].sum()
assert all(
weight_sums <= max_multiplicity
), "No user's private weight sum should exceed max_multiplicity."