Skip to content

Commit

Permalink
Add M_previous^n metric (#1)
Browse files Browse the repository at this point in the history
Add M_previous^n metric

Signed-off-by: Olaf Lipinski <[email protected]>
  • Loading branch information
olipinski authored Nov 6, 2023
1 parent c4b107d commit cd8575b
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 0 deletions.
54 changes: 54 additions & 0 deletions emlangkit/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,34 @@ class Language:
It takes the messages and observations for an emergent language, and
allows calculations of the most commonly used metrics.
Parameters
----------
messages : numpy.ndarray
Numpy array containing the messages.
observations : numpy.ndarray, optional
Numpy array containing the observations. Default is None.
seed : int, optional
Seed value for random number generation. Default is 42.
Examples
--------
Create a Language object with messages and observations:
>>> messages = np.array([1, 2, 3, 4, 5])
>>> observations = np.array([6, 7, 8, 9, 10])
>>> lang = Language(messages, observations)
Create a Language object with only messages and default seed:
>>> messages = np.array([1, 2, 3, 4, 5])
>>> lang = Language(messages)
"""

def __init__(
self,
messages: np.ndarray,
observations: Optional[np.ndarray] = None,
prev_horizon: int = 8,
seed: int = 42,
):
if not isinstance(messages, np.ndarray):
raise ValueError("Language only accepts numpy arrays!")
Expand All @@ -34,6 +56,8 @@ def __init__(
self.messages = messages
self.observations = observations

self.__rng = np.random.default_rng(seed=seed)

# Placeholders
self.__topsim_value = None
self.__posdis_value = None
Expand All @@ -42,6 +66,10 @@ def __init__(
self.__observation_entropy_value = None
self.__mutual_information_value = None

# M_previous^n placeholders
self.__mpn_value = None
self.prev_horizon = prev_horizon

def topsim(self) -> tuple[float, float]:
"""
Calculate the topographic similarity score for the language.
Expand Down Expand Up @@ -194,3 +222,29 @@ def mutual_information(self):
)

return self.__mutual_information_value

# M_previous_n metric

def mpn(self):
"""
Calculate the M_previous^n score for the language.
This method requires observations to be set in the class.
Returns
-------
float: The highest M_previous^n value.
Raises
------
ValueError: If observations are not set.
"""
if self.observations is None:
raise ValueError("Observations are needed to calculate M_previous^n.")

if self.__mpn_value is None:
self.__mpn_value = metrics.compute_mpn(
self.messages, self.observations, self.prev_horizon
)

return self.__mpn_value
2 changes: 2 additions & 0 deletions emlangkit/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Root __init__ of the metrics."""
from emlangkit.metrics.bosdis import compute_bosdis
from emlangkit.metrics.entropy import compute_entropy
from emlangkit.metrics.mpn import compute_mpn
from emlangkit.metrics.mutual_information import compute_mutual_information
from emlangkit.metrics.posdis import compute_posdis
from emlangkit.metrics.topsim import compute_topographic_similarity
Expand All @@ -12,4 +13,5 @@
"compute_mutual_information",
"compute_posdis",
"compute_topographic_similarity",
"compute_mpn",
]
99 changes: 99 additions & 0 deletions emlangkit/metrics/mpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
The M_previous^n metric.
Adapted from https://arxiv.org/abs/2310.06555
"""

import numpy as np


def compute_mpn(
messages: np.ndarray,
observations: np.ndarray,
prev_horizon: int,
return_stats: bool = False,
):
"""
Calculate the M_previous^n metric.
This function assumes that the messages and observations are temporally ordered (i.e., index 0 is first timestep,
last index is last timestep).
The metric will be computed for all horizons up to and including prev_horizon, i.e., [1,prev_horizon].
Parameters
----------
messages : np.ndarray
The messages temporally ordered messages.
observations : np.ndarray
The temporally ordered observations.
prev_horizon : int
The horizon up to which to calculate the metric.
Returns
-------
mpn : np.ndarray
The highest M_previous^n value for each horizon.
msg_stats : dict
The stats for each unique message. Only returned if `return_stats` is True.
"""
msgs, inverse, msg_counts = np.unique(
messages, return_counts=True, return_inverse=True, axis=0
)

msg_stats = {
f"{msg}": {
"count": msg_counts[idx],
"same_as_previous_obj": np.zeros(shape=prev_horizon + 1, dtype=np.int32),
"prev_use_percentage": np.zeros(shape=prev_horizon + 1, dtype=np.float32),
}
for idx, msg in enumerate(msgs)
}

# Times that the object was the same as the previous object
# first we go from first to last observation
for i in range(len(observations)):
# Then starting at a given observation we look to the future to see if it repeats
for horizon in range(1, prev_horizon + 1):
# We cannot look beyond the end of the array
if horizon + i >= len(observations):
break
# If it repeats, then we count it as a possible temporal reference for a given message
if np.array_equal(observations[i], observations[horizon + i]):
# We get the message index from the inverse of the messages
# This looks complicated, but it's not too bad
msg_stats[f"{msgs[inverse[horizon + i]]}"]["same_as_previous_obj"][
horizon
] += 1
# Break, otherwise if there are multiple repeats in a horizon
# They could get labelled twice, and incorrectly
break

mpn = np.zeros(shape=prev_horizon, dtype=np.float32)

for msg in msg_stats:
if msg_stats[f"{msg}"]["count"] > 0:
for horizon in range(1, prev_horizon + 1):
if msg_stats[f"{msg}"]["same_as_previous_obj"][horizon] != 0:
msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = (
round(
msg_stats[f"{msg}"]["same_as_previous_obj"][horizon]
/ msg_stats[f"{msg}"]["count"],
3,
)
* 100
)
if (
msg_stats[f"{msg}"]["prev_use_percentage"][horizon]
> mpn[horizon]
):
mpn[horizon] = msg_stats[f"{msg}"]["prev_use_percentage"][
horizon
]
else:
msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = 0

if return_stats:
return mpn, msg_stats
else:
return mpn
Loading

0 comments on commit cd8575b

Please sign in to comment.