Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve plot_variable_importance #182

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
from numba import jit
from pytensor.tensor.variable import Variable
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.stats import norm, pearsonr
from scipy.stats import norm

from .tree import Tree

Expand Down Expand Up @@ -700,8 +701,9 @@ def plot_variable_importance( # noqa: PLR0915
method: str = "VI",
figsize: Optional[Tuple[float, float]] = None,
xlabel_angle: float = 0,
samples: int = 100,
samples: int = 50,
random_seed: Optional[int] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
) -> Tuple[List[int], Union[List[plt.Axes], Any]]:
"""
Expand Down Expand Up @@ -733,6 +735,14 @@ def plot_variable_importance( # noqa: PLR0915
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
random_seed : Optional[int]
random_seed used to sample from the posterior. Defaults to None.
plot_kwargs : dict
Additional keyword arguments for the plot. Defaults to None.
Valid keys are:
- color_r2: matplotlib valid color for error bars
- marker_r2: matplotlib valid marker for the mean R squared
- marker_fc_r2: matplotlib valid marker face color for the mean R squared
- ls_ref: matplotlib valid linestyle for the reference line
- color_ref: matplotlib valid color for the reference line
ax : axes
Matplotlib axes.

Expand All @@ -745,6 +755,9 @@ def plot_variable_importance( # noqa: PLR0915

all_trees = bartrv.owner.op.all_trees

if plot_kwargs is None:
plot_kwargs = {}

if bartrv.ndim == 1: # type: ignore
shape = 1
else:
Expand Down Expand Up @@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
)

r_2_ref = np.array(
[pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)]
)

if method == "VI":
idxs = np.argsort(
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
Expand All @@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915
shape=shape,
)
r_2 = np.array(
[
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2
for j in range(samples)
]
[pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)]
)
r2_mean[idx] = np.mean(r_2)
r2_hdi[idx] = az.hdi(r_2)
Expand Down Expand Up @@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915
# Calculate Pearson correlation for each sample and find the mean
r_2 = np.zeros(samples)
for j in range(samples):
r_2[j] = (
(pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0])
** 2
)
r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j])
mean_r_2 = np.mean(r_2, dtype=float)
# Identify the least important combination of variables
# based on the maximum mean squared Pearson correlation
Expand Down Expand Up @@ -872,9 +883,21 @@ def plot_variable_importance( # noqa: PLR0915
ticks,
r2_mean,
np.array((r2_yerr_min, r2_yerr_max)),
color="C0",
color=plot_kwargs.get("color_r2", "k"),
fmt=plot_kwargs.get("marker_r2", "o"),
mfc=plot_kwargs.get("marker_fc_r2", "white"),
)
ax.axhline(
np.mean(r_2_ref),
ls=plot_kwargs.get("ls_ref", "--"),
color=plot_kwargs.get("color_ref", "grey"),
)
ax.fill_between(
[-0.5, n_vars - 0.5],
*az.hdi(r_2_ref),
alpha=0.1,
color=plot_kwargs.get("color_ref", "grey"),
)
ax.axhline(r2_mean[-1], ls="--", color="0.5")
ax.set_xticks(ticks, new_labels, rotation=xlabel_angle)
ax.set_ylabel("R²", rotation=0, labelpad=12)
ax.set_ylim(0, 1)
Expand All @@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include):
else:
sequences = [()]
return sequences


@jit(nopython=True)
def pearsonr2(A, B):
"""Compute the squared Pearson correlation coefficient"""
A = A.flatten()
B = B.flatten()
am = A - np.mean(A)
bm = B - np.mean(B)
return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))
Loading