Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute running variance for leaf nodes #91

Merged
merged 2 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BARTRV(RandomVariable):

name: str = "BART"
ndim_supp = 1
ndims_params: List[int] = [2, 1, 0, 0, 1]
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = List[List[Tree]]
Expand All @@ -45,7 +45,9 @@ def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=Non
return dist_params[0].shape[:1]

@classmethod
def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, size=None):
def rng_fn(
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
):
if not cls.all_trees:
if size is not None:
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
Expand Down Expand Up @@ -94,7 +96,8 @@ def __new__(
X: TensorLike,
Y: TensorLike,
m: int = 50,
alpha: float = 0.25,
alpha: float = 0.95,
beta: float = 2,
response: str = "constant",
split_prior: Optional[List[float]] = None,
**kwargs,
Expand All @@ -120,6 +123,7 @@ def __new__(
m=m,
response=response,
alpha=alpha,
beta=beta,
split_prior=split_prior,
),
)()
Expand All @@ -131,7 +135,7 @@ def get_moment(rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = bart_op
params = [X, Y, m, alpha, split_prior]
params = [X, Y, m, alpha, beta, split_prior]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
Expand Down
100 changes: 66 additions & 34 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,16 @@
class ParticleTree:
"""Particle tree."""

__slots__ = "tree", "expansion_nodes", "log_weight", "kfactor"
__slots__ = "tree", "expansion_nodes", "log_weight"

def __init__(self, tree: Tree, kfactor: float = 0.75):
def __init__(self, tree: Tree):
self.tree: Tree = tree.copy()
self.expansion_nodes: List[int] = [0]
self.log_weight: float = 0
self.kfactor: float = kfactor

def copy(self) -> "ParticleTree":
p = ParticleTree(self.tree)
p.expansion_nodes = self.expansion_nodes.copy()
p.kfactor = self.kfactor
return p

def sample_tree(
Expand All @@ -53,6 +51,7 @@ def sample_tree(
X,
missing_data,
sum_trees,
leaf_sd,
m,
response,
normal,
Expand All @@ -73,10 +72,10 @@ def sample_tree(
X,
missing_data,
sum_trees,
leaf_sd,
m,
response,
normal,
self.kfactor,
shape,
)
if idx_new_nodes is not None:
Expand All @@ -95,7 +94,7 @@ class PGBART(ArrayStepShared):
vars: list
List of value variables for sampler
num_particles : tuple
Number of particles. Defaults to 20
Number of particles. Defaults to 10
batch : int or tuple
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and after tuning. If a tuple is passed the first element is the batch size
Expand All @@ -112,7 +111,7 @@ class PGBART(ArrayStepShared):
def __init__(
self,
vars=None, # pylint: disable=redefined-builtin
num_particles: int = 20,
num_particles: int = 10,
batch: Tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
):
Expand Down Expand Up @@ -141,17 +140,20 @@ def __init__(
self.alpha_vec = self.bart.split_prior
else:
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)

init_mean = self.bart.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))

# if data is binary
y_unique = np.unique(self.bart.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
mu_std = 3 / self.m**0.5
self.leaf_sd = 3 / self.m**0.5
else:
mu_std = self.bart.Y.std() / self.m**0.5
self.leaf_sd = self.bart.Y.std() / self.m**0.5

self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.running_sd = RunningSd(shape)

self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
config.floatX
Expand All @@ -164,10 +166,9 @@ def __init__(
shape=self.shape,
)

self.normal = NormalSampler(mu_std, self.shape)
self.normal = NormalSampler(1, self.shape)
self.uniform = UniformSampler(0, 1)
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha)
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha, self.bart.beta)
self.ssv = SampleSplittingVariable(self.alpha_vec)

self.tune = True
Expand Down Expand Up @@ -212,6 +213,7 @@ def astep(self, _):
self.X,
self.missing_data,
self.sum_trees,
self.leaf_sd,
self.m,
self.response,
self.normal,
Expand All @@ -235,16 +237,25 @@ def astep(self, _):
particles, normalized_weights
)
# Update the sum of trees
self.sum_trees = self.sum_trees_noi + new_tree._predict()
new = new_tree._predict()
self.sum_trees = self.sum_trees_noi + new
# To reduce memory usage, we trim the tree
self.all_trees[tree_id] = new_tree.trim()

if self.tune:
# Update the splitting variable and the splitting variable sampler
if self.iter > self.m:
self.ssv = SampleSplittingVariable(self.alpha_vec)

for index in new_tree.get_split_variables():
self.alpha_vec[index] += 1

# update standard deviation at leaf nodes
if self.iter > 2:
self.leaf_sd = self.running_sd.update(new)
else:
self.running_sd.update(new)

else:
# update the variable inclusion
for index in new_tree.get_split_variables():
Expand Down Expand Up @@ -320,10 +331,7 @@ def init_particles(self, tree_id: int) -> List[ParticleTree]:
self.update_weight(p0)
particles: List[ParticleTree] = [p0]

particles.extend(
ParticleTree(self.a_tree, self.uniform_kf.rvs() if self.tune else p0.kfactor)
for _ in self.indices
)
particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
return particles

def update_weight(self, particle: ParticleTree) -> None:
Expand All @@ -344,6 +352,34 @@ def competence(var, has_grad):
return Competence.INCOMPATIBLE


class RunningSd:
def __init__(self, shape: tuple) -> None:
self.count = 0 # number of data points
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)


@njit
def _update(
count: int,
mean: npt.NDArray[np.float_],
m_2: npt.NDArray[np.float_],
new_value: npt.NDArray[np.float_],
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], Union[float, npt.NDArray[np.float_]]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
m_2 += delta * delta2

std = (m_2 / count) ** 0.5
return mean, m_2, std


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None:
"""
Expand All @@ -362,30 +398,26 @@ def rvs(self) -> Union[int, Tuple[int, float]]:
return self.enu[-1]


def compute_prior_probability(alpha) -> List[float]:
def compute_prior_probability(alpha: int, beta: int) -> List[float]:
"""
Calculate the probability of the node being a leaf node (1 - p(being split node)).

Taken from equation 19 in [Rockova2018].

Parameters
----------
alpha : float
beta: float

Returns
-------
list with probabilities for leaf nodes

References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
"""
prior_leaf_prob: List[float] = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha**depth)
depth = 0
while prior_leaf_prob[-1] < 0.9999:
prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta))))
depth += 1
prior_leaf_prob.append(1)

return prior_leaf_prob


Expand All @@ -397,10 +429,10 @@ def grow_tree(
X,
missing_data,
sum_trees,
leaf_sd,
m,
response,
normal,
kfactor,
shape,
):
current_node = tree.get_node(index_leaf_node)
Expand Down Expand Up @@ -432,7 +464,7 @@ def grow_tree(
y_mu_pred=sum_trees[:, idx_data_point],
x_mu=X[idx_data_point, selected_predictor],
m=m,
norm=normal.rvs() * kfactor,
norm=normal.rvs() * leaf_sd,
shape=shape,
response=response,
)
Expand Down Expand Up @@ -493,7 +525,7 @@ def draw_leaf_value(
if response == "linear":
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m)

draw = norm + mu_mean
draw = mu_mean + norm
return draw, linear_params


Expand Down