diff --git a/src/prototree/train.py b/src/prototree/train.py index dbea4a062..2df3749c4 100644 --- a/src/prototree/train.py +++ b/src/prototree/train.py @@ -83,25 +83,30 @@ def update_leaf_distributions( """ batch_size, num_classes = logits.shape - y_true_one_hot = F.one_hot(y_true, num_classes=num_classes) - y_true_logits = torch.log(y_true_one_hot) + # Other sparse formats may be better than COO. + # TODO: This is a bit convoluted. Why is there no sparse version of torch.nn.functional.one_hot ? + y_true_range = torch.arange(0, batch_size) + y_true_indices = torch.stack((y_true_range, y_true)) + y_true_one_hot = torch.sparse_coo_tensor( + y_true_indices, torch.ones_like(y_true, dtype=torch.bool), logits.shape + ) for leaf in root.leaves: - update_leaf(leaf, node_to_prob, logits, y_true_logits, smoothing_factor) + update_leaf(leaf, node_to_prob, logits, y_true_one_hot, smoothing_factor) def update_leaf( leaf: Leaf, node_to_prob: dict[Node, NodeProbabilities], logits: torch.Tensor, - y_true_logits: torch.Tensor, + y_true_one_hot: torch.Tensor, smoothing_factor: float, ): """ :param leaf: :param node_to_prob: :param logits: of shape (batch_size, num_classes) - :param y_true_logits: of shape (batch_size, num_classes) + :param y_true_one_hot: boolean tensor of shape (batch_size, num_classes) :param smoothing_factor: :return: """ @@ -110,15 +115,15 @@ def update_leaf( # shape (num_classes). Not the same as logits, which has (batch_size, num_classes) leaf_logits = leaf.y_logits() - # TODO: y_true_logits is mostly -Inf terms (the rest being 0s) that won't contribute to the total, and we are also - # summing together tensors of different shapes. We should be able to express this more clearly and efficiently by - # taking advantage of this sparsity. - log_dist_update = torch.logsumexp( - log_p_arrival + leaf_logits + y_true_logits - logits, - dim=0, - ) + masked_logits = logits.sparse_mask(y_true_one_hot) + masked_log_p_arrival = y_true_one_hot * log_p_arrival + masked_leaf_logits = y_true_one_hot * leaf_logits + masked_log_combined = masked_log_p_arrival + masked_leaf_logits - masked_logits + + # TODO: Can't use logsumexp because masked tensors don't support it. + masked_dist_update = torch.logsumexp(masked_log_combined, dim=0) - dist_update = torch.exp(log_dist_update) + dist_update = masked_dist_update.to_tensor(0.0) # This scaling (subtraction of `-1/n_batches * c` in the ProtoTree paper) seems to be a form of exponentially # weighted moving average, designed to ensure stability of the leaf class probability distributions (