diff --git a/mne_nirs/io/snirf/_snirf.py b/mne_nirs/io/snirf/_snirf.py index e21554de0..83e2f3ff9 100644 --- a/mne_nirs/io/snirf/_snirf.py +++ b/mne_nirs/io/snirf/_snirf.py @@ -98,7 +98,12 @@ def _add_metadata_tags(raw, nirs): metadata_tags.create_dataset("MeasurementTime", data=_str_encode(timestr)) # Store demographic info - subject_id = raw.info["subject_info"]["first_name"] + subject_id = "_".join( + raw.info["subject_info"][key] + for key in ["first_name", "middle_name", "last_name"] + if key in raw.info["subject_info"] + ) + subject_id = raw.info["subject_info"].get("his_id", subject_id) metadata_tags.create_dataset("SubjectID", data=_str_encode(subject_id)) # Store the units of measurement @@ -111,12 +116,10 @@ def _add_metadata_tags(raw, nirs): birthday = datetime.date(*raw.info["subject_info"]["birthday"]) birthstr = birthday.strftime("%Y-%m-%d") metadata_tags.create_dataset("DateOfBirth", data=[_str_encode(birthstr)]) - if "middle_name" in raw.info["subject_info"]: - middle_name = raw.info["subject_info"]["middle_name"] - metadata_tags.create_dataset("middleName", data=[_str_encode(middle_name)]) - if "last_name" in raw.info["subject_info"]: - last_name = raw.info["subject_info"]["last_name"] - metadata_tags.create_dataset("lastName", data=[_str_encode(last_name)]) + for key in ("first", "middle", "last"): + name = raw.info["subject_info"].get(f"{key}_name", None) + if name is not None: + metadata_tags.create_dataset(f"{key}Name", data=[_str_encode(name)]) if "sex" in raw.info["subject_info"]: sex = str(int(raw.info["subject_info"]["sex"])) metadata_tags.create_dataset("sex", data=[_str_encode(sex)]) diff --git a/mne_nirs/io/snirf/tests/test_snirf.py b/mne_nirs/io/snirf/tests/test_snirf.py index 18f4df109..ff36c8daf 100644 --- a/mne_nirs/io/snirf/tests/test_snirf.py +++ b/mne_nirs/io/snirf/tests/test_snirf.py @@ -10,7 +10,7 @@ from mne.datasets.testing import data_path, requires_testing_data from mne.io import read_raw_nirx, read_raw_snirf from mne.preprocessing.nirs import beer_lambert_law, optical_density -from mne.utils import object_diff +from mne.utils import check_version, object_diff from numpy.testing import assert_allclose, assert_array_equal from snirf import Snirf, validateSnirf @@ -30,6 +30,8 @@ pytest.importorskip("h5py") +MNE_1_7 = check_version("mne", "1.7") + @requires_testing_data @pytest.mark.parametrize( @@ -38,9 +40,23 @@ def test_snirf_write_raw(fname, tmpdir): """Test reading NIRX files.""" raw_orig = read_raw_nirx(fname, preload=True) + subj_info = raw_orig.info["subject_info"] + his_id = "_".join( # this is how MNE-Python creates his_id for NIRX files + subj_info[key] + for key in ("first_name", "middle_name", "last_name") + if key in subj_info + ) + assert raw_orig.info["subject_info"]["his_id"] == his_id test_file = tmpdir.join("test_raw.snirf") write_raw_snirf(raw_orig, test_file) - raw = read_raw_snirf(test_file) + raw = read_raw_snirf(test_file, preload=True) + # Correct MNE bug with reading + subj_info = raw.info["subject_info"] + if not MNE_1_7: + subj_info["first_name"] = subj_info["first_name"].split("_")[0] + assert "his_id" not in subj_info + subj_info["his_id"] = his_id + assert subj_info["his_id"] == his_id result = validateSnirf(str(test_file)) if result.is_valid(): @@ -81,6 +97,15 @@ def test_snirf_write_raw(fname, tmpdir): _verify_snirf_required_fields(test_file) _verify_snirf_version_str(test_file) + # make sure one can be written and read with no his_id (e.g., old MNE-Python) + del raw.info["subject_info"]["his_id"] + write_raw_snirf(raw, test_file) + raw = read_raw_snirf(test_file) + if MNE_1_7: + assert raw.info["subject_info"]["his_id"] == his_id + else: + assert raw.info["subject_info"]["first_name"] == his_id + @requires_testing_data @pytest.mark.parametrize(