Skip to content

Commit

Permalink
Initial example
Browse files Browse the repository at this point in the history
  • Loading branch information
Clifford-appliedAI-GmbH committed Apr 13, 2023
1 parent f93d950 commit 4450b61
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions src/prototree/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,30 @@ def update_leaf_distributions(
"""
batch_size, num_classes = logits.shape

y_true_one_hot = F.one_hot(y_true, num_classes=num_classes)
y_true_logits = torch.log(y_true_one_hot)
# Other sparse formats may be better than COO.
# TODO: This is a bit convoluted. Why is there no sparse version of torch.nn.functional.one_hot ?
y_true_range = torch.arange(0, batch_size)
y_true_indices = torch.stack((y_true_range, y_true))
y_true_one_hot = torch.sparse_coo_tensor(
y_true_indices, torch.ones_like(y_true, dtype=torch.bool), logits.shape
)

for leaf in root.leaves:
update_leaf(leaf, node_to_prob, logits, y_true_logits, smoothing_factor)
update_leaf(leaf, node_to_prob, logits, y_true_one_hot, smoothing_factor)


def update_leaf(
leaf: Leaf,
node_to_prob: dict[Node, NodeProbabilities],
logits: torch.Tensor,
y_true_logits: torch.Tensor,
y_true_one_hot: torch.Tensor,
smoothing_factor: float,
):
"""
:param leaf:
:param node_to_prob:
:param logits: of shape (batch_size, num_classes)
:param y_true_logits: of shape (batch_size, num_classes)
:param y_true_one_hot: boolean tensor of shape (batch_size, num_classes)
:param smoothing_factor:
:return:
"""
Expand All @@ -110,15 +115,15 @@ def update_leaf(
# shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
leaf_logits = leaf.y_logits()

# TODO: y_true_logits is mostly -Inf terms (the rest being 0s) that won't contribute to the total, and we are also
# summing together tensors of different shapes. We should be able to express this more clearly and efficiently by
# taking advantage of this sparsity.
log_dist_update = torch.logsumexp(
log_p_arrival + leaf_logits + y_true_logits - logits,
dim=0,
)
masked_logits = logits.sparse_mask(y_true_one_hot)
masked_log_p_arrival = y_true_one_hot * log_p_arrival
masked_leaf_logits = y_true_one_hot * leaf_logits
masked_log_combined = masked_log_p_arrival + masked_leaf_logits - masked_logits

# TODO: Can't use logsumexp because masked tensors don't support it.
masked_dist_update = torch.logsumexp(masked_log_combined, dim=0)

dist_update = torch.exp(log_dist_update)
dist_update = masked_dist_update.to_tensor(0.0)

# This scaling (subtraction of `-1/n_batches * c` in the ProtoTree paper) seems to be a form of exponentially
# weighted moving average, designed to ensure stability of the leaf class probability distributions (
Expand Down

0 comments on commit 4450b61

Please sign in to comment.