Skip to content

Commit

Permalink
Allow Y to be a tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Aug 19, 2024
1 parent 0e22798 commit 4dfb08a
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def __init__( # noqa: PLR0915
else:
self.X = self.bart.X

if isinstance(self.bart.Y, Variable):
self.Y = self.bart.Y.eval()
else:
self.Y = self.bart.Y


self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.response = self.bart.response
Expand Down Expand Up @@ -166,26 +172,26 @@ def __init__( # noqa: PLR0915
if rule is ContinuousSplitRule:
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))

init_mean = self.bart.Y.mean()
init_mean = self.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))

# if data is binary
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))

y_unique = np.unique(self.bart.Y)
y_unique = np.unique(self.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
self.leaf_sd *= 3 / self.m**0.5
else:
self.leaf_sd *= self.bart.Y.std() / self.m**0.5
self.leaf_sd *= self.Y.std() / self.m**0.5

self.running_sd = [
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
]

self.sum_trees = np.full(
(self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
(self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean
).astype(config.floatX)
self.sum_trees_noi = self.sum_trees - init_mean
self.a_tree = Tree.new_tree(
Expand Down

0 comments on commit 4dfb08a

Please sign in to comment.