From cd8575b2512b29b7855c48b024df0411adbe1b36 Mon Sep 17 00:00:00 2001 From: Olaf Lipinski <5785856+olipinski@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:36:42 +0000 Subject: [PATCH] Add M_previous^n metric (#1) Add M_previous^n metric Signed-off-by: Olaf Lipinski --- emlangkit/language.py | 54 ++++++++ emlangkit/metrics/__init__.py | 2 + emlangkit/metrics/mpn.py | 99 +++++++++++++++ tests/test_metrics.py | 224 ++++++++++++++++++++++++++++++++++ 4 files changed, 379 insertions(+) create mode 100644 emlangkit/metrics/mpn.py diff --git a/emlangkit/language.py b/emlangkit/language.py index c35e79b..7037a30 100644 --- a/emlangkit/language.py +++ b/emlangkit/language.py @@ -12,12 +12,34 @@ class Language: It takes the messages and observations for an emergent language, and allows calculations of the most commonly used metrics. + + Parameters + ---------- + messages : numpy.ndarray + Numpy array containing the messages. + observations : numpy.ndarray, optional + Numpy array containing the observations. Default is None. + seed : int, optional + Seed value for random number generation. Default is 42. + + Examples + -------- + Create a Language object with messages and observations: + >>> messages = np.array([1, 2, 3, 4, 5]) + >>> observations = np.array([6, 7, 8, 9, 10]) + >>> lang = Language(messages, observations) + + Create a Language object with only messages and default seed: + >>> messages = np.array([1, 2, 3, 4, 5]) + >>> lang = Language(messages) """ def __init__( self, messages: np.ndarray, observations: Optional[np.ndarray] = None, + prev_horizon: int = 8, + seed: int = 42, ): if not isinstance(messages, np.ndarray): raise ValueError("Language only accepts numpy arrays!") @@ -34,6 +56,8 @@ def __init__( self.messages = messages self.observations = observations + self.__rng = np.random.default_rng(seed=seed) + # Placeholders self.__topsim_value = None self.__posdis_value = None @@ -42,6 +66,10 @@ def __init__( self.__observation_entropy_value = None self.__mutual_information_value = None + # M_previous^n placeholders + self.__mpn_value = None + self.prev_horizon = prev_horizon + def topsim(self) -> tuple[float, float]: """ Calculate the topographic similarity score for the language. @@ -194,3 +222,29 @@ def mutual_information(self): ) return self.__mutual_information_value + + # M_previous_n metric + + def mpn(self): + """ + Calculate the M_previous^n score for the language. + + This method requires observations to be set in the class. + + Returns + ------- + float: The highest M_previous^n value. + + Raises + ------ + ValueError: If observations are not set. + """ + if self.observations is None: + raise ValueError("Observations are needed to calculate M_previous^n.") + + if self.__mpn_value is None: + self.__mpn_value = metrics.compute_mpn( + self.messages, self.observations, self.prev_horizon + ) + + return self.__mpn_value diff --git a/emlangkit/metrics/__init__.py b/emlangkit/metrics/__init__.py index 541ae36..07f022d 100644 --- a/emlangkit/metrics/__init__.py +++ b/emlangkit/metrics/__init__.py @@ -1,6 +1,7 @@ """Root __init__ of the metrics.""" from emlangkit.metrics.bosdis import compute_bosdis from emlangkit.metrics.entropy import compute_entropy +from emlangkit.metrics.mpn import compute_mpn from emlangkit.metrics.mutual_information import compute_mutual_information from emlangkit.metrics.posdis import compute_posdis from emlangkit.metrics.topsim import compute_topographic_similarity @@ -12,4 +13,5 @@ "compute_mutual_information", "compute_posdis", "compute_topographic_similarity", + "compute_mpn", ] diff --git a/emlangkit/metrics/mpn.py b/emlangkit/metrics/mpn.py new file mode 100644 index 0000000..1f93140 --- /dev/null +++ b/emlangkit/metrics/mpn.py @@ -0,0 +1,99 @@ +""" +The M_previous^n metric. + +Adapted from https://arxiv.org/abs/2310.06555 +""" + +import numpy as np + + +def compute_mpn( + messages: np.ndarray, + observations: np.ndarray, + prev_horizon: int, + return_stats: bool = False, +): + """ + Calculate the M_previous^n metric. + + This function assumes that the messages and observations are temporally ordered (i.e., index 0 is first timestep, + last index is last timestep). + + The metric will be computed for all horizons up to and including prev_horizon, i.e., [1,prev_horizon]. + + Parameters + ---------- + messages : np.ndarray + The messages temporally ordered messages. + observations : np.ndarray + The temporally ordered observations. + prev_horizon : int + The horizon up to which to calculate the metric. + + Returns + ------- + mpn : np.ndarray + The highest M_previous^n value for each horizon. + msg_stats : dict + The stats for each unique message. Only returned if `return_stats` is True. + """ + msgs, inverse, msg_counts = np.unique( + messages, return_counts=True, return_inverse=True, axis=0 + ) + + msg_stats = { + f"{msg}": { + "count": msg_counts[idx], + "same_as_previous_obj": np.zeros(shape=prev_horizon + 1, dtype=np.int32), + "prev_use_percentage": np.zeros(shape=prev_horizon + 1, dtype=np.float32), + } + for idx, msg in enumerate(msgs) + } + + # Times that the object was the same as the previous object + # first we go from first to last observation + for i in range(len(observations)): + # Then starting at a given observation we look to the future to see if it repeats + for horizon in range(1, prev_horizon + 1): + # We cannot look beyond the end of the array + if horizon + i >= len(observations): + break + # If it repeats, then we count it as a possible temporal reference for a given message + if np.array_equal(observations[i], observations[horizon + i]): + # We get the message index from the inverse of the messages + # This looks complicated, but it's not too bad + msg_stats[f"{msgs[inverse[horizon + i]]}"]["same_as_previous_obj"][ + horizon + ] += 1 + # Break, otherwise if there are multiple repeats in a horizon + # They could get labelled twice, and incorrectly + break + + mpn = np.zeros(shape=prev_horizon, dtype=np.float32) + + for msg in msg_stats: + if msg_stats[f"{msg}"]["count"] > 0: + for horizon in range(1, prev_horizon + 1): + if msg_stats[f"{msg}"]["same_as_previous_obj"][horizon] != 0: + msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = ( + round( + msg_stats[f"{msg}"]["same_as_previous_obj"][horizon] + / msg_stats[f"{msg}"]["count"], + 3, + ) + * 100 + ) + if ( + msg_stats[f"{msg}"]["prev_use_percentage"][horizon] + > mpn[horizon] + ): + mpn[horizon] = msg_stats[f"{msg}"]["prev_use_percentage"][ + horizon + ] + else: + msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = 0 + + if return_stats: + return mpn, msg_stats + else: + return mpn diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 98375ba..c1ca26a 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -501,3 +501,227 @@ def test_bosdis(): 1.0, 2, ) + + +def test_mpn(): + """Tests to see if M_previous^n is calculated correctly.""" + test_obs = np.array([[x, y] for x in range(4) for y in range(4)]) + + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 0, 0], + [0, 0, 1], + [0, 0, 2], + [0, 0, 3], + [0, 1, 0], + [0, 1, 1], + [0, 1, 2], + [0, 1, 3], + [2, 0, 0], + [2, 0, 1], + [2, 0, 2], + [2, 0, 3], + [2, 1, 0], + [2, 1, 1], + [2, 1, 2], + [2, 1, 3], + ] + ), + observations=test_obs, + prev_horizon=8, + )[0], + 0, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [1], + [2], + [2], + [2], + ] + ), + observations=np.array( + [ + [4], + [4], + [4], + [4], + ] + ), + prev_horizon=8, + )[1], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [1], + [2], + [2], + [2], + ] + ), + observations=np.array( + [ + [4], + [4], + [4], + [4], + ] + ), + prev_horizon=8, + )[2], + 0, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[1], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[2], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[3], + 100, + 2, + ) + + # noinspection PyTypeChecker + np.testing.assert_almost_equal( + metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + )[4], + 100, + 2, + )