diff --git a/nnunetv2/utilities/crossval_split.py b/nnunetv2/utilities/crossval_split.py new file mode 100644 index 000000000..472603b00 --- /dev/null +++ b/nnunetv2/utilities/crossval_split.py @@ -0,0 +1,16 @@ +from typing import List + +import numpy as np +from sklearn.model_selection import KFold + + +def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]: + splits = [] + kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed) + for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)): + train_keys = np.array(train_identifiers)[train_idx] + test_keys = np.array(train_identifiers)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + return splits