Skip to content

Commit

Permalink
Add passing ax down the plotting functions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625379714
Change-Id: I66332c26e6cafbda1dc1076beec9f096c14b65b8
  • Loading branch information
vezhnick authored and copybara-github committed Apr 16, 2024
1 parent 8242497 commit 449cd8f
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions concordia/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def plot_line_measurement_channel(measurements_obj: measurements.Measurements,
channel_name: str,
group_by: str = 'player',
xaxis: str = 'time',
yaxis: str = 'value_float') -> None:
yaxis: str = 'value_float',
ax: plt.Axes = None) -> None:
"""Plots a pie chart of a measurement channel."""
if channel_name not in measurements_obj.available_channels():
raise ValueError(f'Unknown channel: {channel_name}')
Expand All @@ -36,7 +37,7 @@ def plot_line_measurement_channel(measurements_obj: measurements.Measurements,
channel.subscribe(on_next=data.append)

plot_df_line(pd.DataFrame(data), channel_name, group_by=group_by, xaxis=xaxis,
yaxis=yaxis)
yaxis=yaxis, ax=ax)


def plot_pie_measurement_channel(measurements_obj: measurements.Measurements,
Expand Down Expand Up @@ -91,7 +92,8 @@ def plot_df_line(df: pd.DataFrame,
title: str = 'Metric',
group_by: str = 'player',
xaxis: str = 'time',
yaxis: str = 'value_float') -> None:
yaxis: str = 'value_float',
ax: plt.Axes = None) -> None:
"""Plots a line chart of a dataframe.
Args:
Expand All @@ -102,8 +104,11 @@ def plot_df_line(df: pd.DataFrame,
the same value in this field, the y-axis values are averaged.
yaxis: The name of the column to use as the y-axis. The values in this
column must be numerical.
ax: The axis to plot on. If None, uses the current axis.
"""
ax = plt.gca()
if ax is None:
ax = plt.gca()

for player, group_df in df.groupby(group_by):
group_df = group_df.groupby(xaxis).mean(numeric_only=True).reset_index()
group_df.plot(x=xaxis, y=yaxis, label=player, ax=ax)
Expand Down

0 comments on commit 449cd8f

Please sign in to comment.