diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index b2387e55b..7a9c9039a 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -124,7 +124,7 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): # get device for correct placement of tensors device = self.device - checkpoint_dict = self.load_checkpoint(map_location=device) + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load, map_location=device) epoch = checkpoint_dict['epoch'] new_state = {} # grabbing keys from the checkpoint state dict, poping from the tensor_dict