Skip to content

Commit

Permalink
escn cpu fix (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
mshuaibii authored Jul 11, 2023
1 parent dafa193 commit ab7833d
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions ocpmodels/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def __init__(
self.sphere_channels_all = self.num_resolutions * self.sphere_channels
self.basis_width_scalar = basis_width_scalar
self.distance_function = distance_function
self.device = torch.cuda.current_device()

# variables used for display purposes
self.counter = 0
Expand Down Expand Up @@ -226,13 +225,13 @@ def __init__(

@conditional_grad(torch.enable_grad())
def forward(self, data):
device = data.pos.device
self.batch_size = len(data.natoms)
self.dtype = data.pos.dtype

start_time = time.time()
atomic_numbers = data.atomic_numbers.long()
num_atoms = len(atomic_numbers)
pos = data.pos

(
edge_index,
Expand Down Expand Up @@ -269,7 +268,7 @@ def forward(self, data):
num_atoms,
self.lmax_list,
self.sphere_channels,
self.device,
device,
self.dtype,
)

Expand All @@ -285,7 +284,7 @@ def forward(self, data):

# This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer
mappingReduced = CoefficientMapping(
self.lmax_list, self.mmax_list, self.device
self.lmax_list, self.mmax_list, device
)

###############################################################
Expand Down Expand Up @@ -319,7 +318,7 @@ def forward(self, data):

# Sample the spherical channels (node embeddings) at evenly distributed points on the sphere.
# These values are fed into the output blocks.
x_pt = torch.tensor([], device=self.device)
x_pt = torch.tensor([], device=device)
offset = 0
# Compute the embedding values at every sampled point on the sphere
for i in range(self.num_resolutions):
Expand All @@ -343,7 +342,7 @@ def forward(self, data):
# Energy estimation
###############################################################
node_energy = self.energy_block(x_pt)
energy = torch.zeros(len(data.natoms), device=pos.device)
energy = torch.zeros(len(data.natoms), device=device)
energy.index_add_(0, data.batch, node_energy.view(-1))
# Scale energy to help balance numerical precision w.r.t. forces
energy = energy * 0.001
Expand Down

0 comments on commit ab7833d

Please sign in to comment.