Skip to content

Commit

Permalink
compute running variance for leaf nodes (#91)
Browse files Browse the repository at this point in the history
* use running variance for leaf nodes

* use running variance for leaf nodes
  • Loading branch information
aloctavodia authored Jun 28, 2023
1 parent de582f7 commit d1bb5b7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 38 deletions.
12 changes: 8 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BARTRV(RandomVariable):

name: str = "BART"
ndim_supp = 1
ndims_params: List[int] = [2, 1, 0, 0, 1]
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[Tree]]
Expand All @@ -45,7 +45,9 @@ def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=Non
return dist_params[0].shape[:1]

@classmethod
def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, size=None):
def rng_fn(
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
):
if not cls.all_trees:
if size is not None:
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
Expand Down Expand Up @@ -94,7 +96,8 @@ def __new__(
X: TensorLike,
Y: TensorLike,
m: int = 50,
alpha: float = 0.25,
alpha: float = 0.95,
beta: float = 2,
response: str = "constant",
split_prior: Optional[List[float]] = None,
**kwargs,
Expand All @@ -120,6 +123,7 @@ def __new__(
m=m,
response=response,
alpha=alpha,
beta=beta,
split_prior=split_prior,
),
)()
Expand All @@ -131,7 +135,7 @@ def get_moment(rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = bart_op
params = [X, Y, m, alpha, split_prior]
params = [X, Y, m, alpha, beta, split_prior]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
Expand Down
100 changes: 66 additions & 34 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,16 @@
class ParticleTree:
"""Particle tree."""

__slots__ = "tree", "expansion_nodes", "log_weight", "kfactor"
__slots__ = "tree", "expansion_nodes", "log_weight"

def __init__(self, tree: Tree, kfactor: float = 0.75):
def __init__(self, tree: Tree):
self.tree: Tree = tree.copy()
self.expansion_nodes: List[int] = [0]
self.log_weight: float = 0
self.kfactor: float = kfactor

def copy(self) -> "ParticleTree":
p = ParticleTree(self.tree)
p.expansion_nodes = self.expansion_nodes.copy()
p.kfactor = self.kfactor
return p

def sample_tree(
Expand All @@ -53,6 +51,7 @@ def sample_tree(
X,
missing_data,
sum_trees,
leaf_sd,
m,
response,
normal,
Expand All @@ -73,10 +72,10 @@ def sample_tree(
X,
missing_data,
sum_trees,
leaf_sd,
m,
response,
normal,
self.kfactor,
shape,
)
if idx_new_nodes is not None:
Expand All @@ -95,7 +94,7 @@ class PGBART(ArrayStepShared):
vars: list
List of value variables for sampler
num_particles : tuple
Number of particles. Defaults to 20
Number of particles. Defaults to 10
batch : int or tuple
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and after tuning. If a tuple is passed the first element is the batch size
Expand All @@ -112,7 +111,7 @@ class PGBART(ArrayStepShared):
def __init__(
self,
vars=None, # pylint: disable=redefined-builtin
num_particles: int = 20,
num_particles: int = 10,
batch: Tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
):
Expand Down Expand Up @@ -141,17 +140,20 @@ def __init__(
self.alpha_vec = self.bart.split_prior
else:
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)

init_mean = self.bart.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))

# if data is binary
y_unique = np.unique(self.bart.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
mu_std = 3 / self.m**0.5
self.leaf_sd = 3 / self.m**0.5
else:
mu_std = self.bart.Y.std() / self.m**0.5
self.leaf_sd = self.bart.Y.std() / self.m**0.5

self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.running_sd = RunningSd(shape)

self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
config.floatX
Expand All @@ -164,10 +166,9 @@ def __init__(
shape=self.shape,
)

self.normal = NormalSampler(mu_std, self.shape)
self.normal = NormalSampler(1, self.shape)
self.uniform = UniformSampler(0, 1)
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha)
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha, self.bart.beta)
self.ssv = SampleSplittingVariable(self.alpha_vec)

self.tune = True
Expand Down Expand Up @@ -212,6 +213,7 @@ def astep(self, _):
self.X,
self.missing_data,
self.sum_trees,
self.leaf_sd,
self.m,
self.response,
self.normal,
Expand All @@ -235,16 +237,25 @@ def astep(self, _):
particles, normalized_weights
)
# Update the sum of trees
self.sum_trees = self.sum_trees_noi + new_tree._predict()
new = new_tree._predict()
self.sum_trees = self.sum_trees_noi + new
# To reduce memory usage, we trim the tree
self.all_trees[tree_id] = new_tree.trim()

if self.tune:
# Update the splitting variable and the splitting variable sampler
if self.iter > self.m:
self.ssv = SampleSplittingVariable(self.alpha_vec)

for index in new_tree.get_split_variables():
self.alpha_vec[index] += 1

# update standard deviation at leaf nodes
if self.iter > 2:
self.leaf_sd = self.running_sd.update(new)
else:
self.running_sd.update(new)

else:
# update the variable inclusion
for index in new_tree.get_split_variables():
Expand Down Expand Up @@ -320,10 +331,7 @@ def init_particles(self, tree_id: int) -> List[ParticleTree]:
self.update_weight(p0)
particles: List[ParticleTree] = [p0]

particles.extend(
ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor)
for _ in self.indices
)
particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
return particles

def update_weight(self, particle: ParticleTree) -> None:
Expand All @@ -344,6 +352,34 @@ def competence(var, has_grad):
return Competence.INCOMPATIBLE


class RunningSd:
def __init__(self, shape: tuple) -> None:
self.count = 0 # number of data points
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)


@njit
def _update(
count: int,
mean: npt.NDArray[np.float_],
m_2: npt.NDArray[np.float_],
new_value: npt.NDArray[np.float_],
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], Union[float, npt.NDArray[np.float_]]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
m_2 += delta * delta2

std = (m_2 / count) ** 0.5
return mean, m_2, std


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None:
"""
Expand All @@ -362,30 +398,26 @@ def rvs(self) -> Union[int, Tuple[int, float]]:
return self.enu[-1]


def compute_prior_probability(alpha) -> List[float]:
def compute_prior_probability(alpha: int, beta: int) -> List[float]:
"""
Calculate the probability of the node being a leaf node (1 - p(being split node)).
Taken from equation 19 in [Rockova2018].
Parameters
----------
alpha : float
beta: float
Returns
-------
list with probabilities for leaf nodes
References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
"""
prior_leaf_prob: List[float] = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha**depth)
depth = 0
while prior_leaf_prob[-1] < 0.9999:
prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta))))
depth += 1
prior_leaf_prob.append(1)

return prior_leaf_prob


Expand All @@ -397,10 +429,10 @@ def grow_tree(
X,
missing_data,
sum_trees,
leaf_sd,
m,
response,
normal,
kfactor,
shape,
):
current_node = tree.get_node(index_leaf_node)
Expand Down Expand Up @@ -432,7 +464,7 @@ def grow_tree(
y_mu_pred=sum_trees[:, idx_data_point],
x_mu=X[idx_data_point, selected_predictor],
m=m,
norm=normal.rvs() * kfactor,
norm=normal.rvs() * leaf_sd,
shape=shape,
response=response,
)
Expand Down Expand Up @@ -493,7 +525,7 @@ def draw_leaf_value(
if response == "linear":
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m)

draw = norm + mu_mean
draw = mu_mean + norm
return draw, linear_params


Expand Down

0 comments on commit d1bb5b7

Please sign in to comment.