Skip to content

Commit

Permalink
moving validation from checkpoint and not both under function 'validate'
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-edwards committed Oct 8, 2024
1 parent 059d126 commit cda5de9
Showing 1 changed file with 50 additions and 36 deletions.
86 changes: 50 additions & 36 deletions examples/fl_post/fl/project/src/runner_nnunetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,42 +182,58 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs):
return self.convert_results_to_tensorkeys(col_name, round_num, metrics)


def validate(self, col_name, round_num, input_tensor_dict, **kwargs):
def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False, **kwargs):
# TODO: Figure out the right name to use for this method and the default assigner
"""Perform validation."""

self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs)
# 1. Insert tensor_dict info into checkpoint
current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False)
self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}")
# 2. Train/val function existing externally
# Some todo inside function below
# TODO: test for off-by-one error

# FIXME: we need to understand how to use round_num instead of current_epoch
# this will matter in straggler handling cases
# TODO: Should we put this in a separate process?
train_completed, \
val_completed, \
this_ave_train_loss, \
this_ave_val_loss, \
this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs,
epochs=1,
current_epoch=current_epoch,
train_cutoff=0,
val_cutoff = self.val_cutoff,
task=self.data_loader.get_task_name(),
val_epoch=True,
train_epoch=False)
self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}")

# double check
if train_completed != 0.0:
raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.")

# 3. Prepare metrics
metrics = {'val_eval': this_val_eval_metrics}

def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001):
hash_1 = np.sum[torch.mean(_value) for _value in td_1.values()]
hash_2 = np.sum[torch.mean(_value) for _value in td_2.values()]
delta = np.abs(hash_1 - hash_2)
if delta > epsilon:
raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.")


if not from_checkpoint:
self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs)
# 1. Insert tensor_dict info into checkpoint
current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False)
self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}")
# 2. Train/val function existing externally
# Some todo inside function below
# TODO: test for off-by-one error

# FIXME: we need to understand how to use round_num instead of current_epoch
# this will matter in straggler handling cases
# TODO: Should we put this in a separate process?
train_completed, \
val_completed, \
this_ave_train_loss, \
this_ave_val_loss, \
this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs,
epochs=1,
current_epoch=current_epoch,
train_cutoff=0,
val_cutoff = self.val_cutoff,
task=self.data_loader.get_task_name(),
val_epoch=True,
train_epoch=False)
self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}")

# double check
if train_completed != 0.0:
raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.")

# 3. Prepare metrics
metrics = {'val_eval': this_val_eval_metrics}
else:
checkpoint_dict = self.load_checkpoint()
# double check
compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'])

(all_tr_losses, _, _, _) = checkpoint_dict['plot_stuff']
# these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after
metrics = {'train_loss': all_tr_losses[-1]}

######################################################################################################
# TODO: Provide val_completed to be incorporated into the collab weight computation
Expand All @@ -241,7 +257,5 @@ def load_metrics(self, filepath):
# WORKING HERE

def validate_by_reading_checkpoint(self, col_name, round_num, input_tensor_dict, **kwargs):
(all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff']
# these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after
metrics = {'train_loss': all_tr_losses[-1]}
fjkdls;jafkdls;jfkdsl;
return self.convert_results_to_tensorkeys(col_name, round_num, metrics)

0 comments on commit cda5de9

Please sign in to comment.