Skip to content

Commit

Permalink
clean and update
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 3, 2022
1 parent f32bc75 commit 1861124
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pymc_bart.utils import plot_dependence, plot_variable_importance

__all__ = ["BART", "PGBART"]
__version__ = "0.1.0"
__version__ = "0.2.0"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
20 changes: 8 additions & 12 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,47 +80,43 @@ def __init__(
else:
self.X = self.bart.X

self.Y = self.bart.Y
self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.alpha = self.bart.alpha
shape = initial_values[value_bart.name].shape
if len(shape) == 1:
self.shape = 1
else:
self.shape = shape[0]

# self.alpha_vec = self.bart.split_prior
if self.bart.split_prior:
self.alpha_vec = self.bart.split_prior
else:
self.alpha_vec = np.ones(self.X.shape[1])
self.init_mean = self.Y.mean()
init_mean = self.bart.Y.mean()
# if data is binary
y_unique = np.unique(self.Y)
y_unique = np.unique(self.bart.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
mu_std = 3 / self.m**0.5
# maybe we need to check for count data
else:
mu_std = self.Y.std() / self.m**0.5
mu_std = self.bart.Y.std() / self.m**0.5

self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))

self.sum_trees = np.full((self.shape, self.Y.shape[0]), self.init_mean).astype(
self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
config.floatX
)
self.sum_trees_noi = self.sum_trees - (self.init_mean / self.m)
self.sum_trees_noi = self.sum_trees - (init_mean / self.m)
self.a_tree = Tree(
leaf_node_value=self.init_mean / self.m,
leaf_node_value=init_mean / self.m,
idx_data_points=np.arange(self.num_observations, dtype="int32"),
num_observations=self.num_observations,
shape=self.shape,
)
self.normal = NormalSampler(mu_std, self.shape)
self.uniform = UniformSampler(0.33, 0.75, self.shape)
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha)
self.ssv = SampleSplittingVariable(self.alpha_vec)

self.tune = True
Expand All @@ -143,7 +139,7 @@ def __init__(
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
self.all_particles = []
for _ in range(self.m):
self.a_tree.leaf_node_value = self.init_mean / self.m
self.a_tree.leaf_node_value = init_mean / self.m
p = ParticleTree(self.a_tree)
self.all_particles.append(p)
self.all_trees = np.array([p.tree for p in self.all_particles])
Expand Down
3 changes: 0 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ def get_version():
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown",
packages=find_packages(),
# because of an upload-size limit by PyPI, we're temporarily removing docs from the tarball.
# Also see MANIFEST.in
# package_data={'docs': ['*']},
include_package_data=True,
classifiers=classifiers,
python_requires=">=3.8",
Expand Down

0 comments on commit 1861124

Please sign in to comment.