diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index 26a642c0..1229aa53 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -426,10 +426,10 @@ def calibrate(self, valid_loader: DataLoader): all_predictions[name] = [] all_targets[name] = [] all_uncertainties[uncertainty_name] = [] - all_predictions[name].append(outputs[name].block().values) + all_predictions[name].append(outputs[name].block().values.detach()) all_targets[name].append(target.block().values) all_uncertainties[uncertainty_name].append( - outputs[uncertainty_name].block().values + outputs[uncertainty_name].block().values.detach() ) for name in all_predictions: