Skip to content

Commit

Permalink
Merge pull request #847 from mlcommons/fix_data_splitter
Browse files Browse the repository at this point in the history
Fixing data split logic
  • Loading branch information
sarthakpati authored Apr 18, 2024
2 parents a1f2d5f + 348df6c commit 4d751fd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
6 changes: 2 additions & 4 deletions GANDLF/utils/data_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ def split_data(
"nested_training" in parameters
), "`nested_training` key missing in parameters"
# populate the headers
_, parameters["headers"] = (
parseTrainingCSV(full_dataset) if "headers" not in parameters else full_dataset,
parameters["headers"],
)
if "headers" not in parameters:
_, parameters["headers"] = parseTrainingCSV(full_dataset)

parameters = (
populate_header_in_parameters(parameters, parameters["headers"])
Expand Down
6 changes: 1 addition & 5 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -3124,11 +3124,7 @@ def test_generic_data_split():
)
parameters["nested_training"] = {"testing": 5, "validation": 5, "stratified": True}
# read and parse csv
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_3d_rad_classification.csv"
)
parameters["model"]["num_channels"] = len(parameters["headers"]["channelHeaders"])
parameters = populate_header_in_parameters(parameters, parameters["headers"])
training_data, _ = parseTrainingCSV(inputDir + "/train_3d_rad_classification.csv")
# duplicate the data to test stratified sampling
training_data_duplicate = training_data._append(training_data)
for _ in range(1):
Expand Down

0 comments on commit 4d751fd

Please sign in to comment.