Skip to content

Commit

Permalink
Merge pull request #784 from scap3yvt/783-rename-parseconfig-to-confi…
Browse files Browse the repository at this point in the history
…gmanager-for-consistency

Rename parseconfig for consistency
  • Loading branch information
sarthakpati authored Jan 26, 2024
2 parents b31d7fc + 2b2ef97 commit e8ddde5
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 437 deletions.
4 changes: 2 additions & 2 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import SimpleITK as sitk
import numpy as np

from GANDLF.parseConfig import parseConfig
from GANDLF.config_manager import ConfigManager
from GANDLF.utils import find_problem_type_from_parameters, one_hot
from GANDLF.metrics import (
overall_stats,
Expand Down Expand Up @@ -58,7 +58,7 @@ def generate_metrics_dict(input_csv: str, config: str, outputfile: str = None) -
assert column in headers, f"The input csv should have a column named {column}"

overall_stats_dict = {}
parameters = parseConfig(config)
parameters = ConfigManager(config)
problem_type = parameters.get("problem_type", None)
problem_type = (
find_problem_type_from_parameters(parameters)
Expand Down
4 changes: 2 additions & 2 deletions GANDLF/cli/main_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from GANDLF.training_manager import TrainingManager, TrainingManager_split
from GANDLF.inference_manager import InferenceManager
from GANDLF.parseConfig import parseConfig
from GANDLF.config_manager import ConfigManager
from GANDLF.utils import (
populate_header_in_parameters,
parseTrainingCSV,
Expand Down Expand Up @@ -34,7 +34,7 @@ def main_run(
file_data_full = data_csv
model_parameters = config_file
device = device
parameters = parseConfig(model_parameters)
parameters = ConfigManager(model_parameters)
parameters["device_id"] = -1

if train_mode:
Expand Down
4 changes: 2 additions & 2 deletions GANDLF/cli/post_training_model_optimization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from GANDLF.compute import create_pytorch_objects
from GANDLF.parseConfig import parseConfig
from GANDLF.config_manager import ConfigManager
from GANDLF.utils import version_check, load_model, optimize_and_save_model


Expand All @@ -21,7 +21,7 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool:

# If parameters are not available in the model file, parse them from the config file
parameters = (
parseConfig(config_path, version_check_flag=False)
ConfigManager(config_path, version_check_flag=False)
if parameters is None
else parameters
)
Expand Down
4 changes: 2 additions & 2 deletions GANDLF/cli/preprocess_and_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_dataframe,
get_correct_padding_size,
)
from GANDLF.parseConfig import parseConfig
from GANDLF.config_manager import ConfigManager
from GANDLF.data.ImagesFromDataFrame import ImagesFromDataFrame
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -45,7 +45,7 @@ def preprocess_and_save(
# read the csv
# don't care if the dataframe gets shuffled or not
dataframe, headers = parseTrainingCSV(data_csv, train=False)
parameters = parseConfig(config_file)
parameters = ConfigManager(config_file)

# save the parameters so that the same compute doesn't happen once again
parameter_file = os.path.join(output_dir, "parameters.pkl")
Expand Down
16 changes: 15 additions & 1 deletion GANDLF/parseConfig.py → GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def initialize_key(parameters, key, value=None):
return parameters


def parseConfig(config_file_path, version_check_flag=True):
def _parseConfig(config_file_path, version_check_flag=True):
"""
This function parses the configuration file and returns a dictionary of parameters.
Expand Down Expand Up @@ -714,3 +714,17 @@ def parseConfig(config_file_path, version_check_flag=True):
params["inference_mechanism"] = inference_mechanism

return params


def ConfigManager(config_file_path, version_check_flag=True) -> None:
"""
This function parses the configuration file and returns a dictionary of parameters.
Args:
config_file_path (Union[str, dict]): The filename of the configuration file.
version_check_flag (bool, optional): Whether to check the version in configuration file. Defaults to True.
Returns:
dict: The parameter dictionary.
"""
return _parseConfig(config_file_path, version_check_flag)
2 changes: 1 addition & 1 deletion docs/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Before starting to work on the code-level on GaNDLF, please follow the [instruct
## Overall Architecture

- Command-line parsing: [gandlf_run](https://github.com/mlcommons/GaNDLF/blob/master/gandlf_run)
- Parameters from [training configuration](https://github.com/mlcommons/GaNDLF/blob/master/samples/config_all_options.yaml) get passed as a `dict` via [parameter parser](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/parseConfig.py)
- Parameters from [training configuration](https://github.com/mlcommons/GaNDLF/blob/master/samples/config_all_options.yaml) get passed as a `dict` via the [config manager](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/config_manager.py)
- [Training Manager](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/training_manager.py):
- Handles k-fold training
- Main entry point from CLI
Expand Down
Loading

0 comments on commit e8ddde5

Please sign in to comment.