From 9246426c750db83a6f58d7e55b43ad54c91bbe82 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Tue, 25 Jul 2023 13:34:08 -0400 Subject: [PATCH] added hybrid loss for test --- testing/test_full.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/testing/test_full.py b/testing/test_full.py index 86004edfd..73ad8af65 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -1226,9 +1226,6 @@ def get_parameters_after_alteration(loss_type: str) -> dict: parameters = parseConfig(file_config_temp, version_check_flag=True) parameters["nested_training"]["testing"] = -5 parameters["nested_training"]["validation"] = -5 - parameters = parseConfig( - testingDir + "/config_segmentation.yaml", version_check_flag=False - ) training_data, parameters["headers"] = parseTrainingCSV( inputDir + "/train_2d_rad_segmentation.csv" ) @@ -1246,7 +1243,14 @@ def get_parameters_after_alteration(loss_type: str) -> dict: parameters = populate_header_in_parameters(parameters, parameters["headers"]) return parameters, training_data # loop through selected models and train for single epoch - for loss_type in ["dc", "dc_log", "dcce", "dcce_logits", "tversky", "focal"]: + for loss_type in [ + "dc", + "dc_log", + "dcce", + "dcce_logits", + "tversky", + "focal", + "dc_focal"]: parameters, training_data = get_parameters_after_alteration(loss_type) sanitize_outputDir() TrainingManager(