Skip to content

Commit

Permalink
consolidated a few checks
Browse files Browse the repository at this point in the history
  • Loading branch information
scap3yvt authored Jan 17, 2024
1 parent d42167a commit d70a4fb
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions GANDLF/utils/parameter_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ def populate_header_in_parameters(parameters, headers):
parameters["model"]["num_classes"] = len(headers["predictionHeaders"])

# initialize model type for processing: if not defined, default to torch
if not ("type" in parameters["model"]):
parameters["model"]["type"] = "torch"
parameters["model"]["type"] = parameters["model"].get("type", "torch")

if parameters["model"]["type"] == "openvino" and parameters["model"][
"architecture"
] in ["brain_age", "sdnet"]:
print(
"Only PyTorch for inference is supported for the current model architecture: {0}.".format(
"Only PyTorch for inference is supported for the current network topology: {0}.".format(
parameters["model"]["architecture"]
)
)
parameters["model"]["type"] = "torch"

# initialize number of channels for processing
if not ("num_channels" in parameters["model"]):
parameters["model"]["num_channels"] = len(headers["channelHeaders"])
parameters["model"]["num_channels"] = parameters["model"].get(
"num_channels", len(headers["channelHeaders"])
)

parameters["problem_type"] = find_problem_type(
parameters, get_modelbase_final_layer(parameters["model"]["final_layer"])
Expand Down

0 comments on commit d70a4fb

Please sign in to comment.