Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate_function argument to utils.concat() #401

Merged
merged 24 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 101 additions & 23 deletions audformat/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def concat(
objs: typing.Sequence[typing.Union[pd.Series, pd.DataFrame]],
*,
overwrite: bool = False,
aggregate_function: typing.Callable[[pd.Series], typing.Any] = None,
) -> typing.Union[pd.Series, pd.DataFrame]:
r"""Concatenate objects.

Expand Down Expand Up @@ -64,18 +65,35 @@ def concat(
or one column contains ``NaN``.
If ``overwrite`` is set to ``True``,
the value of the last object in the list is kept.
If ``overwrite`` is set to ``False``,
a custom aggregation function can be provided
with ``aggregate_function``
that converts the overlapping values
into a single value.

Args:
objs: objects
overwrite: overwrite values where indices overlap
aggregate_function: function to aggregate overlapping values.
The function gets a :class:`pandas.Series`
with overlapping values
as input.
E.g. set to
``lambda y: y.mean()``
to average the values
or to
``tuple``
to return them as a tuple

Returns:
concatenated objects

Raises:
ValueError: if level and dtypes of object indices do not match
ValueError: if columns with the same name have different dtypes
ValueError: if values in the same position do not match
ValueError: if ``aggregate_function`` is ``None``,
``overwrite`` is ``False``,
and values in the same position do not match

Examples:
>>> concat(
Expand All @@ -97,6 +115,15 @@ def concat(
0 0 1
>>> concat(
... [
... pd.Series([1], index=pd.Index([0])),
... pd.Series([2], index=pd.Index([0])),
... ],
... aggregate_function=np.sum,
... )
0 3
dtype: Int64
>>> concat(
... [
... pd.Series(
... [0., 1.],
... index=pd.Index(
Expand Down Expand Up @@ -194,6 +221,7 @@ def concat(

# reindex all columns to the new index
columns_reindex = {}
overlapping_values = {}
for column in columns:

# if we already have a column with that name, we have to merge them
Expand Down Expand Up @@ -223,36 +251,50 @@ def concat(
columns_reindex[column.name].astype('float64')
)

# overlapping values must match or have to be nan in one column
# Handle overlapping values
if not overwrite:
intersection = intersect(
[
columns_reindex[column.name].index,
column.index,
columns_reindex[column.name].dropna().index,
column.dropna().index,
]
)
# We use len() here as index.empty takes a very long time
if len(intersection) > 0:
combine = pd.DataFrame(
{
'left': columns_reindex[column.name][intersection],
'right': column[intersection]
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
msg_tail = '\n...' \
if len(overlap) > max_display \
else ''
raise ValueError(
"Found overlapping data in column "
f"'{column.name}':\n"
f"{msg_overlap}{msg_tail}"

# Store overlap if custom aggregate function is provided
if aggregate_function is not None:
if column.name not in overlapping_values:
overlapping_values[column.name] = []
overlapping_values[column.name].append(
column.loc[intersection]
)
column = column.loc[~column.index.isin(intersection)]

# Raise error if values don't match and are not NaN
else:
combine = pd.DataFrame(
{
'left':
columns_reindex[column.name][intersection],
'right':
column[intersection],
}
)
combine.dropna(inplace=True)
differ = combine['left'] != combine['right']
if np.any(differ):
max_display = 10
overlap = combine[differ]
msg_overlap = str(overlap[:max_display])
msg_tail = '\n...' \
if len(overlap) > max_display \
else ''
raise ValueError(
"Found overlapping data in column "
f"'{column.name}':\n"
f"{msg_overlap}{msg_tail}"
)

# drop NaN to avoid overwriting values from other column
column = column.dropna()
Expand All @@ -269,6 +311,42 @@ def concat(
)
columns_reindex[column.name][column.index] = column

# Apply custom aggregation function
# on collected overlapping data
# (no overlapping data is collected
# when no aggregation function is provided)
if len(overlapping_values) > 0:
for column in overlapping_values:

# Add data of first column
# overlapping with all other columns
union_index = union(
[y.index for y in overlapping_values[column]]
)
overlapping_values[column].insert(
0,
columns_reindex[column].loc[union_index]
)

# Convert list of overlapping data series to data frame
# and apply aggregate function
df = pd.concat(
overlapping_values[column],
axis=1,
ignore_index=True,
)
dtype = columns_reindex[column].dtype
y = df.apply(aggregate_function, axis=1)

# Restore the original dtype if possible
try:
y = y.astype(dtype)
except (TypeError, ValueError):
columns_reindex[column] = columns_reindex[column].astype(
y.dtype
)
columns_reindex[column].loc[y.index] = y

# Use `None` to force `{}` return the correct index, see
# https://github.com/pandas-dev/pandas/issues/52404
df = pd.DataFrame(columns_reindex or None, index=index)
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
'audeer': ('https://audeering.github.io/audeer/', None),
'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None),
'python': ('https://docs.python.org/3/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
}
# Ignore package dependencies during building the docs
autodoc_mock_imports = [
Expand Down
Loading