Skip to content

Commit

Permalink
knn_interpolate gpu patch
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Aug 15, 2024
1 parent 31ee1e9 commit 2970757
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def knn_interpolate(
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

y_idx, x_idx = y_idx.to(x.device), x_idx.to(x.device)
weights = weights.to(x.device)

den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum")
y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum")

Expand Down

0 comments on commit 2970757

Please sign in to comment.