From a63a497994de0863c96d352ccb3889b029a17fcb Mon Sep 17 00:00:00 2001 From: Rehan Guha Date: Wed, 16 Oct 2024 19:18:08 +0530 Subject: [PATCH] Updated ClusterMapper --- README.md | 2 +- pdistmap/map.py | 83 ++++++++++++++++++++ pdistmap/{intersection.py => set.py} | 0 tests/test_map.py | 87 +++++++++++++++++++++ tests/{test_intersection.py => test_set.py} | 2 +- 5 files changed, 172 insertions(+), 2 deletions(-) create mode 100644 pdistmap/map.py rename pdistmap/{intersection.py => set.py} (100%) create mode 100644 tests/test_map.py rename tests/{test_intersection.py => test_set.py} (98%) diff --git a/README.md b/README.md index cb07db0..88ef386 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ pip install pdistmap ```python -from pdistmap.intersection import KDEIntersection +from pdistmap.set import KDEIntersection import numpy as np A = np.array([25, 40, 70, 65, 69, 75, 80, 85]) diff --git a/pdistmap/map.py b/pdistmap/map.py new file mode 100644 index 0000000..d7a5824 --- /dev/null +++ b/pdistmap/map.py @@ -0,0 +1,83 @@ +import numpy as np +from pdistmap.set import KDEIntersection + +class ClusterMapper: + def __init__(self, Adict: dict, Bdict: dict): + """ + Initialize the ClusterMapper with two dictionaries containing lists. + + Parameters: + - Adict (dict): A dictionary where each key corresponds to a list of values for comparison. + - Bdict (dict): A dictionary where each key corresponds to another list of values for comparison. + """ + self.Adict = Adict + self.Bdict = Bdict + + def _check_valid_n(self, n: int) -> None: + """ + Checks if the value of n is valid. + + Parameters: + - n: Number of top matches to validate. + + Raises: + - ValueError: If n is not greater than 1 or not less than the size of the lists in Adict or Bdict. + """ + # Ensure n is valid + if n < 1: + raise ValueError("n must be greater than 1.") + + # Check the size of the lists in Adict + for key in self.Adict: + if n >= len(self.Adict[key]): + raise ValueError(f"n must be less than the size of the list for key '{key}' in Adict.") + + # Check the size of the lists in Bdict + for key in self.Bdict: + if n >= len(self.Bdict[key]): + raise ValueError(f"n must be less than the size of the list for key '{key}' in Bdict.") + + + def list_similarity(self, list1: list, list2: list) -> float: + """ + Calculate the similarity between two lists using the KDEIntersection method. + + Parameters: + - list1 (list): The first list of values. + - list2 (list): The second list of values. + + Returns: + - float: A similarity score calculated as the intersection area between the two lists. + """ + return KDEIntersection(np.array(list1), np.array(list2)).intersection_area() + + def find_top_n_matches(self, n: int = 2) -> dict: + """ + Find the top 'n' closest matches between the lists of the two dictionaries. + + Parameters: + - n (int): The number of top matches to return for each key in Adict. Default is 2. + + Returns: + - dict: A dictionary mapping each key from Adict to its top 'n' matches from Bdict, + represented as a list of tuples containing the matching key and its similarity score. + """ + + self._check_valid_n(n) + + matches = {} + + # Iterate over each item in Adict + for key_A, value_A in self.Adict.items(): + similarities = [] + + # Iterate over each item in Bdict to calculate similarity + for key_B, value_B in self.Bdict.items(): + similarity = self.list_similarity(value_A, value_B) + similarities.append((key_B, similarity)) + + # Sort the similarities and store the top 'n' matches + top_n_matches = sorted(similarities, key=lambda x: (-x[1], x[0]))[:n] + matches[key_A] = top_n_matches + + return matches diff --git a/pdistmap/intersection.py b/pdistmap/set.py similarity index 100% rename from pdistmap/intersection.py rename to pdistmap/set.py diff --git a/tests/test_map.py b/tests/test_map.py new file mode 100644 index 0000000..3561e8e --- /dev/null +++ b/tests/test_map.py @@ -0,0 +1,87 @@ +import pytest +import numpy as np +from pdistmap.set import KDEIntersection +from pdistmap.map import ClusterMapper + +# Mocking the KDEIntersection class to avoid actual computation during tests +class MockKDEIntersection: + def __init__(self, list1, list2): + self.list1 = list1 + self.list2 = list2 + + def intersection_area(self): + # A mock similarity score for testing purposes + return np.random.random() + +# Replace the actual KDEIntersection with the mock +KDEIntersection = MockKDEIntersection + + +@pytest.fixture +def sample_data(): + """Fixture providing sample input data for testing.""" + Adict = { + "A_o": [1, 2, 3], + "B_o": [4, 5, 6], + } + Bdict = { + "A": [1, 2, 3], + "B": [7, 8, 9], + } + return Adict, Bdict + + +def test_initialization(sample_data): + """Test the initialization of the ClusterMapper class.""" + Adict, Bdict = sample_data + cluster_mapper = ClusterMapper(Adict, Bdict) + + assert cluster_mapper.Adict == Adict + assert cluster_mapper.Bdict == Bdict + + +def test_list_similarity(sample_data): + """Test the list_similarity method.""" + Adict, Bdict = sample_data + cluster_mapper = ClusterMapper(Adict, Bdict) + + # Assuming the mock will generate a random similarity score + similarity = cluster_mapper.list_similarity(Adict["A_o"], Bdict["A"]) + assert isinstance(similarity, float) # Ensure the result is a float + + +def test_find_top_n_matches(sample_data): + """Test the find_top_n_matches method.""" + Adict, Bdict = sample_data + cluster_mapper = ClusterMapper(Adict, Bdict) + + matches = cluster_mapper.find_top_n_matches(n=1) + + assert len(matches) == len(Adict) # Ensure we have as many matches as keys in Adict + for key_A in matches.keys(): + assert len(matches[key_A]) == 1 # Each key should return one match + + +def test_find_top_n_matches_multiple(sample_data): + """Test find_top_n_matches with more matches.""" + Adict, Bdict = sample_data + cluster_mapper = ClusterMapper(Adict, Bdict) + + matches = cluster_mapper.find_top_n_matches(n=2) + + assert len(matches) == len(Adict) # Ensure we have as many matches as keys in Adict + for key_A in matches.keys(): + assert len(matches[key_A]) == 2 # Each key should return two matches + + +def test_invalid_n_value(sample_data): + """Test handling of invalid 'n' values in find_top_n_matches.""" + Adict, Bdict = sample_data + cluster_mapper = ClusterMapper(Adict, Bdict) + + with pytest.raises(ValueError): + cluster_mapper.find_top_n_matches(n=-1) # Invalid number of matches + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_intersection.py b/tests/test_set.py similarity index 98% rename from tests/test_intersection.py rename to tests/test_set.py index 0d6347a..7536270 100644 --- a/tests/test_intersection.py +++ b/tests/test_set.py @@ -1,6 +1,6 @@ import pytest import numpy as np -from pdistmap.intersection import KDEIntersection +from pdistmap.set import KDEIntersection # Test cases for valid inputs def test_intersection_area_valid():