Skip to content

Commit

Permalink
Lint and documentation
Browse files Browse the repository at this point in the history
Signed-off-by: Olaf Lipinski <[email protected]>
  • Loading branch information
olipinski committed Nov 7, 2023
1 parent f69b95d commit c61c813
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 6 deletions.
205 changes: 204 additions & 1 deletion emlangkit/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
observations: Optional[np.ndarray] = None,
prev_horizon: int = 8,
seed: int = 42,
has_threshold: float = 0.8,
):
if not isinstance(messages, np.ndarray):
raise ValueError("Language only accepts numpy arrays!")
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
self.prev_horizon = prev_horizon

# HAS placeholders
self.has_threshold = has_threshold
self.__alpha = None
self.__freq = None
self.__branching_entropy = None
Expand Down Expand Up @@ -99,6 +101,11 @@ def topsim(self) -> tuple[float, float]:
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.observations is None:
raise ValueError(
Expand All @@ -125,6 +132,11 @@ def posdis(self):
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.observations is None:
raise ValueError(
Expand All @@ -150,6 +162,11 @@ def bosdis(self):
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.observations is None:
raise ValueError(
Expand All @@ -175,6 +192,11 @@ def language_entropy(self):
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
# This may have been calculated previously
if self.__langauge_entropy_value is None:
Expand All @@ -195,6 +217,11 @@ def observation_entropy(self):
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.observations is None:
raise ValueError(
Expand All @@ -221,6 +248,11 @@ def mutual_information(self):
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.observations is None:
raise ValueError("Observations are needed to calculate mutual information!")
Expand Down Expand Up @@ -253,6 +285,11 @@ def mpn(self):
Raises
------
ValueError: If observations are not set.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.observations is None:
raise ValueError("Observations are needed to calculate M_previous^n.")
Expand All @@ -266,6 +303,18 @@ def mpn(self):

# Harris' Articulation Scheme metrics
def branching_entropy(self):
"""
Calculate the branching entropy for a given language.
Returns
-------
float: The calculated branching entropy value.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.__branching_entropy is None:
if self.__freq is None:
self.__alpha, self.__freq = metrics.has_init(self.messages)
Expand All @@ -276,6 +325,19 @@ def branching_entropy(self):
return self.__branching_entropy

def conditional_entropy(self):
"""
Calculate the conditional entropy for a given language.
Returns
-------
float
The calculated conditional entropy value.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
# No need to even check for __freq as branching entropy already requires that
if self.__conditional_entropy is None:
if self.__branching_entropy is None:
Expand All @@ -287,11 +349,41 @@ def conditional_entropy(self):
return self.__conditional_entropy

def boundaries(self, return_count: bool = False, return_mean: bool = False):
"""
Calculate the HAS boundaries for a given language.
Parameters
----------
return_count : bool, optional
If True, the method will return the boundaries and the count of each boundary.
Default is False.
return_mean : bool, optional
If True, the method will return the boundaries, the count of each boundary,
and the mean count. Default is False.
Returns
-------
boundaries : list of lists
A list of boundary lists for each message in the language.
Optional Returns:
If `return_count` is True, the method will also return `nb`, which is a list
containing the count of each boundary.
If `return_mean` is True, the method will also return `nb` and `mean`. `nb` is
a list containing the count of each boundary, and `mean` is the mean count.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.__boundaries is None:
if self.__branching_entropy is None:
self.branching_entropy()
self.__boundaries = metrics.compute_boundaries(
self.messages, self.__branching_entropy, threshold=0.8
self.messages, self.__branching_entropy, threshold=self.has_threshold
)

if return_count:
Expand All @@ -311,6 +403,39 @@ def random_boundaries(
return_mean: bool = False,
recompute: bool = False,
):
"""
Calculate the random HAS boundaries for a given language.
Parameters
----------
return_count : bool, optional
If True, returns the random boundaries along with the number of boundary items for each boundary.
Default is False.
return_mean : bool, optional
If True, returns the random boundaries along with the number of boundary items for each boundary,
as well as the mean number of boundary items across all boundaries.
Default is False.
recompute : bool, optional
If True, forces the recomputation of the random boundaries.
Default is False.
Returns
-------
boundaries : list of lists
A list of random boundary lists for each message in the language.
Optional Returns:
If `return_count` is True, the method will also return `nb`, which is a list
containing the count of each boundary.
If `return_mean` is True, the method will also return `nb` and `mean`. `nb` is
a list containing the count of each boundary, and `mean` is the mean count.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.__random_boundaries is None and not recompute:
if self.__boundaries is None:
self.boundaries()
Expand All @@ -330,6 +455,34 @@ def random_boundaries(
return self.__random_boundaries

def segments(self, return_ids: bool = False, return_hashed_segments: bool = False):
"""
Calculate the HAS segments for a given language.
Parameters
----------
return_ids : bool, optional
If True, returns the segments along with their corresponding segment ids.
Default is False.
return_hashed_segments : bool, optional
If True, returns the segments along with their hashed versions.
Default is False.
Returns
-------
numpy.ndarray
Array of segments.
Optional Returns:
If `return_ids` is True, the method will also return segment_ids.
If `return_hashed_segments` is True, the method will also return the hashed segments.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.__segments is None:
if self.__boundaries is None:
self.boundaries()
Expand All @@ -356,6 +509,32 @@ def random_segments(
return_hashed_segments: bool = False,
recompute: bool = False,
):
"""
Calculate the random HAS segments for a given language.
Parameters
----------
return_ids : bool, optional
Specifies whether to return segment IDs along with the segments. Default is False.
return_hashed_segments : bool, optional
Specifies whether to return hashed segments along with the segments. Default is False.
recompute : bool, optional
Specifies whether to recompute the random segments. Default is False.
Returns
-------
numpy.ndarray
Array of segments.
Optional Returns:
If `return_ids` is True, the method will also return segment_ids.
If `return_hashed_segments` is True, the method will also return the hashed segments.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.__random_segments is None and not recompute:
if self.__random_boundaries is None and not recompute:
self.random_boundaries()
Expand All @@ -381,6 +560,30 @@ def random_segments(
return self.__random_segments

def has_stats(self, compute_topsim: bool = False) -> dict:
"""
Calculate the HAS statistics for a given language.
Parameters
----------
compute_topsim : bool, optional
Flag indicating whether to compute topographic similarity. Default is False.
Returns
-------
dict
A dictionary containing various statistics related to the language.
Raises
------
ValueError
If observations are None and compute_topsim is True.
Notes
-----
The result is cached and will only be computed once.
Subsequent calls to this method will return the cached value.
"""
if self.__has_stats is None:
if self.observations is None and compute_topsim:
raise ValueError(
Expand Down
7 changes: 5 additions & 2 deletions emlangkit/metrics/topsim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Calculate topographic similarity for a given language."""
from typing import Literal, Tuple
from typing import Tuple

import editdistance
import numpy as np
Expand Down Expand Up @@ -33,7 +33,10 @@ def compute_topographic_similarity(
Topographic similarity score.
"""
if message_dist_metric == "editdistance":
msg_metric = lambda x, y: editdistance.eval(x, y) / ((len(x) + len(y)) / 2)

def msg_metric(x, y):
return editdistance.eval(x, y) / ((len(x) + len(y)) / 2)

else:
msg_metric = message_dist_metric

Expand Down
7 changes: 4 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def test_mpn():


def test_has():
"""Tests to see if HAS metrics are calculated correctly."""
rng = np.random.default_rng(seed=42)

messages = np.array(
Expand All @@ -781,12 +782,12 @@ def test_has():

be = metrics.compute_branching_entropy(alpha, freq)

ce = metrics.compute_conditional_entropy(be, freq)
metrics.compute_conditional_entropy(be, freq)

boundaries = metrics.compute_boundaries(messages, be, 0.5)

segments = metrics.compute_segments(messages, boundaries)
metrics.compute_segments(messages, boundaries)

random_boundaries = metrics.compute_random_boundaries(messages, boundaries, rng)

random_segments = metrics.compute_segments(messages, random_boundaries)
metrics.compute_segments(messages, random_boundaries)

0 comments on commit c61c813

Please sign in to comment.