Skip to content

Commit

Permalink
Add M_previous^n metric
Browse files Browse the repository at this point in the history
Signed-off-by: Olaf Lipinski <[email protected]>
  • Loading branch information
olipinski committed Nov 6, 2023
1 parent c4b107d commit ab98eec
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 0 deletions.
2 changes: 2 additions & 0 deletions emlangkit/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Root __init__ of the metrics."""
from emlangkit.metrics.bosdis import compute_bosdis
from emlangkit.metrics.entropy import compute_entropy
from emlangkit.metrics.mpn import compute_mpn
from emlangkit.metrics.mutual_information import compute_mutual_information
from emlangkit.metrics.posdis import compute_posdis
from emlangkit.metrics.topsim import compute_topographic_similarity
Expand All @@ -12,4 +13,5 @@
"compute_mutual_information",
"compute_posdis",
"compute_topographic_similarity",
"compute_mpn",
]
99 changes: 99 additions & 0 deletions emlangkit/metrics/mpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
The M_previous^n metric.
Adapted from https://arxiv.org/abs/2310.06555
"""

import numpy as np


def compute_mpn(
messages: np.ndarray,
observations: np.ndarray,
prev_horizon: int,
return_stats: bool = False,
):
"""
Calculate the M_previous^n metric.
This function assumes that the messages and observations are temporally ordered (i.e., index 0 is first timestep,
last index is last timestep).
The metric will be computed for all horizons up to and including prev_horizon, i.e., [1,prev_horizon].
Parameters
----------
messages : np.ndarray
The messages temporally ordered messages.
observations : np.ndarray
The temporally ordered observations.
prev_horizon : int
The horizon up to which to calculate the metric.
Returns
-------
mpn : np.ndarray
The highest M_previous^n value for each horizon.
msg_stats : dict
The stats for each unique message. Only returned if `return_stats` is True.
"""
msgs, inverse, msg_counts = np.unique(
messages, return_counts=True, return_inverse=True, axis=0
)

msg_stats = {
f"{msg}": {
"count": msg_counts[idx],
"same_as_previous_obj": np.zeros(shape=prev_horizon + 1, dtype=np.int32),
"prev_use_percentage": np.zeros(shape=prev_horizon + 1, dtype=np.float32),
}
for idx, msg in enumerate(msgs)
}

# Times that the object was the same as the previous object
# first we go from first to last observation
for i in range(len(observations)):
# Then starting at a given observation we look to the future to see if it repeats
for horizon in range(1, prev_horizon + 1):
# We cannot look beyond the end of the array
if horizon + i >= len(observations):
break
# If it repeats, then we count it as a possible temporal reference for a given message
if np.array_equal(observations[i], observations[horizon + i]):
# We get the message index from the inverse of the messages
# This looks complicated, but it's not too bad
msg_stats[f"{msgs[inverse[horizon + i]]}"]["same_as_previous_obj"][
horizon
] += 1
# Break, otherwise if there are multiple repeats in a horizon
# They could get labelled twice, and incorrectly
break

mpn = np.zeros(shape=prev_horizon, dtype=np.float32)

for msg in msg_stats:
if msg_stats[f"{msg}"]["count"] > 0:
for horizon in range(1, prev_horizon + 1):
if msg_stats[f"{msg}"]["same_as_previous_obj"][horizon] != 0:
msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = (
round(
msg_stats[f"{msg}"]["same_as_previous_obj"][horizon]
/ msg_stats[f"{msg}"]["count"],
3,
)
* 100
)
if (
msg_stats[f"{msg}"]["prev_use_percentage"][horizon]
> mpn[horizon]
):
mpn[horizon] = msg_stats[f"{msg}"]["prev_use_percentage"][
horizon
]
else:
msg_stats[f"{msg}"]["prev_use_percentage"][horizon] = 0

if return_stats:
return mpn, msg_stats
else:
return mpn
224 changes: 224 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,227 @@ def test_bosdis():
1.0,
2,
)


def test_mpn():
"""Tests to see if M_previous^n is calculated correctly."""
test_obs = np.array([[x, y] for x in range(4) for y in range(4)])

np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[0, 0, 0],
[0, 0, 1],
[0, 0, 2],
[0, 0, 3],
[0, 1, 0],
[0, 1, 1],
[0, 1, 2],
[0, 1, 3],
[2, 0, 0],
[2, 0, 1],
[2, 0, 2],
[2, 0, 3],
[2, 1, 0],
[2, 1, 1],
[2, 1, 2],
[2, 1, 3],
]
),
observations=test_obs,
prev_horizon=8,
)[0],
0,
2,
)

# noinspection PyTypeChecker
np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[1],
[2],
[2],
[2],
]
),
observations=np.array(
[
[4],
[4],
[4],
[4],
]
),
prev_horizon=8,
)[1],
100,
2,
)

# noinspection PyTypeChecker
np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[1],
[2],
[2],
[2],
]
),
observations=np.array(
[
[4],
[4],
[4],
[4],
]
),
prev_horizon=8,
)[2],
0,
2,
)

# noinspection PyTypeChecker
np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[4, 1],
[2, 1],
[4, 1],
[3, 1],
[1, 1],
]
),
observations=np.array(
[
[1],
[2],
[3],
[4],
[1],
[4],
[3],
[1],
[1],
]
),
prev_horizon=8,
)[1],
100,
2,
)

# noinspection PyTypeChecker
np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[4, 1],
[2, 1],
[4, 1],
[3, 1],
[1, 1],
]
),
observations=np.array(
[
[1],
[2],
[3],
[4],
[1],
[4],
[3],
[1],
[1],
]
),
prev_horizon=8,
)[2],
100,
2,
)

# noinspection PyTypeChecker
np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[4, 1],
[2, 1],
[4, 1],
[3, 1],
[1, 1],
]
),
observations=np.array(
[
[1],
[2],
[3],
[4],
[1],
[4],
[3],
[1],
[1],
]
),
prev_horizon=8,
)[3],
100,
2,
)

# noinspection PyTypeChecker
np.testing.assert_almost_equal(
metrics.compute_mpn(
messages=np.array(
[
[0, 1],
[0, 1],
[0, 1],
[0, 1],
[4, 1],
[2, 1],
[4, 1],
[3, 1],
[1, 1],
]
),
observations=np.array(
[
[1],
[2],
[3],
[4],
[1],
[4],
[3],
[1],
[1],
]
),
prev_horizon=8,
)[4],
100,
2,
)

0 comments on commit ab98eec

Please sign in to comment.