Skip to content

Commit

Permalink
Merge pull request #6 from cms-pepr/gravnet_jit
Browse files Browse the repository at this point in the history
Minor changes so that gravnet will jit out of the box
  • Loading branch information
lgray authored Nov 19, 2021
2 parents 68ef833 + 10d46bb commit 377366b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch_cmspepr/gravnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ def global_exchange(x: Tensor, batch: Tensor) -> Tensor:
"""
n_hits_per_event = scatter_count(batch)
n_hits, n_features = x.size()
batch_size = batch.max()+1
batch_size = int(batch.max()) + 1

# minmeanmax: (batch_size x 3*n_features)
meanminmax = torch.cat((
scatter_mean(x, batch, dim=0),
scatter_min(x, batch, dim=0)[0],
scatter_max(x, batch, dim=0)[0]
), dim=1)
assert meanminmax.size() == (batch_size, 3*n_features)
assert list(meanminmax.size()) == [batch_size, 3*n_features]

meanminmax = torch.repeat_interleave(meanminmax, n_hits_per_event, dim=0)
assert meanminmax.size() == (n_hits, 3*n_features)
assert list(meanminmax.size()) == [n_hits, 3*n_features]

out = torch.cat((meanminmax, x), dim=1)
assert out.size() == (n_hits, 4*n_features)
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
self.gravnet_layer = GravNetConv(
in_channels, out_channels,
space_dimensions, propagate_dimensions, k
)
).jittable()
self.post_gravnet = nn.Sequential(
nn.BatchNorm1d(out_channels),
nn.Linear(out_channels, 128),
Expand Down

0 comments on commit 377366b

Please sign in to comment.