Skip to content

Commit

Permalink
Fix PET
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 1, 2024
1 parent 18c87ad commit 603daa0
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
from pet.pet import PET, SelfContributionsWrapper
from pet.train_model import fit_pet

from ...utils.additive import remove_additive
from ...utils.data import Dataset, check_datasets, collate_fn
from ...utils.data.system_to_ase import system_to_ase
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from . import PET as WrappedPET


Expand Down Expand Up @@ -94,6 +99,12 @@ def train(

ase_train_dataset = []
for (system,), targets in train_dataloader:
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
for additive_model in model.additive_models:
targets = remove_additive(
[system], targets, additive_model, model.dataset_info.targets
)
ase_atoms = system_to_ase(system)
ase_atoms.info["energy"] = float(
targets[target_name].block().values.squeeze(-1).detach().cpu().numpy()
Expand Down

0 comments on commit 603daa0

Please sign in to comment.