Skip to content

Commit

Permalink
Exclude mtt::aux:: quantities from composition models
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 16, 2024
1 parent fc71849 commit a743996
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def forward(
# at evaluation, we also add the composition contributions
composition_contributions = self.composition_model(systems, outputs)
for name in return_dict:
if name.startswith("mtt::aux::"):
continue
return_dict[name] = metatensor.torch.add(
return_dict[name],
composition_contributions[name],
Expand Down
4 changes: 4 additions & 0 deletions src/metatrain/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def forward(
device = systems[0].positions.device

for output_name in outputs:
if output_name.startswith("mtt::aux::"):
continue
if output_name not in self.output_to_output_index:
raise ValueError(
f"output key {output_name} is not supported by this composition "
Expand All @@ -210,6 +212,8 @@ def forward(
# number of atoms per atomic type.
targets_out: Dict[str, TensorMap] = {}
for target_key, target in outputs.items():
if target_key.startswith("mtt::aux::"):
continue
weights = self.weights[self.output_to_output_index[target_key]]
targets_list = []
sample_values: List[List[int]] = []
Expand Down

0 comments on commit a743996

Please sign in to comment.