Skip to content

Commit

Permalink
added hybrid loss for test
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Jul 25, 2023
1 parent e40d8f6 commit 9246426
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,9 +1226,6 @@ def get_parameters_after_alteration(loss_type: str) -> dict:
parameters = parseConfig(file_config_temp, version_check_flag=True)
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
Expand All @@ -1246,7 +1243,14 @@ def get_parameters_after_alteration(loss_type: str) -> dict:
parameters = populate_header_in_parameters(parameters, parameters["headers"])
return parameters, training_data
# loop through selected models and train for single epoch
for loss_type in ["dc", "dc_log", "dcce", "dcce_logits", "tversky", "focal"]:
for loss_type in [
"dc",
"dc_log",
"dcce",
"dcce_logits",
"tversky",
"focal",
"dc_focal"]:
parameters, training_data = get_parameters_after_alteration(loss_type)
sanitize_outputDir()
TrainingManager(
Expand Down

0 comments on commit 9246426

Please sign in to comment.