Skip to content

Commit

Permalink
add error bars to variable importance (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Jun 26, 2023
1 parent 7cf8595 commit de582f7
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.stats import norm, pearsonr
from xarray import concat

from .tree import Tree

Expand Down Expand Up @@ -742,6 +743,12 @@ def plot_variable_importance(
labels = X.columns
X = X.values

n_draws = idata["posterior"].dims["draw"]
half = n_draws // 2
f_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(0, half - 1))
s_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(half, n_draws))

var_imp_chains = concat([f_half, s_half], dim="chain", join="override").mean(("draw")).values
var_imp = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
if labels is None:
labels_ary = np.arange(len(var_imp))
Expand All @@ -759,7 +766,16 @@ def plot_variable_importance(
indices = idxs[::-1]
else:
indices = np.arange(len(var_imp))
axes[0].plot((var_imp / var_imp.sum())[indices], "o-")

chains_mean = (var_imp / var_imp.sum())[indices]
chains_hdi = az.hdi((var_imp_chains.T / var_imp_chains.sum(axis=1)).T)[indices]

axes[0].errorbar(
ticks,
chains_mean,
np.array((chains_mean - chains_hdi[:, 0], chains_hdi[:, 1] - chains_mean)),
color="C0",
)
axes[0].set_xticks(ticks)
axes[0].set_xticklabels(labels_ary[indices])
axes[0].set_xlabel("covariables")
Expand Down Expand Up @@ -790,8 +806,9 @@ def plot_variable_importance(
ev_mean[idx] = np.mean(pearson)
ev_hdi[idx] = az.hdi(pearson)

axes[1].errorbar(ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)))

axes[1].errorbar(
ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)), color="C0"
)
axes[1].axhline(ev_mean[-1], ls="--", color="0.5")
axes[1].set_xticks(ticks)
axes[1].set_xticklabels(ticks + 1)
Expand Down

0 comments on commit de582f7

Please sign in to comment.