-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
172 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import numpy as np | ||
from pdistmap.set import KDEIntersection | ||
|
||
class ClusterMapper: | ||
def __init__(self, Adict: dict, Bdict: dict): | ||
""" | ||
Initialize the ClusterMapper with two dictionaries containing lists. | ||
Parameters: | ||
- Adict (dict): A dictionary where each key corresponds to a list of values for comparison. | ||
- Bdict (dict): A dictionary where each key corresponds to another list of values for comparison. | ||
""" | ||
self.Adict = Adict | ||
self.Bdict = Bdict | ||
|
||
def _check_valid_n(self, n: int) -> None: | ||
""" | ||
Checks if the value of n is valid. | ||
Parameters: | ||
- n: Number of top matches to validate. | ||
Raises: | ||
- ValueError: If n is not greater than 1 or not less than the size of the lists in Adict or Bdict. | ||
""" | ||
# Ensure n is valid | ||
if n < 1: | ||
raise ValueError("n must be greater than 1.") | ||
|
||
# Check the size of the lists in Adict | ||
for key in self.Adict: | ||
if n >= len(self.Adict[key]): | ||
raise ValueError(f"n must be less than the size of the list for key '{key}' in Adict.") | ||
|
||
# Check the size of the lists in Bdict | ||
for key in self.Bdict: | ||
if n >= len(self.Bdict[key]): | ||
raise ValueError(f"n must be less than the size of the list for key '{key}' in Bdict.") | ||
|
||
|
||
def list_similarity(self, list1: list, list2: list) -> float: | ||
""" | ||
Calculate the similarity between two lists using the KDEIntersection method. | ||
Parameters: | ||
- list1 (list): The first list of values. | ||
- list2 (list): The second list of values. | ||
Returns: | ||
- float: A similarity score calculated as the intersection area between the two lists. | ||
""" | ||
return KDEIntersection(np.array(list1), np.array(list2)).intersection_area() | ||
|
||
def find_top_n_matches(self, n: int = 2) -> dict: | ||
""" | ||
Find the top 'n' closest matches between the lists of the two dictionaries. | ||
Parameters: | ||
- n (int): The number of top matches to return for each key in Adict. Default is 2. | ||
Returns: | ||
- dict: A dictionary mapping each key from Adict to its top 'n' matches from Bdict, | ||
represented as a list of tuples containing the matching key and its similarity score. | ||
""" | ||
|
||
self._check_valid_n(n) | ||
|
||
matches = {} | ||
|
||
# Iterate over each item in Adict | ||
for key_A, value_A in self.Adict.items(): | ||
similarities = [] | ||
|
||
# Iterate over each item in Bdict to calculate similarity | ||
for key_B, value_B in self.Bdict.items(): | ||
similarity = self.list_similarity(value_A, value_B) | ||
similarities.append((key_B, similarity)) | ||
|
||
# Sort the similarities and store the top 'n' matches | ||
top_n_matches = sorted(similarities, key=lambda x: (-x[1], x[0]))[:n] | ||
matches[key_A] = top_n_matches | ||
|
||
return matches |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytest | ||
import numpy as np | ||
from pdistmap.set import KDEIntersection | ||
from pdistmap.map import ClusterMapper | ||
|
||
# Mocking the KDEIntersection class to avoid actual computation during tests | ||
class MockKDEIntersection: | ||
def __init__(self, list1, list2): | ||
self.list1 = list1 | ||
self.list2 = list2 | ||
|
||
def intersection_area(self): | ||
# A mock similarity score for testing purposes | ||
return np.random.random() | ||
|
||
# Replace the actual KDEIntersection with the mock | ||
KDEIntersection = MockKDEIntersection | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(): | ||
"""Fixture providing sample input data for testing.""" | ||
Adict = { | ||
"A_o": [1, 2, 3], | ||
"B_o": [4, 5, 6], | ||
} | ||
Bdict = { | ||
"A": [1, 2, 3], | ||
"B": [7, 8, 9], | ||
} | ||
return Adict, Bdict | ||
|
||
|
||
def test_initialization(sample_data): | ||
"""Test the initialization of the ClusterMapper class.""" | ||
Adict, Bdict = sample_data | ||
cluster_mapper = ClusterMapper(Adict, Bdict) | ||
|
||
assert cluster_mapper.Adict == Adict | ||
assert cluster_mapper.Bdict == Bdict | ||
|
||
|
||
def test_list_similarity(sample_data): | ||
"""Test the list_similarity method.""" | ||
Adict, Bdict = sample_data | ||
cluster_mapper = ClusterMapper(Adict, Bdict) | ||
|
||
# Assuming the mock will generate a random similarity score | ||
similarity = cluster_mapper.list_similarity(Adict["A_o"], Bdict["A"]) | ||
assert isinstance(similarity, float) # Ensure the result is a float | ||
|
||
|
||
def test_find_top_n_matches(sample_data): | ||
"""Test the find_top_n_matches method.""" | ||
Adict, Bdict = sample_data | ||
cluster_mapper = ClusterMapper(Adict, Bdict) | ||
|
||
matches = cluster_mapper.find_top_n_matches(n=1) | ||
|
||
assert len(matches) == len(Adict) # Ensure we have as many matches as keys in Adict | ||
for key_A in matches.keys(): | ||
assert len(matches[key_A]) == 1 # Each key should return one match | ||
|
||
|
||
def test_find_top_n_matches_multiple(sample_data): | ||
"""Test find_top_n_matches with more matches.""" | ||
Adict, Bdict = sample_data | ||
cluster_mapper = ClusterMapper(Adict, Bdict) | ||
|
||
matches = cluster_mapper.find_top_n_matches(n=2) | ||
|
||
assert len(matches) == len(Adict) # Ensure we have as many matches as keys in Adict | ||
for key_A in matches.keys(): | ||
assert len(matches[key_A]) == 2 # Each key should return two matches | ||
|
||
|
||
def test_invalid_n_value(sample_data): | ||
"""Test handling of invalid 'n' values in find_top_n_matches.""" | ||
Adict, Bdict = sample_data | ||
cluster_mapper = ClusterMapper(Adict, Bdict) | ||
|
||
with pytest.raises(ValueError): | ||
cluster_mapper.find_top_n_matches(n=-1) # Invalid number of matches | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters