Skip to content

Commit

Permalink
update theory overview (#55)
Browse files Browse the repository at this point in the history
* update theory overview

* ops: update pylint rules

---------

Co-authored-by: M-Mouhcine <[email protected]>
  • Loading branch information
jdalch and M-Mouhcine authored Jul 4, 2024
1 parent 812fa09 commit d09f773
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 56 deletions.
3 changes: 3 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ disable=
C0302, # allow too many lines in module
C0411, # allow custom import order

E0606, # allow false positive used-before-assignment

R0801, # allow similar lines in 2 files
R0915, # allow too many statements

W0105, # allow no effect string statement
W0102, # allow dangerous default value []
Expand Down
30 changes: 17 additions & 13 deletions deel/puncc/api/nonconformity_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
This module provides nonconformity scores for conformal prediction. To be used
when building a :ref:`calibrator <calibration>`.
"""
import pkgutil
import importlib
from typing import Callable
from typing import Iterable

Expand All @@ -33,13 +33,13 @@
from deel.puncc.api.utils import logit_normalization_check
from deel.puncc.api.utils import supported_types_check

if pkgutil.find_loader("pandas") is not None:
if importlib.util.find_spec("pandas") is not None:
import pandas as pd

if pkgutil.find_loader("tensorflow") is not None:
if importlib.util.find_spec("tensorflow") is not None:
import tensorflow as tf

if pkgutil.find_loader("torch") is not None:
if importlib.util.find_spec("torch") is not None:
import torch


Expand Down Expand Up @@ -176,7 +176,7 @@ def difference(y_pred: Iterable, y_true: Iterable) -> Iterable:
"""
supported_types_check(y_pred, y_true)

if pkgutil.find_loader("torch") is not None and isinstance(
if importlib.util.find_spec("torch") is not None and isinstance(
y_pred, torch.Tensor
):
y_pred = y_pred.cpu().detach().numpy()
Expand Down Expand Up @@ -205,7 +205,9 @@ def absolute_difference(y_pred: Iterable, y_true: Iterable) -> Iterable:
return abs(difference(y_pred, y_true))


def scaled_ad(Y_pred: Iterable, y_true: Iterable, eps: float = 1e-12) -> Iterable:
def scaled_ad(
Y_pred: Iterable, y_true: Iterable, eps: float = 1e-12
) -> Iterable:
"""Scaled Absolute Deviation, normalized by an estimation of the conditional
mean absolute deviation (conditional MAD). Considering
:math:`Y_{\\text{pred}} = (\mu_{\\text{pred}}, \sigma_{\\text{pred}})`:
Expand Down Expand Up @@ -235,7 +237,7 @@ def scaled_ad(Y_pred: Iterable, y_true: Iterable, eps: float = 1e-12) -> Iterabl
if len(y_true.shape) != 1:
raise RuntimeError("Each y_true must contain a point observation.")

if pkgutil.find_loader("pandas") is not None and isinstance(
if importlib.util.find_spec("pandas") is not None and isinstance(
Y_pred, pd.DataFrame
):
y_pred, sigma_pred = Y_pred.iloc[:, 0], Y_pred.iloc[:, 1]
Expand All @@ -245,8 +247,10 @@ def scaled_ad(Y_pred: Iterable, y_true: Iterable, eps: float = 1e-12) -> Iterabl
# MAD then Scaled MAD and computed
mean_absolute_deviation = absolute_difference(y_pred, y_true)
if np.any(sigma_pred + eps <= 0):
print("Warning: calibration points with MAD predictions"
" below -eps won't be used for calibration.")
print(
"Warning: calibration points with MAD predictions"
" below -eps won't be used for calibration."
)

nonneg = sigma_pred + eps > 0
return mean_absolute_deviation[nonneg] / (sigma_pred[nonneg] + eps)
Expand Down Expand Up @@ -282,7 +286,7 @@ def cqr_score(Y_pred: Iterable, y_true: Iterable) -> Iterable:
if len(y_true.shape) != 1:
raise RuntimeError("Each y_pred must contain a point observation.")

if pkgutil.find_loader("pandas") is not None and isinstance(
if importlib.util.find_spec("pandas") is not None and isinstance(
Y_pred, pd.DataFrame
):
q_lo, q_hi = Y_pred.iloc[:, 0], Y_pred.iloc[:, 1]
Expand All @@ -295,20 +299,20 @@ def cqr_score(Y_pred: Iterable, y_true: Iterable) -> Iterable:
if isinstance(diff_lo, np.ndarray):
return np.maximum(diff_lo, diff_hi)

if pkgutil.find_loader("pandas") is not None and isinstance(
if importlib.util.find_spec("pandas") is not None and isinstance(
diff_lo, (pd.DataFrame, pd.Series)
):
return (pd.concat([diff_lo, diff_hi]).groupby(level=0)).max()
# raise NotImplementedError(
# "CQR score not implemented for DataFrames. Please provide ndarray or tensors."
# )

if pkgutil.find_loader("tensorflow") is not None and isinstance(
if importlib.util.find_spec("tensorflow") is not None and isinstance(
diff_lo, tf.Tensor
):
return tf.math.maximum(diff_lo, diff_hi)

# if pkgutil.find_loader("torch") is not None and isinstance(
# if importlib.util.find_spec("torch") is not None and isinstance(
# diff_lo, torch.Tensor
# ):
# return torch.maximum(diff_lo, diff_hi)
Expand Down
6 changes: 3 additions & 3 deletions deel/puncc/api/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""
This module provides data splitting schemes.
"""
import pkgutil
import importlib
from abc import ABC
from typing import Iterable
from typing import List
Expand All @@ -36,7 +36,7 @@
from deel.puncc.api.utils import sample_len_check
from deel.puncc.api.utils import supported_types_check

if pkgutil.find_loader("pandas") is not None:
if importlib.util.find_spec("pandas") is not None:
import pandas as pd


Expand Down Expand Up @@ -171,7 +171,7 @@ def __call__(
folds = []

for fit, calib in kfold.split(X):
if pkgutil.find_loader("pandas") is not None and isinstance(
if importlib.util.find_spec("pandas") is not None and isinstance(
X, pd.DataFrame
):
if isinstance(y, pd.DataFrame):
Expand Down
20 changes: 10 additions & 10 deletions deel/puncc/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
This module implements utility functions.
"""
import logging
import pkgutil
import importlib
import sys
from typing import Any
from typing import Iterable
Expand All @@ -34,13 +34,13 @@

import numpy as np

if pkgutil.find_loader("pandas") is not None:
if importlib.util.find_spec("pandas") is not None:
import pandas as pd

if pkgutil.find_loader("tensorflow") is not None:
if importlib.util.find_spec("tensorflow") is not None:
import tensorflow as tf

if pkgutil.find_loader("torch") is not None:
if importlib.util.find_spec("torch") is not None:
import torch

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,15 +126,15 @@ def supported_types_check(*data: Iterable):
if isinstance(a, np.ndarray):
pass

elif pkgutil.find_loader("pandas") is not None and isinstance(
elif importlib.util.find_spec("pandas") is not None and isinstance(
a, (pd.DataFrame, pd.Series)
):
pass
elif pkgutil.find_loader("tensorflow") is not None and isinstance(
elif importlib.util.find_spec("tensorflow") is not None and isinstance(
a, tf.Tensor
):
pass
elif pkgutil.find_loader("torch") is not None and isinstance(
elif importlib.util.find_spec("torch") is not None and isinstance(
a, torch.Tensor
):
pass
Expand Down Expand Up @@ -277,15 +277,15 @@ def quantile(

if isinstance(a, np.ndarray):
pass
elif pkgutil.find_loader("pandas") is not None and isinstance(
elif importlib.util.find_spec("pandas") is not None and isinstance(
a, pd.DataFrame
):
a = a.to_numpy()
elif pkgutil.find_loader("tensorflow") is not None and isinstance(
elif importlib.util.find_spec("tensorflow") is not None and isinstance(
a, tf.Tensor
):
a = a.numpy()
# elif pkgutil.find_loader("torch") is not None:
# elif importlib.util.find_spec("torch") is not None:
# if isinstance(a, torch.Tensor):
# a = a.cpu().detach().numpy()
else:
Expand Down
5 changes: 4 additions & 1 deletion deel/puncc/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def plot_prediction_intervals(
"""

# Initialisation
current_rcparams = None

# Figure size configuration
if "figsize" in fig_kw.keys():
figsize = fig_kw["figsize"]
Expand Down Expand Up @@ -256,7 +259,7 @@ def plot_prediction_intervals(
ax.set_xlim(X[0] - int_size * 0.01, X[-1] + int_size * 0.01)

# restablish rcparams
if restablish_rcparams:
if current_rcparams is not None and restablish_rcparams:
matplotlib.rcParams.update(current_rcparams)

return ax
Expand Down
Loading

0 comments on commit d09f773

Please sign in to comment.