diff --git a/tests/test_utils_concat.py b/tests/test_utils_concat.py index 1c370729..2344a0e9 100644 --- a/tests/test_utils_concat.py +++ b/tests/test_utils_concat.py @@ -552,7 +552,6 @@ def test_concat(objs, overwrite, expected): pd.testing.assert_frame_equal(obj, expected) -# @pytest.mark.parametrize( 'objs, aggregate_function, expected', [ @@ -900,3 +899,64 @@ def test_concat_aggregate_function(objs, aggregate_function, expected): pd.testing.assert_series_equal(obj, expected) else: pd.testing.assert_frame_equal(obj, expected) + + +@pytest.mark.parametrize( + 'objs, aggregate_function, expected', + [ + # empty + ( + [], + None, + pd.Series([], pd.Index([]), dtype='object'), + ), + # identical values + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + None, + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ], + np.mean, + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + ), + # different values + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ], + None, + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ), + ( + [ + pd.Series([1, 2], pd.Index(['a', 'b']), dtype='float'), + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ], + np.mean, + pd.Series([2, 3], pd.Index(['a', 'b']), dtype='float'), + ), + ] +) +def test_concat_overwrite_aggregate_function( + objs, + aggregate_function, + expected, +): + obj = audformat.utils.concat( + objs, + overwrite=True, + aggregate_function=aggregate_function, + ) + if isinstance(obj, pd.Series): + pd.testing.assert_series_equal(obj, expected) + else: + pd.testing.assert_frame_equal(obj, expected)