diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f63919d..31a148c 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -22,6 +22,9 @@ def knn_interpolate( squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + y_idx, x_idx = y_idx.to(x.device), x_idx.to(x.device) + weights = weights.to(x.device) + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum")