diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 71fc9c48..9bf17c70 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -123,7 +123,7 @@ def __init__( # not yet implemented gradients = gradients loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} - all_inputs = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} + all_inputs: dict[str | int, Any] = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} all_inputs.update( {k: v for k, v in loss_inputs.items() if v not in outputs.values()} )