Skip to content

Commit

Permalink
Updated ClusterMapper
Browse files Browse the repository at this point in the history
  • Loading branch information
rehanguha committed Oct 16, 2024
1 parent 5460757 commit a63a497
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pip install pdistmap

```python

from pdistmap.intersection import KDEIntersection
from pdistmap.set import KDEIntersection
import numpy as np

A = np.array([25, 40, 70, 65, 69, 75, 80, 85])
Expand Down
83 changes: 83 additions & 0 deletions pdistmap/map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
from pdistmap.set import KDEIntersection

class ClusterMapper:
def __init__(self, Adict: dict, Bdict: dict):
"""
Initialize the ClusterMapper with two dictionaries containing lists.
Parameters:
- Adict (dict): A dictionary where each key corresponds to a list of values for comparison.
- Bdict (dict): A dictionary where each key corresponds to another list of values for comparison.
"""
self.Adict = Adict
self.Bdict = Bdict

def _check_valid_n(self, n: int) -> None:
"""
Checks if the value of n is valid.
Parameters:
- n: Number of top matches to validate.
Raises:
- ValueError: If n is not greater than 1 or not less than the size of the lists in Adict or Bdict.
"""
# Ensure n is valid
if n < 1:
raise ValueError("n must be greater than 1.")

# Check the size of the lists in Adict
for key in self.Adict:
if n >= len(self.Adict[key]):
raise ValueError(f"n must be less than the size of the list for key '{key}' in Adict.")

# Check the size of the lists in Bdict
for key in self.Bdict:
if n >= len(self.Bdict[key]):
raise ValueError(f"n must be less than the size of the list for key '{key}' in Bdict.")


def list_similarity(self, list1: list, list2: list) -> float:
"""
Calculate the similarity between two lists using the KDEIntersection method.
Parameters:
- list1 (list): The first list of values.
- list2 (list): The second list of values.
Returns:
- float: A similarity score calculated as the intersection area between the two lists.
"""
return KDEIntersection(np.array(list1), np.array(list2)).intersection_area()

def find_top_n_matches(self, n: int = 2) -> dict:
"""
Find the top 'n' closest matches between the lists of the two dictionaries.
Parameters:
- n (int): The number of top matches to return for each key in Adict. Default is 2.
Returns:
- dict: A dictionary mapping each key from Adict to its top 'n' matches from Bdict,
represented as a list of tuples containing the matching key and its similarity score.
"""

self._check_valid_n(n)

matches = {}

# Iterate over each item in Adict
for key_A, value_A in self.Adict.items():
similarities = []

# Iterate over each item in Bdict to calculate similarity
for key_B, value_B in self.Bdict.items():
similarity = self.list_similarity(value_A, value_B)
similarities.append((key_B, similarity))

# Sort the similarities and store the top 'n' matches
top_n_matches = sorted(similarities, key=lambda x: (-x[1], x[0]))[:n]
matches[key_A] = top_n_matches

return matches
File renamed without changes.
87 changes: 87 additions & 0 deletions tests/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import numpy as np
from pdistmap.set import KDEIntersection
from pdistmap.map import ClusterMapper

# Mocking the KDEIntersection class to avoid actual computation during tests
class MockKDEIntersection:
def __init__(self, list1, list2):
self.list1 = list1
self.list2 = list2

def intersection_area(self):
# A mock similarity score for testing purposes
return np.random.random()

# Replace the actual KDEIntersection with the mock
KDEIntersection = MockKDEIntersection


@pytest.fixture
def sample_data():
"""Fixture providing sample input data for testing."""
Adict = {
"A_o": [1, 2, 3],
"B_o": [4, 5, 6],
}
Bdict = {
"A": [1, 2, 3],
"B": [7, 8, 9],
}
return Adict, Bdict


def test_initialization(sample_data):
"""Test the initialization of the ClusterMapper class."""
Adict, Bdict = sample_data
cluster_mapper = ClusterMapper(Adict, Bdict)

assert cluster_mapper.Adict == Adict
assert cluster_mapper.Bdict == Bdict


def test_list_similarity(sample_data):
"""Test the list_similarity method."""
Adict, Bdict = sample_data
cluster_mapper = ClusterMapper(Adict, Bdict)

# Assuming the mock will generate a random similarity score
similarity = cluster_mapper.list_similarity(Adict["A_o"], Bdict["A"])
assert isinstance(similarity, float) # Ensure the result is a float


def test_find_top_n_matches(sample_data):
"""Test the find_top_n_matches method."""
Adict, Bdict = sample_data
cluster_mapper = ClusterMapper(Adict, Bdict)

matches = cluster_mapper.find_top_n_matches(n=1)

assert len(matches) == len(Adict) # Ensure we have as many matches as keys in Adict
for key_A in matches.keys():
assert len(matches[key_A]) == 1 # Each key should return one match


def test_find_top_n_matches_multiple(sample_data):
"""Test find_top_n_matches with more matches."""
Adict, Bdict = sample_data
cluster_mapper = ClusterMapper(Adict, Bdict)

matches = cluster_mapper.find_top_n_matches(n=2)

assert len(matches) == len(Adict) # Ensure we have as many matches as keys in Adict
for key_A in matches.keys():
assert len(matches[key_A]) == 2 # Each key should return two matches


def test_invalid_n_value(sample_data):
"""Test handling of invalid 'n' values in find_top_n_matches."""
Adict, Bdict = sample_data
cluster_mapper = ClusterMapper(Adict, Bdict)

with pytest.raises(ValueError):
cluster_mapper.find_top_n_matches(n=-1) # Invalid number of matches


if __name__ == "__main__":
pytest.main()
2 changes: 1 addition & 1 deletion tests/test_intersection.py → tests/test_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import numpy as np
from pdistmap.intersection import KDEIntersection
from pdistmap.set import KDEIntersection

# Test cases for valid inputs
def test_intersection_area_valid():
Expand Down

0 comments on commit a63a497

Please sign in to comment.