Skip to content

Commit

Permalink
Fix the dimension index (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijian-liu authored Jul 31, 2024
1 parent 8da35e0 commit 9eb3b88
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchsparse/nn/modules/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor:
self.kernel, 0, torch.div(coords[:, self.dim], stride).trunc().long()
)
feats = (feats.unsqueeze(dim=-1) * kernel).sum(1) + self.bias
coords = (coords - self.offset).t()[[3] + self.bev_dims].long()
coords = (coords - self.offset).t()[[0] + self.bev_dims].long()
coords[1:] = torch.div(coords[1:], stride).trunc().long()
indices = (
coords[0] * int(self.bev_shape.prod())
Expand Down Expand Up @@ -197,7 +197,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor:
assert isinstance(stride, torch.Tensor), type(stride)

# [b, x, y, z]
coords = (coords - self.offset).t()[[3] + self.bev_dims + [self.dim]].long()
coords = (coords - self.offset).t()[[0] + self.bev_dims + [self.dim]].long()
shape = self.shape[self.bev_dims + [self.dim]]

# now stride must be torch.Tensor since input.s is tuple.
Expand Down

0 comments on commit 9eb3b88

Please sign in to comment.