From e6b3b695f27b7a995d65106d80490e11d6c72154 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 9 Jul 2024 16:45:48 -0400 Subject: [PATCH] Add test for csv_to_wfdb(). --- tests/io/test_convert.py | 75 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/io/test_convert.py b/tests/io/test_convert.py index aa7ba78a..cf97f700 100644 --- a/tests/io/test_convert.py +++ b/tests/io/test_convert.py @@ -1,14 +1,22 @@ +import os +import shutil +import unittest + import numpy as np from wfdb.io.record import rdrecord from wfdb.io.convert.edf import read_edf +from wfdb.io.convert.csv import csv_to_wfdb + +class TestEdfToWfdb: + """ + Tests for the io.convert.edf module. + """ -class TestConvert: def test_edf_uniform(self): """ EDF format conversion to MIT for uniform sample rates. - """ # Uniform sample rates record_MIT = rdrecord("sample-data/n16").__dict__ @@ -60,7 +68,6 @@ def test_edf_uniform(self): def test_edf_non_uniform(self): """ EDF format conversion to MIT for non-uniform sample rates. - """ # Non-uniform sample rates record_MIT = rdrecord("sample-data/wave_4").__dict__ @@ -108,3 +115,65 @@ def test_edf_non_uniform(self): target_results = len(fields) * [True] assert np.array_equal(test_results, target_results) + + +class TestCsvToWfdb(unittest.TestCase): + """ + Tests for the io.convert.csv module. + """ + + def setUp(self): + """ + Create a temporary directory containing data for testing. + + Load 100.dat file for comparison to 100.csv file. + """ + self.test_dir = "test_output" + os.makedirs(self.test_dir, exist_ok=True) + + self.record_100_csv = "sample-data/100.csv" + self.record_100_dat = rdrecord("sample-data/100", physical=True) + + def tearDown(self): + """ + Remove the temporary directory after the test. + """ + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_write_dir(self): + """ + Call the function with the write_dir argument. + """ + csv_to_wfdb( + file_name=self.record_100_csv, + fs=360, + units="mV", + write_dir=self.test_dir, + ) + + # Check if the output files are created in the specified directory + base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0] + expected_dat_file = os.path.join(self.test_dir, f"{base_name}.dat") + expected_hea_file = os.path.join(self.test_dir, f"{base_name}.hea") + + self.assertTrue(os.path.exists(expected_dat_file)) + self.assertTrue(os.path.exists(expected_hea_file)) + + # Check that newly written file matches the 100.dat file + record_write = rdrecord(os.path.join(self.test_dir, base_name)) + + self.assertEqual(record_write.fs, 360) + self.assertEqual(record_write.fs, self.record_100_dat.fs) + self.assertEqual(record_write.units, ["mV", "mV"]) + self.assertEqual(record_write.units, self.record_100_dat.units) + self.assertEqual(record_write.sig_name, ["MLII", "V5"]) + self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name) + self.assertEqual(record_write.p_signal.size, 1300000) + self.assertEqual( + record_write.p_signal.size, self.record_100_dat.p_signal.size + ) + + +if __name__ == "__main__": + unittest.main()