Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate_function argument to utils.concat() #401

Merged
merged 24 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 82 additions & 20 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def concat(
objs: typing.Sequence[typing.Union[pd.Series, pd.DataFrame]],
*,
overwrite: bool = False,
aggregate_function: typing.Callable[[pd.Series], typing.Any] = None,
) -> typing.Union[pd.Series, pd.DataFrame]:
r"""Concatenate objects.

Expand Down Expand Up @@ -64,18 +65,35 @@ def concat(
or one column contains ``NaN``.
If ``overwrite`` is set to ``True``,
the value of the last object in the list is kept.
If ``overwrite`` is set to ``False``,
a custom aggregation function can be provided
with ``aggregate_function``
that converts the overlapping values
into a single value.

Args:
objs: objects
overwrite: overwrite values where indices overlap
aggregate_function: function to aggregate values
for all entries
that contain more than one value per index.
hagenw marked this conversation as resolved.
Show resolved Hide resolved
The function gets a dataframe row as input.
frankenjoe marked this conversation as resolved.
Show resolved Hide resolved
E.g. set to
:func:`numpy.mean`
to average the values
or to
``lambda row: tuple(row.to_list())``
frankenjoe marked this conversation as resolved.
Show resolved Hide resolved
to return all values as a tuple
hagenw marked this conversation as resolved.
Show resolved Hide resolved

Returns:
concatenated objects

Raises:
ValueError: if level and dtypes of object indices do not match
ValueError: if columns with the same name have different dtypes
ValueError: if values in the same position do not match
ValueError: if ``aggregate_function`` is ``None``,
``overwrite`` is ``False``,
and values in the same position do not match

Examples:
>>> concat(
Expand All @@ -97,6 +115,15 @@ def concat(
0 0 1
>>> concat(
... [
... pd.Series([1], index=pd.Index([0])),
... pd.Series([2], index=pd.Index([0])),
... ],
... aggregate_function=np.sum,
... )
0 3
dtype: Int64
>>> concat(
... [
... pd.Series(
... [0., 1.],
... index=pd.Index(
Expand Down Expand Up @@ -194,6 +221,7 @@ def concat(

# reindex all columns to the new index
columns_reindex = {}
overlapping_values = {}
for column in columns:

# if we already have a column with that name, we have to merge them
Expand Down Expand Up @@ -233,26 +261,41 @@ def concat(
)
# We use len() here as index.empty takes a very long time
if len(intersection) > 0:
combine = pd.DataFrame(
{
'left': columns_reindex[column.name][intersection],
'right': column[intersection]
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
msg_tail = '\n...' \
if len(overlap) > max_display \
else ''
raise ValueError(
"Found overlapping data in column "
f"'{column.name}':\n"
f"{msg_overlap}{msg_tail}"

# Custom handling of overlapping values
if aggregate_function is not None:
if column.name not in overlapping_values:
overlapping_values[column.name] = [
columns_reindex[column.name].loc[intersection]
]
overlapping_values[column.name].append(
column.loc[intersection]
)
column = column.loc[~column.index.isin(intersection)]

else:
combine = pd.DataFrame(
{
'left':
columns_reindex[column.name][intersection],
'right':
column[intersection]
hagenw marked this conversation as resolved.
Show resolved Hide resolved
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
msg_tail = '\n...' \
if len(overlap) > max_display \
else ''
raise ValueError(
"Found overlapping data in column "
f"'{column.name}':\n"
f"{msg_overlap}{msg_tail}"
)

# drop NaN to avoid overwriting values from other column
column = column.dropna()
Expand All @@ -269,6 +312,25 @@ def concat(
)
columns_reindex[column.name][column.index] = column

# Apply custom aggregation function
# on collected overlapping data
# (no overlapping data is collected
# when no aggregation function is provided)
if len(overlapping_values) > 0:
for column in overlapping_values:
df = pd.concat(
overlapping_values[column],
axis=1,
ignore_index=True,
)
dtype = columns_reindex[column].dtype
columns_reindex[column] = df.apply(aggregate_function, axis=1)
# Restore the original dtype if possible
try:
columns_reindex[column] = columns_reindex[column].astype(dtype)
except (TypeError, ValueError):
pass
frankenjoe marked this conversation as resolved.
Show resolved Hide resolved

# Use `None` to force `{}` return the correct index, see
# https://github.com/pandas-dev/pandas/issues/52404
df = pd.DataFrame(columns_reindex or None, index=index)
Expand Down
Loading