Skip to content

Commit

Permalink
round instead of current_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-edwards committed Oct 21, 2024
1 parent 227cac2 commit b90f732
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions examples/fl_post/fl/project/src/runner_nnunetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs):

self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs)
# 1. Insert tensor_dict info into checkpoint
current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False)
self.logger.info(f"In col train method, loaded checkpoint with current epoch: {current_epoch}")
# 2. Train/val function existing externally
# Some todo inside function below
# TODO: test for off-by-one error

# FIXME: we need to understand how to use round_num instead of current_epoch
# this will matter in straggler handling cases
# TODO: Should we put this in a separate process?
self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False)
self.logger.info(f"Training for round:{round_num}")
train_completed, \
val_completed, \
this_ave_train_loss, \
Expand All @@ -174,7 +167,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs):
this_val_eval_metrics_C3, \
this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs,
epochs=epochs,
current_epoch=current_epoch,
round=round,
train_cutoff=self.train_cutoff,
val_cutoff = self.val_cutoff,
task=self.data_loader.get_task_name(),
Expand Down Expand Up @@ -213,15 +206,10 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True):
if not from_checkpoint:
self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs)
# 1. Insert tensor_dict info into checkpoint
current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False)
self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}")
self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False)
self.logger.info(f"Validating for round:{round_num}")
# 2. Train/val function existing externally
# Some todo inside function below
# TODO: test for off-by-one error

# FIXME: we need to understand how to use round_num instead of current_epoch
# this will matter in straggler handling cases
# TODO: Should we put this in a separate process?
train_completed, \
val_completed, \
this_ave_train_loss, \
Expand All @@ -232,7 +220,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True):
this_val_eval_metrics_C3, \
this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs,
epochs=1,
current_epoch=current_epoch,
round=round_num,
train_cutoff=0,
val_cutoff = self.val_cutoff,
task=self.data_loader.get_task_name(),
Expand Down

0 comments on commit b90f732

Please sign in to comment.