diff --git a/torch_cmspepr/gravnet_model.py b/torch_cmspepr/gravnet_model.py index 4d42f84..dab5345 100644 --- a/torch_cmspepr/gravnet_model.py +++ b/torch_cmspepr/gravnet_model.py @@ -16,7 +16,7 @@ def global_exchange(x: Tensor, batch: Tensor) -> Tensor: """ n_hits_per_event = scatter_count(batch) n_hits, n_features = x.size() - batch_size = batch.max()+1 + batch_size = int(batch.max()) + 1 # minmeanmax: (batch_size x 3*n_features) meanminmax = torch.cat(( @@ -24,10 +24,10 @@ def global_exchange(x: Tensor, batch: Tensor) -> Tensor: scatter_min(x, batch, dim=0)[0], scatter_max(x, batch, dim=0)[0] ), dim=1) - assert meanminmax.size() == (batch_size, 3*n_features) + assert list(meanminmax.size()) == [batch_size, 3*n_features] meanminmax = torch.repeat_interleave(meanminmax, n_hits_per_event, dim=0) - assert meanminmax.size() == (n_hits, 3*n_features) + assert list(meanminmax.size()) == [n_hits, 3*n_features] out = torch.cat((meanminmax, x), dim=1) assert out.size() == (n_hits, 4*n_features) @@ -68,7 +68,7 @@ def __init__( self.gravnet_layer = GravNetConv( in_channels, out_channels, space_dimensions, propagate_dimensions, k - ) + ).jittable() self.post_gravnet = nn.Sequential( nn.BatchNorm1d(out_channels), nn.Linear(out_channels, 128),