Skip to content

Commit

Permalink
Add test for csv_to_wfdb().
Browse files Browse the repository at this point in the history
  • Loading branch information
tompollard committed Jul 9, 2024
1 parent f874b1c commit e6b3b69
Showing 1 changed file with 72 additions and 3 deletions.
75 changes: 72 additions & 3 deletions tests/io/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import os
import shutil
import unittest

import numpy as np

from wfdb.io.record import rdrecord
from wfdb.io.convert.edf import read_edf
from wfdb.io.convert.csv import csv_to_wfdb


class TestEdfToWfdb:
"""
Tests for the io.convert.edf module.
"""

class TestConvert:
def test_edf_uniform(self):
"""
EDF format conversion to MIT for uniform sample rates.
"""
# Uniform sample rates
record_MIT = rdrecord("sample-data/n16").__dict__
Expand Down Expand Up @@ -60,7 +68,6 @@ def test_edf_uniform(self):
def test_edf_non_uniform(self):
"""
EDF format conversion to MIT for non-uniform sample rates.
"""
# Non-uniform sample rates
record_MIT = rdrecord("sample-data/wave_4").__dict__
Expand Down Expand Up @@ -108,3 +115,65 @@ def test_edf_non_uniform(self):

target_results = len(fields) * [True]
assert np.array_equal(test_results, target_results)


class TestCsvToWfdb(unittest.TestCase):
"""
Tests for the io.convert.csv module.
"""

def setUp(self):
"""
Create a temporary directory containing data for testing.
Load 100.dat file for comparison to 100.csv file.
"""
self.test_dir = "test_output"
os.makedirs(self.test_dir, exist_ok=True)

self.record_100_csv = "sample-data/100.csv"
self.record_100_dat = rdrecord("sample-data/100", physical=True)

def tearDown(self):
"""
Remove the temporary directory after the test.
"""
if os.path.exists(self.test_dir):
shutil.rmtree(self.test_dir)

def test_write_dir(self):
"""
Call the function with the write_dir argument.
"""
csv_to_wfdb(
file_name=self.record_100_csv,
fs=360,
units="mV",
write_dir=self.test_dir,
)

# Check if the output files are created in the specified directory
base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0]
expected_dat_file = os.path.join(self.test_dir, f"{base_name}.dat")
expected_hea_file = os.path.join(self.test_dir, f"{base_name}.hea")

self.assertTrue(os.path.exists(expected_dat_file))
self.assertTrue(os.path.exists(expected_hea_file))

# Check that newly written file matches the 100.dat file
record_write = rdrecord(os.path.join(self.test_dir, base_name))

self.assertEqual(record_write.fs, 360)
self.assertEqual(record_write.fs, self.record_100_dat.fs)
self.assertEqual(record_write.units, ["mV", "mV"])
self.assertEqual(record_write.units, self.record_100_dat.units)
self.assertEqual(record_write.sig_name, ["MLII", "V5"])
self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name)
self.assertEqual(record_write.p_signal.size, 1300000)
self.assertEqual(
record_write.p_signal.size, self.record_100_dat.p_signal.size
)


if __name__ == "__main__":
unittest.main()

0 comments on commit e6b3b69

Please sign in to comment.