Skip to content

Commit

Permalink
add missing crossval_splits.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jan 9, 2024
1 parent 2b9cfc2 commit 8cb4084
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions nnunetv2/utilities/crossval_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List

import numpy as np
from sklearn.model_selection import KFold


def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]:
splits = []
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)):
train_keys = np.array(train_identifiers)[train_idx]
test_keys = np.array(train_identifiers)[test_idx]
splits.append({})
splits[-1]['train'] = list(train_keys)
splits[-1]['val'] = list(test_keys)
return splits

0 comments on commit 8cb4084

Please sign in to comment.