diff --git a/GANDLF/utils/parameter_processing.py b/GANDLF/utils/parameter_processing.py index f2e665053..e41ac7f7a 100644 --- a/GANDLF/utils/parameter_processing.py +++ b/GANDLF/utils/parameter_processing.py @@ -21,22 +21,22 @@ def populate_header_in_parameters(parameters, headers): parameters["model"]["num_classes"] = len(headers["predictionHeaders"]) # initialize model type for processing: if not defined, default to torch - if not ("type" in parameters["model"]): - parameters["model"]["type"] = "torch" + parameters["model"]["type"] = parameters["model"].get("type", "torch") if parameters["model"]["type"] == "openvino" and parameters["model"][ "architecture" ] in ["brain_age", "sdnet"]: print( - "Only PyTorch for inference is supported for the current model architecture: {0}.".format( + "Only PyTorch for inference is supported for the current network topology: {0}.".format( parameters["model"]["architecture"] ) ) parameters["model"]["type"] = "torch" # initialize number of channels for processing - if not ("num_channels" in parameters["model"]): - parameters["model"]["num_channels"] = len(headers["channelHeaders"]) + parameters["model"]["num_channels"] = parameters["model"].get( + "num_channels", len(headers["channelHeaders"]) + ) parameters["problem_type"] = find_problem_type( parameters, get_modelbase_final_layer(parameters["model"]["final_layer"])