Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joseba/multivariate calibrator #47

Merged
merged 7 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions deel/puncc/api/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from deel.puncc.api.utils import alpha_calib_check
from deel.puncc.api.utils import quantile
from deel.puncc.api.corrections import bonferroni

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
self._len_calib = 0
self._residuals = None
self._norm_weights = None
self._feature_axis = None

def fit(
self,
Expand All @@ -153,6 +155,8 @@ def fit(
logger.debug(f"Shape of y_true: {y_true.shape}")
self._residuals = self.nonconf_score_func(y_pred, y_true)
self._len_calib = len(self._residuals)
if y_pred.ndim > 1:
self._feature_axis = -1
logger.debug("Nonconformity scores computed !")

def calibrate(
Expand All @@ -161,6 +165,7 @@ def calibrate(
alpha: float,
y_pred: Iterable,
weights: Optional[Iterable] = None,
correction: Optional[Callable] = bonferroni,
) -> Tuple[np.ndarray]:
"""Compute calibrated prediction sets for new examples w.r.t a
significance level :math:`\\alpha`.
Expand All @@ -170,6 +175,9 @@ def calibrate(
:param Iterable weights: weights to be associated to the nonconformity
scores. Defaults to None when all the scores
are equiprobable.
:param Callable correction: correction for multiple hypothesis testing
in the case of multivariate regression.
Defaults to Bonferroni correction.

:returns: prediction set.
In case of regression, returns (y_lower, y_upper).
Expand All @@ -181,26 +189,7 @@ def calibrate(
calibration set.

"""

if self._residuals is None:
raise RuntimeError("Run `fit` method before calling `calibrate`.")

# Check consistency of alpha w.r.t the size of calibration data
if weights is None:
alpha_calib_check(alpha=alpha, n=self._len_calib)

# Compute weighted quantiles
## Lemma 1 of Tibshirani's paper (https://arxiv.org/pdf/1904.06019.pdf)
## The coverage guarantee holds with 1) the inflated
## (1-\alpha)(1+1/n)-th quantile or 2) when adding an infinite term to
## the sequence and computing the $(1-\alpha)$-th empirical quantile.
infty_array = np.array([np.inf])
lemma_residuals = np.concatenate((self._residuals, infty_array))
residuals_Q = quantile(
lemma_residuals,
1 - alpha,
w=weights,
)
residuals_Q = self.compute_quantile(alpha=alpha, weights=weights, correction=correction)

return self.pred_set_func(y_pred, scores_quantile=residuals_Q)

Expand Down Expand Up @@ -230,6 +219,61 @@ def get_nonconformity_scores(self) -> np.ndarray:
"""
return self._residuals

def compute_quantile(
self,
*,
alpha: float,
weights: Optional[Iterable] = None,
correction: Optional[Callable] = bonferroni,
) -> np.ndarray:
"""Compute quantile of scores w.r.t a
significance level :math:`\\alpha`.

:param float alpha: significance level (max miscoverage target).
:param Iterable weights: weights to be associated to the nonconformity
scores. Defaults to None when all the scores
are equiprobable.
:param Callable correction: correction for multiple hypothesis testing
in the case of multivariate regression.
Defaults to Bonferroni correction.

:returns: quantile
:rtype: ndarray

:raises RuntimeError: :meth:`compute_quantile` called before :meth:`fit`.
:raise ValueError: failed check on :data:`alpha` w.r.t size of the
calibration set.

"""

if self._residuals is None:
raise RuntimeError("Run `fit` method before calling `calibrate`.")

alpha = correction(alpha)

# Check consistency of alpha w.r.t the size of calibration data
if weights is None:
alpha_calib_check(alpha=alpha, n=self._len_calib)

# Compute weighted quantiles
## Lemma 1 of Tibshirani's paper (https://arxiv.org/pdf/1904.06019.pdf)
## The coverage guarantee holds with 1) the inflated
## (1-\alpha)(1+1/n)-th quantile or 2) when adding an infinite term to
## the sequence and computing the $(1-\alpha)$-th empirical quantile.
if self._residuals.ndim > 1:
infty_array = np.full((1,self._residuals.shape[-1]), np.inf)
else:
infty_array = np.array([np.inf])
lemma_residuals = np.concatenate((self._residuals, infty_array), axis=0)
residuals_Q = quantile(
lemma_residuals,
1 - alpha,
w=weights,
feature_axis=self._feature_axis
)

return residuals_Q

@staticmethod
def barber_weights(weights: np.ndarray) -> np.ndarray:
"""Compute and normalize inference weights of the nonconformity distribution
Expand Down Expand Up @@ -418,6 +462,7 @@ class CvPlusCalibrator:
def __init__(self, kfold_calibrators: dict):
self.kfold_calibrators_dict = kfold_calibrators
self._len_calib = None
self._feature_axis = None

# Sanity checks:
# - The collection of calibrators is not None
Expand Down Expand Up @@ -483,6 +528,10 @@ def calibrate(
if y_pred is None:
raise RuntimeError("No prediction obtained with cv+.")

# Check for multivariate predictions
if y_pred.ndim > 1:
self._feature_axis = -1

# nonconformity scores
kth_calibrator = self.kfold_calibrators_dict[k]
nconf_scores = kth_calibrator.get_nonconformity_scores()
Expand Down Expand Up @@ -536,11 +585,13 @@ def calibrate(
(1 - alpha) * (1 + 1 / self._len_calib),
w=weights,
axis=1,
feature_axis=self._feature_axis
)
y_hi = quantile(
concat_y_hi,
(1 - alpha) * (1 + 1 / self._len_calib),
w=weights,
axis=1,
feature_axis=self._feature_axis
)
return y_lo, y_hi
4 changes: 2 additions & 2 deletions deel/puncc/api/corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np


def bonferroni(alpha: float, nvars: int) -> float:
def bonferroni(alpha: float, nvars: int =1) -> float:
"""Bonferroni correction for multiple comparisons.

:param float alpha: nominal coverage level.
Expand All @@ -38,7 +38,7 @@ def bonferroni(alpha: float, nvars: int) -> float:
:rtype: float.
"""
# Sanity checks
if alpha <= 0 or alpha >= 1:
if np.any(alpha <= 0) or np.any(alpha >= 1):
raise ValueError("alpha must be in (0,1)")

if nvars <= 0:
Expand Down
49 changes: 49 additions & 0 deletions tests/api/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,52 @@ def test_anomaly_detection_calibrator(
assert anomalies.shape == (117, 2)
assert not_anomalies is not None
assert not_anomalies.shape == (33, 2)


@pytest.mark.parametrize(
"alpha, random_state",
[[0.1, 42], [0.3, 42], [0.5, 42], [0.7, 42], [0.9, 42],
[np.array([0.1, 0.3]), 42], [np.array([0.5, 0.7]), 42]]
)
def test_multivariate_regression_calibrator(
rand_multivariate_reg_data, alpha, random_state
):
# Generate data
(y_pred_calib, y_calib, y_pred_test, y_test) = rand_multivariate_reg_data

# Nonconformity score function that takes as argument
# the predicted values y_pred = model(X) and the true labels y_true. In
# this example, we reimplement the mean absolute deviation that is
# already defined in `deel.puncc.api.nonconformity_scores.mad`
def nonconformity_function(y_pred, y_true):
return np.abs(y_pred - y_true)

# Prediction sets are computed based on point predictions and
# the quantiles of nonconformity scores. The function below returns a
# fixed size rectangle around the point predictions.
def prediction_set_function(y_pred, scores_quantile):
y_lo = y_pred - scores_quantile
y_hi = y_pred + scores_quantile
return y_lo, y_hi

# The calibrator is instantiated by passing the two functions defined
# above to the constructor.
calibrator = BaseCalibrator(
nonconf_score_func=nonconformity_function,
pred_set_func=prediction_set_function,
)

# The nonconformity scores are computed by calling the `fit` method
# on the calibration dataset.
calibrator.fit(y_pred=y_pred_calib, y_true=y_calib)

# The lower and upper bounds of the prediction interval are then returned
# by the call to calibrate on the new data w.r.t a risk level alpha.
y_pred_lower, y_pred_upper = calibrator.calibrate(
y_pred=y_pred_test, alpha=alpha
)

assert y_pred_lower is not None
assert y_pred_upper is not None
assert not (True in np.isnan(y_pred_lower))
assert not (True in np.isnan(y_pred_upper))
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ def rand_reg_data():
return y_pred_calib, y_calib, y_pred_test, y_test


@pytest.fixture
def rand_multivariate_reg_data():
X_pred_calib = 10 * np.random.randn(100, 8)
X_pred_test = 10 * np.random.randn(100, 8)
X_test = 10 * np.random.randn(100, 8)
y_pred_calib = 4 * np.random.randn(100, 2) + 1
y_calib = 2 * np.random.randn(100, 2) + 1
y_pred_test = 4 * np.random.randn(100, 2) + 2
y_test = 2 * np.random.randn(100, 2) + 1

return y_pred_calib, y_calib, y_pred_test, y_test


@pytest.fixture
def rand_class_data():
X_pred_calib = 10 * np.random.randn(100, 4)
Expand Down