From a743996c496081d8934098876512392645a5d1e8 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 16 Aug 2024 20:27:57 +0200 Subject: [PATCH] Exclude `mtt::aux::` quantities from composition models --- src/metatrain/experimental/soap_bpnn/model.py | 2 ++ src/metatrain/utils/composition.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index c0c2a585..a3d0d4ea 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -285,6 +285,8 @@ def forward( # at evaluation, we also add the composition contributions composition_contributions = self.composition_model(systems, outputs) for name in return_dict: + if name.startswith("mtt::aux::"): + continue return_dict[name] = metatensor.torch.add( return_dict[name], composition_contributions[name], diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index 76a63b47..cf419056 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -194,6 +194,8 @@ def forward( device = systems[0].positions.device for output_name in outputs: + if output_name.startswith("mtt::aux::"): + continue if output_name not in self.output_to_output_index: raise ValueError( f"output key {output_name} is not supported by this composition " @@ -210,6 +212,8 @@ def forward( # number of atoms per atomic type. targets_out: Dict[str, TensorMap] = {} for target_key, target in outputs.items(): + if target_key.startswith("mtt::aux::"): + continue weights = self.weights[self.output_to_output_index[target_key]] targets_list = [] sample_values: List[List[int]] = []