From cda5de9f82ce948d63522df22174f93a5de431bc Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 10:24:41 -0700 Subject: [PATCH] moving validation from checkpoint and not both under function 'validate' --- .../fl_post/fl/project/src/runner_nnunetv1.py | 86 +++++++++++-------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 299e13e4c..17a0624b1 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -182,42 +182,58 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" - self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) - # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}") - # 2. Train/val function existing externally - # Some todo inside function below - # TODO: test for off-by-one error - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? - train_completed, \ - val_completed, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, - epochs=1, - current_epoch=current_epoch, - train_cutoff=0, - val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name(), - val_epoch=True, - train_epoch=False) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - - # double check - if train_completed != 0.0: - raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") - - # 3. Prepare metrics - metrics = {'val_eval': this_val_eval_metrics} - + def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): + hash_1 = np.sum[torch.mean(_value) for _value in td_1.values()] + hash_2 = np.sum[torch.mean(_value) for _value in td_2.values()] + delta = np.abs(hash_1 - hash_2) + if delta > epsilon: + raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") + + + if not from_checkpoint: + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}") + # 2. Train/val function existing externally + # Some todo inside function below + # TODO: test for off-by-one error + + # FIXME: we need to understand how to use round_num instead of current_epoch + # this will matter in straggler handling cases + # TODO: Should we put this in a separate process? + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + epochs=1, + current_epoch=current_epoch, + train_cutoff=0, + val_cutoff = self.val_cutoff, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=False) + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # 3. Prepare metrics + metrics = {'val_eval': this_val_eval_metrics} + else: + checkpoint_dict = self.load_checkpoint() + # double check + compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict']) + + (all_tr_losses, _, _, _) = checkpoint_dict['plot_stuff'] + # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after + metrics = {'train_loss': all_tr_losses[-1]} ###################################################################################################### # TODO: Provide val_completed to be incorporated into the collab weight computation @@ -241,7 +257,5 @@ def load_metrics(self, filepath): # WORKING HERE def validate_by_reading_checkpoint(self, col_name, round_num, input_tensor_dict, **kwargs): - (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after - metrics = {'train_loss': all_tr_losses[-1]} + fjkdls;jafkdls;jfkdsl; return self.convert_results_to_tensorkeys(col_name, round_num, metrics) \ No newline at end of file