From 10d46bba57f7684d8bbf9b6dd39bf95dff53d45e Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Fri, 19 Nov 2021 08:41:23 -0600 Subject: [PATCH] minor changes so that gravnet will jit out of the box --- torch_cmspepr/gravnet_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_cmspepr/gravnet_model.py b/torch_cmspepr/gravnet_model.py index 4d42f84..dab5345 100644 --- a/torch_cmspepr/gravnet_model.py +++ b/torch_cmspepr/gravnet_model.py @@ -16,7 +16,7 @@ def global_exchange(x: Tensor, batch: Tensor) -> Tensor: """ n_hits_per_event = scatter_count(batch) n_hits, n_features = x.size() - batch_size = batch.max()+1 + batch_size = int(batch.max()) + 1 # minmeanmax: (batch_size x 3*n_features) meanminmax = torch.cat(( @@ -24,10 +24,10 @@ def global_exchange(x: Tensor, batch: Tensor) -> Tensor: scatter_min(x, batch, dim=0)[0], scatter_max(x, batch, dim=0)[0] ), dim=1) - assert meanminmax.size() == (batch_size, 3*n_features) + assert list(meanminmax.size()) == [batch_size, 3*n_features] meanminmax = torch.repeat_interleave(meanminmax, n_hits_per_event, dim=0) - assert meanminmax.size() == (n_hits, 3*n_features) + assert list(meanminmax.size()) == [n_hits, 3*n_features] out = torch.cat((meanminmax, x), dim=1) assert out.size() == (n_hits, 4*n_features) @@ -68,7 +68,7 @@ def __init__( self.gravnet_layer = GravNetConv( in_channels, out_channels, space_dimensions, propagate_dimensions, k - ) + ).jittable() self.post_gravnet = nn.Sequential( nn.BatchNorm1d(out_channels), nn.Linear(out_channels, 128),