Skip to content

Commit

Permalink
black python formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
rehanguha committed Sep 15, 2024
1 parent 7de4711 commit 320cd7c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 53 deletions.
23 changes: 0 additions & 23 deletions .github/workflows/pylint.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
89 changes: 60 additions & 29 deletions pdistmap/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import matplotlib.pyplot as plt
from typing import Tuple


class KDEIntersection:
"""
A class used to compute the intersection area between two Kernel Density Estimations (KDEs).
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self, A: np.ndarray, B: np.ndarray) -> None:
self.A = A
self.B = B
self.executed = False

def resetVectors(self, A: np.ndarray, B: np.ndarray) -> None:
"""
Resets the vectors A and B and sets the execution flag to False.
Expand All @@ -72,7 +73,7 @@ def resetVectors(self, A: np.ndarray, B: np.ndarray) -> None:
self.A = A
self.B = B
self.executed = False

def _validate_numeric_array(self, arr: np.ndarray) -> np.ndarray:
"""
Validates that the input is a numpy array and contains only numeric values (integers or floats).
Expand All @@ -97,10 +98,12 @@ def _validate_numeric_array(self, arr: np.ndarray) -> np.ndarray:
"""
if not isinstance(arr, np.ndarray):
raise TypeError("Input is not a numpy array.")

if not np.issubdtype(arr.dtype, np.number):
raise ValueError("Array contains non-numeric data. Only integers and floats are allowed.")

raise ValueError(
"Array contains non-numeric data. Only integers and floats are allowed."
)

if np.isnan(arr).any():
raise ValueError("Array contains NaN values, which are not allowed.")

Expand Down Expand Up @@ -133,8 +136,10 @@ def _check_limit(self, value: float, lower: float, upper: float) -> float:
return value
else:
raise ValueError(f"Value should be between [{lower}, {upper}].")

def _calculate_kde(self, data: np.ndarray, bw_method: str = "scott") -> gaussian_kde:

def _calculate_kde(
self, data: np.ndarray, bw_method: str = "scott"
) -> gaussian_kde:
"""
Calculates the Kernel Density Estimation (KDE) of the input data.
Expand All @@ -152,7 +157,9 @@ def _calculate_kde(self, data: np.ndarray, bw_method: str = "scott") -> gaussian
"""
return gaussian_kde(data, bw_method=bw_method)

def _min_max_finder(self, data1: np.ndarray, data2: np.ndarray, adjustment_factor: float = 0) -> Tuple[float, float]:
def _min_max_finder(
self, data1: np.ndarray, data2: np.ndarray, adjustment_factor: float = 0
) -> Tuple[float, float]:
"""
Finds the minimum and maximum values from two datasets, adjusting them with a factor.
Expand All @@ -172,7 +179,7 @@ def _min_max_finder(self, data1: np.ndarray, data2: np.ndarray, adjustment_facto
"""
xmin = min(data1.min(), data2.min())
xmax = max(data1.max(), data2.max())

adjustment_factor = self._check_limit(adjustment_factor, lower=0, upper=1)

dx = adjustment_factor * (xmax - xmin)
Expand All @@ -181,7 +188,9 @@ def _min_max_finder(self, data1: np.ndarray, data2: np.ndarray, adjustment_facto

return adjusted_xmin, adjusted_xmax

def _build_linespace(self, xmin: float, xmax: float, linespace_num: int = 10000) -> np.ndarray:
def _build_linespace(
self, xmin: float, xmax: float, linespace_num: int = 10000
) -> np.ndarray:
"""
Builds a linearly spaced array between two values.
Expand All @@ -200,8 +209,14 @@ def _build_linespace(self, xmin: float, xmax: float, linespace_num: int = 10000)
The linearly spaced array.
"""
return np.linspace(xmin, xmax, linespace_num)

def intersection_area(self, adjustment_factor: float = 0.2, bw_method: str = "scott", linespace_num: int = 10000, plot: bool = False) -> float:

def intersection_area(
self,
adjustment_factor: float = 0.2,
bw_method: str = "scott",
linespace_num: int = 10000,
plot: bool = False,
) -> float:
"""
Calculates the intersection area between the KDEs of A and B.
Expand All @@ -223,26 +238,30 @@ def intersection_area(self, adjustment_factor: float = 0.2, bw_method: str = "sc
"""
A = self._validate_numeric_array(self.A)
B = self._validate_numeric_array(self.B)

kdeA = self._calculate_kde(A, bw_method=bw_method)
kdeB = self._calculate_kde(B, bw_method=bw_method)

data_min, data_max = self._min_max_finder(A, B, adjustment_factor=adjustment_factor)

self.data_linespace = self._build_linespace(data_min, data_max, linespace_num=linespace_num)
data_min, data_max = self._min_max_finder(
A, B, adjustment_factor=adjustment_factor
)

self.data_linespace = self._build_linespace(
data_min, data_max, linespace_num=linespace_num
)

self.kdeA_data = kdeA(self.data_linespace)
self.kdeB_data = kdeB(self.data_linespace)

self.inters = np.minimum(self.kdeA_data, self.kdeB_data)
self.area_inters = np.trapezoid(self.inters, self.data_linespace)
self.executed = True

if plot:
self.plot()

return self.area_inters

def plot(self) -> None:
"""
Plots the KDEs of A and B and their intersection.
Expand All @@ -253,18 +272,30 @@ def plot(self) -> None:
If `intersection_area()` has not been called before plotting.
"""
if not self.executed:
raise Exception("'intersection_area()' needs to be executed first. Or use 'intersection_area(plot=True)'")
raise Exception(
"'intersection_area()' needs to be executed first. Or use 'intersection_area(plot=True)'"
)

plt.plot(self.data_linespace, self.kdeA_data, color='b', label='A')
plt.fill_between(self.data_linespace, self.kdeA_data, 0, color='b', alpha=0.2)
plt.plot(self.data_linespace, self.kdeB_data, color='orange', label='B')
plt.fill_between(self.data_linespace, self.kdeB_data, 0, color='orange', alpha=0.2)
plt.plot(self.data_linespace, self.inters, color='r')
plt.fill_between(self.data_linespace, self.inters, 0, facecolor='none', edgecolor='r', hatch='x', label='Intersection')
plt.plot(self.data_linespace, self.kdeA_data, color="b", label="A")
plt.fill_between(self.data_linespace, self.kdeA_data, 0, color="b", alpha=0.2)
plt.plot(self.data_linespace, self.kdeB_data, color="orange", label="B")
plt.fill_between(
self.data_linespace, self.kdeB_data, 0, color="orange", alpha=0.2
)
plt.plot(self.data_linespace, self.inters, color="r")
plt.fill_between(
self.data_linespace,
self.inters,
0,
facecolor="none",
edgecolor="r",
hatch="x",
label="Intersection",
)

handles, labels = plt.gca().get_legend_handles_labels()
labels[2] += f': {self.area_inters * 100:.1f} %'
plt.legend(handles, labels, title='')
plt.title('KDE Intersection')
labels[2] += f": {self.area_inters * 100:.1f} %"
plt.legend(handles, labels, title="")
plt.title("KDE Intersection")
plt.tight_layout()
plt.show()

0 comments on commit 320cd7c

Please sign in to comment.