From ab98eec5549804d56cfe827e4e79fb644296122a Mon Sep 17 00:00:00 2001 From: Olaf Lipinski Date: Mon, 6 Nov 2023 13:57:31 +0000 Subject: [PATCH] Add M_previous^n metric Signed-off-by: Olaf Lipinski --- emlangkit/metrics/__init__.py | 2 + emlangkit/metrics/mpn.py | 99 +++++++++++++++ tests/test_metrics.py | 224 ++++++++++++++++++++++++++++++++++ 3 files changed, 325 insertions(+) create mode 100644 emlangkit/metrics/mpn.py diff --git a/emlangkit/metrics/__init__.py b/emlangkit/metrics/__init__.py index 541ae36..07f022d 100644 --- a/emlangkit/metrics/__init__.py +++ b/emlangkit/metrics/__init__.py @@ -1,6 +1,7 @@ """Root __init__ of the metrics.""" from emlangkit.metrics.bosdis import compute_bosdis from emlangkit.metrics.entropy import compute_entropy +from emlangkit.metrics.mpn import compute_mpn from emlangkit.metrics.mutual_information import compute_mutual_information from emlangkit.metrics.posdis import compute_posdis from emlangkit.metrics.topsim import compute_topographic_similarity @@ -12,4 +13,5 @@ "compute_mutual_information", "compute_posdis", "compute_topographic_similarity", + "compute_mpn", ] diff --git a/emlangkit/metrics/mpn.py b/emlangkit/metrics/mpn.py new file mode 100644 index 0000000..1f93140 --- /dev/null +++ b/emlangkit/metrics/mpn.py @@ -0,0 +1,99 @@ +""" +The M_previous^n metric. + +Adapted from https://arxiv.org/abs/2310.06555 +""" + +import numpy as np + + +def compute_mpn( + messages: np.ndarray, + observations: np.ndarray, + prev_horizon: int, + return_stats: bool = False, +): + """ + Calculate the M_previous^n metric. + + This function assumes that the messages and observations are temporally ordered (i.e., index 0 is first timestep, + last index is last timestep). + + The metric will be computed for all horizons up to and including prev_horizon, i.e., [1,prev_horizon]. + + Parameters + ---------- + messages : np.ndarray + The messages temporally ordered messages. + observations : np.ndarray + The temporally ordered observations. + prev_horizon : int + The horizon up to which to calculate the metric. + + Returns + ------- + mpn : np.ndarray + The highest M_previous^n value for each horizon. + msg_stats : dict + The stats for each unique message. Only returned if `return_stats` is True. + """ + msgs, inverse, msg_counts = np.unique( + messages, return_counts=True, return_inverse=True, axis=0 + ) + + msg_stats = { + f"{msg}": { + "count": msg_counts[idx], + "same_as_previous_obj": np.zeros(shape=prev_horizon + 1, dtype=np.int32), + "prev_use_percentage": np.zeros(shape=prev_horizon + 1, dtype=np.float32), + } + for idx, msg in enumerate(msgs) + } + + # Times that the object was the same as the previous object + # first we go from first to last observation + for i in range(len(observations)): + # Then starting at a given observation we look to the future to see if it repeats + for horizon in range(1, prev_horizon + 1): + # We cannot look beyond the end of the array + if horizon + i >= len(observations): + break + # If it repeats, then we count it as a possible temporal reference for a given message + if np.array_equal(observations[i], observations[horizon + i]): + # We get the message index from the inverse of the messages + # This looks complicated, but it's not too bad + msg_stats[f"{msgs[inverse[horizon + i]]}"]["same_as_previous_obj"][ + horizon + ] += 1 + # Break, otherwise if there are multiple repeats in a horizon + # They could get labelled twice, and incorrectly + break + + mpn = np.zeros(shape=prev_horizon, dtype=np.float32) + + for msg in msg_stats: + if msg_stats[f"{msg}"]["count"] > 0: + for horizon in range(1, prev_horizon + 1): + if msg_stats[f"{msg}"]["same_as_previous_obj"][horizon] != 0: + msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = ( + round( + msg_stats[f"{msg}"]["same_as_previous_obj"][horizon] + / msg_stats[f"{msg}"]["count"], + 3, + ) + * 100 + ) + if ( + msg_stats[f"{msg}"]["prev_use_percentage"][horizon] + > mpn[horizon] + ): + mpn[horizon] = msg_stats[f"{msg}"]["prev_use_percentage"][ + horizon + ] + else: + msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = 0 + + if return_stats: + return mpn, msg_stats + else: + return mpn diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 98375ba..c1ca26a 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -501,3 +501,227 @@ def test_bosdis(): 1.0, 2, ) + + +def test_mpn(): + """Tests to see if M_previous^n is calculated correctly.""" + test_obs = np.array([[x, y] for x in range(4) for y in range(4)]) + + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 0, 0], + [0, 0, 1], + [0, 0, 2], + [0, 0, 3], + [0, 1, 0], + [0, 1, 1], + [0, 1, 2], + [0, 1, 3], + [2, 0, 0], + [2, 0, 1], + [2, 0, 2], + [2, 0, 3], + [2, 1, 0], + [2, 1, 1], + [2, 1, 2], + [2, 1, 3], + ] + ), + observations=test_obs, + prev_horizon=8, + )[0], + 0, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [1], + [2], + [2], + [2], + ] + ), + observations=np.array( + [ + [4], + [4], + [4], + [4], + ] + ), + prev_horizon=8, + )[1], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [1], + [2], + [2], + [2], + ] + ), + observations=np.array( + [ + [4], + [4], + [4], + [4], + ] + ), + prev_horizon=8, + )[2], + 0, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[1], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[2], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[3], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[4], + 100, + 2, + )