Skip to content

Commit

Permalink
Merge pull request #47 from deel-ai/joseba/multivariate-calibrator
Browse files Browse the repository at this point in the history
Joseba/multivariate calibrator
  • Loading branch information
jdalch authored Jan 3, 2024
2 parents 56b19f0 + 01da866 commit 6a2b131
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 22 deletions.
91 changes: 71 additions & 20 deletions deel/puncc/api/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from deel.puncc.api.utils import alpha_calib_check
from deel.puncc.api.utils import quantile
from deel.puncc.api.corrections import bonferroni

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
self._len_calib = 0
self._residuals = None
self._norm_weights = None
self._feature_axis = None

def fit(
self,
Expand All @@ -153,6 +155,8 @@ def fit(
logger.debug(f"Shape of y_true: {y_true.shape}")
self._residuals = self.nonconf_score_func(y_pred, y_true)
self._len_calib = len(self._residuals)
if y_pred.ndim > 1:
self._feature_axis = -1
logger.debug("Nonconformity scores computed !")

def calibrate(
Expand All @@ -161,6 +165,7 @@ def calibrate(
alpha: float,
y_pred: Iterable,
weights: Optional[Iterable] = None,
correction: Optional[Callable] = bonferroni,
) -> Tuple[np.ndarray]:
"""Compute calibrated prediction sets for new examples w.r.t a
significance level :math:`\\alpha`.
Expand All @@ -170,6 +175,9 @@ def calibrate(
:param Iterable weights: weights to be associated to the nonconformity
scores. Defaults to None when all the scores
are equiprobable.
:param Callable correction: correction for multiple hypothesis testing
in the case of multivariate regression.
Defaults to Bonferroni correction.
:returns: prediction set.
In case of regression, returns (y_lower, y_upper).
Expand All @@ -181,26 +189,7 @@ def calibrate(
calibration set.
"""

if self._residuals is None:
raise RuntimeError("Run `fit` method before calling `calibrate`.")

# Check consistency of alpha w.r.t the size of calibration data
if weights is None:
alpha_calib_check(alpha=alpha, n=self._len_calib)

# Compute weighted quantiles
## Lemma 1 of Tibshirani's paper (https://arxiv.org/pdf/1904.06019.pdf)
## The coverage guarantee holds with 1) the inflated
## (1-\alpha)(1+1/n)-th quantile or 2) when adding an infinite term to
## the sequence and computing the $(1-\alpha)$-th empirical quantile.
infty_array = np.array([np.inf])
lemma_residuals = np.concatenate((self._residuals, infty_array))
residuals_Q = quantile(
lemma_residuals,
1 - alpha,
w=weights,
)
residuals_Q = self.compute_quantile(alpha=alpha, weights=weights, correction=correction)

return self.pred_set_func(y_pred, scores_quantile=residuals_Q)

Expand Down Expand Up @@ -230,6 +219,61 @@ def get_nonconformity_scores(self) -> np.ndarray:
"""
return self._residuals

def compute_quantile(
self,
*,
alpha: float,
weights: Optional[Iterable] = None,
correction: Optional[Callable] = bonferroni,
) -> np.ndarray:
"""Compute quantile of scores w.r.t a
significance level :math:`\\alpha`.
:param float alpha: significance level (max miscoverage target).
:param Iterable weights: weights to be associated to the nonconformity
scores. Defaults to None when all the scores
are equiprobable.
:param Callable correction: correction for multiple hypothesis testing
in the case of multivariate regression.
Defaults to Bonferroni correction.
:returns: quantile
:rtype: ndarray
:raises RuntimeError: :meth:`compute_quantile` called before :meth:`fit`.
:raise ValueError: failed check on :data:`alpha` w.r.t size of the
calibration set.
"""

if self._residuals is None:
raise RuntimeError("Run `fit` method before calling `calibrate`.")

alpha = correction(alpha)

# Check consistency of alpha w.r.t the size of calibration data
if weights is None:
alpha_calib_check(alpha=alpha, n=self._len_calib)

# Compute weighted quantiles
## Lemma 1 of Tibshirani's paper (https://arxiv.org/pdf/1904.06019.pdf)
## The coverage guarantee holds with 1) the inflated
## (1-\alpha)(1+1/n)-th quantile or 2) when adding an infinite term to
## the sequence and computing the $(1-\alpha)$-th empirical quantile.
if self._residuals.ndim > 1:
infty_array = np.full((1,self._residuals.shape[-1]), np.inf)
else:
infty_array = np.array([np.inf])
lemma_residuals = np.concatenate((self._residuals, infty_array), axis=0)
residuals_Q = quantile(
lemma_residuals,
1 - alpha,
w=weights,
feature_axis=self._feature_axis
)

return residuals_Q

@staticmethod
def barber_weights(weights: np.ndarray) -> np.ndarray:
"""Compute and normalize inference weights of the nonconformity distribution
Expand Down Expand Up @@ -418,6 +462,7 @@ class CvPlusCalibrator:
def __init__(self, kfold_calibrators: dict):
self.kfold_calibrators_dict = kfold_calibrators
self._len_calib = None
self._feature_axis = None

# Sanity checks:
# - The collection of calibrators is not None
Expand Down Expand Up @@ -483,6 +528,10 @@ def calibrate(
if y_pred is None:
raise RuntimeError("No prediction obtained with cv+.")

# Check for multivariate predictions
if y_pred.ndim > 1:
self._feature_axis = -1

# nonconformity scores
kth_calibrator = self.kfold_calibrators_dict[k]
nconf_scores = kth_calibrator.get_nonconformity_scores()
Expand Down Expand Up @@ -536,11 +585,13 @@ def calibrate(
(1 - alpha) * (1 + 1 / self._len_calib),
w=weights,
axis=1,
feature_axis=self._feature_axis
)
y_hi = quantile(
concat_y_hi,
(1 - alpha) * (1 + 1 / self._len_calib),
w=weights,
axis=1,
feature_axis=self._feature_axis
)
return y_lo, y_hi
4 changes: 2 additions & 2 deletions deel/puncc/api/corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np


def bonferroni(alpha: float, nvars: int) -> float:
def bonferroni(alpha: float, nvars: int =1) -> float:
"""Bonferroni correction for multiple comparisons.
:param float alpha: nominal coverage level.
Expand All @@ -38,7 +38,7 @@ def bonferroni(alpha: float, nvars: int) -> float:
:rtype: float.
"""
# Sanity checks
if alpha <= 0 or alpha >= 1:
if np.any(alpha <= 0) or np.any(alpha >= 1):
raise ValueError("alpha must be in (0,1)")

if nvars <= 0:
Expand Down
49 changes: 49 additions & 0 deletions tests/api/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,52 @@ def test_anomaly_detection_calibrator(
assert anomalies.shape == (117, 2)
assert not_anomalies is not None
assert not_anomalies.shape == (33, 2)


@pytest.mark.parametrize(
"alpha, random_state",
[[0.1, 42], [0.3, 42], [0.5, 42], [0.7, 42], [0.9, 42],
[np.array([0.1, 0.3]), 42], [np.array([0.5, 0.7]), 42]]
)
def test_multivariate_regression_calibrator(
rand_multivariate_reg_data, alpha, random_state
):
# Generate data
(y_pred_calib, y_calib, y_pred_test, y_test) = rand_multivariate_reg_data

# Nonconformity score function that takes as argument
# the predicted values y_pred = model(X) and the true labels y_true. In
# this example, we reimplement the mean absolute deviation that is
# already defined in `deel.puncc.api.nonconformity_scores.mad`
def nonconformity_function(y_pred, y_true):
return np.abs(y_pred - y_true)

# Prediction sets are computed based on point predictions and
# the quantiles of nonconformity scores. The function below returns a
# fixed size rectangle around the point predictions.
def prediction_set_function(y_pred, scores_quantile):
y_lo = y_pred - scores_quantile
y_hi = y_pred + scores_quantile
return y_lo, y_hi

# The calibrator is instantiated by passing the two functions defined
# above to the constructor.
calibrator = BaseCalibrator(
nonconf_score_func=nonconformity_function,
pred_set_func=prediction_set_function,
)

# The nonconformity scores are computed by calling the `fit` method
# on the calibration dataset.
calibrator.fit(y_pred=y_pred_calib, y_true=y_calib)

# The lower and upper bounds of the prediction interval are then returned
# by the call to calibrate on the new data w.r.t a risk level alpha.
y_pred_lower, y_pred_upper = calibrator.calibrate(
y_pred=y_pred_test, alpha=alpha
)

assert y_pred_lower is not None
assert y_pred_upper is not None
assert not (True in np.isnan(y_pred_lower))
assert not (True in np.isnan(y_pred_upper))
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ def rand_reg_data():
return y_pred_calib, y_calib, y_pred_test, y_test


@pytest.fixture
def rand_multivariate_reg_data():
X_pred_calib = 10 * np.random.randn(100, 8)
X_pred_test = 10 * np.random.randn(100, 8)
X_test = 10 * np.random.randn(100, 8)
y_pred_calib = 4 * np.random.randn(100, 2) + 1
y_calib = 2 * np.random.randn(100, 2) + 1
y_pred_test = 4 * np.random.randn(100, 2) + 2
y_test = 2 * np.random.randn(100, 2) + 1

return y_pred_calib, y_calib, y_pred_test, y_test


@pytest.fixture
def rand_class_data():
X_pred_calib = 10 * np.random.randn(100, 4)
Expand Down

0 comments on commit 6a2b131

Please sign in to comment.