Skip to content

Commit

Permalink
improve plot_variable_importance
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 22, 2024
1 parent e8f258a commit 686d54f
Showing 1 changed file with 45 additions and 12 deletions.
57 changes: 45 additions & 12 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
from numba import jit
from pytensor.tensor.variable import Variable
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.stats import norm, pearsonr
from scipy.stats import norm

from .tree import Tree

Expand Down Expand Up @@ -700,8 +701,9 @@ def plot_variable_importance( # noqa: PLR0915
method: str = "VI",
figsize: Optional[Tuple[float, float]] = None,
xlabel_angle: float = 0,
samples: int = 100,
samples: int = 50,
random_seed: Optional[int] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
) -> Tuple[List[int], Union[List[plt.Axes], Any]]:
"""
Expand Down Expand Up @@ -733,6 +735,14 @@ def plot_variable_importance( # noqa: PLR0915
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
random_seed : Optional[int]
random_seed used to sample from the posterior. Defaults to None.
plot_kwargs : dict
Additional keyword arguments for the plot. Defaults to None.
Valid keys are:
- color_r2: matplotlib valid color for error bars
- marker_r2: matplotlib valid marker for the mean R squared
- marker_fc_r2: matplotlib valid marker face color for the mean R squared
- ls_ref: matplotlib valid linestyle for the reference line
- color_ref: matplotlib valid color for the reference line
ax : axes
Matplotlib axes.
Expand All @@ -745,6 +755,9 @@ def plot_variable_importance( # noqa: PLR0915

all_trees = bartrv.owner.op.all_trees

if plot_kwargs is None:
plot_kwargs = {}

if bartrv.ndim == 1: # type: ignore
shape = 1
else:
Expand Down Expand Up @@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
)

r_2_ref = np.array(
[pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)]
)

if method == "VI":
idxs = np.argsort(
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
Expand All @@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915
shape=shape,
)
r_2 = np.array(
[
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2
for j in range(samples)
]
[pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)]
)
r2_mean[idx] = np.mean(r_2)
r2_hdi[idx] = az.hdi(r_2)
Expand Down Expand Up @@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915
# Calculate Pearson correlation for each sample and find the mean
r_2 = np.zeros(samples)
for j in range(samples):
r_2[j] = (
(pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0])
** 2
)
r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j])
mean_r_2 = np.mean(r_2, dtype=float)
# Identify the least important combination of variables
# based on the maximum mean squared Pearson correlation
Expand Down Expand Up @@ -872,9 +883,21 @@ def plot_variable_importance( # noqa: PLR0915
ticks,
r2_mean,
np.array((r2_yerr_min, r2_yerr_max)),
color="C0",
color=plot_kwargs.get("color_r2", "k"),
fmt=plot_kwargs.get("marker_r2", "o"),
mfc=plot_kwargs.get("marker_fc_r2", "white"),
)
ax.axhline(
np.mean(r_2_ref),
ls=plot_kwargs.get("ls_ref", "--"),
color=plot_kwargs.get("color_ref", "grey"),
)
ax.fill_between(
[-0.5, n_vars - 0.5],
*az.hdi(r_2_ref),
alpha=0.1,
color=plot_kwargs.get("color_ref", "grey"),
)
ax.axhline(r2_mean[-1], ls="--", color="0.5")
ax.set_xticks(ticks, new_labels, rotation=xlabel_angle)
ax.set_ylabel("R²", rotation=0, labelpad=12)
ax.set_ylim(0, 1)
Expand All @@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include):
else:
sequences = [()]
return sequences


@jit(nopython=True)
def pearsonr2(A, B):
"""Compute the squared Pearson correlation coefficient"""
A = A.flatten()
B = B.flatten()
am = A - np.mean(A)
bm = B - np.mean(B)
return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))

0 comments on commit 686d54f

Please sign in to comment.