Skip to content

Commit e6b3b69

Browse files
committed
Add test for csv_to_wfdb().
1 parent f874b1c commit e6b3b69

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

tests/io/test_convert.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
import os
2+
import shutil
3+
import unittest
4+
15
import numpy as np
26

37
from wfdb.io.record import rdrecord
48
from wfdb.io.convert.edf import read_edf
9+
from wfdb.io.convert.csv import csv_to_wfdb
10+
511

12+
class TestEdfToWfdb:
13+
"""
14+
Tests for the io.convert.edf module.
15+
"""
616

7-
class TestConvert:
817
def test_edf_uniform(self):
918
"""
1019
EDF format conversion to MIT for uniform sample rates.
11-
1220
"""
1321
# Uniform sample rates
1422
record_MIT = rdrecord("sample-data/n16").__dict__
@@ -60,7 +68,6 @@ def test_edf_uniform(self):
6068
def test_edf_non_uniform(self):
6169
"""
6270
EDF format conversion to MIT for non-uniform sample rates.
63-
6471
"""
6572
# Non-uniform sample rates
6673
record_MIT = rdrecord("sample-data/wave_4").__dict__
@@ -108,3 +115,65 @@ def test_edf_non_uniform(self):
108115

109116
target_results = len(fields) * [True]
110117
assert np.array_equal(test_results, target_results)
118+
119+
120+
class TestCsvToWfdb(unittest.TestCase):
121+
"""
122+
Tests for the io.convert.csv module.
123+
"""
124+
125+
def setUp(self):
126+
"""
127+
Create a temporary directory containing data for testing.
128+
129+
Load 100.dat file for comparison to 100.csv file.
130+
"""
131+
self.test_dir = "test_output"
132+
os.makedirs(self.test_dir, exist_ok=True)
133+
134+
self.record_100_csv = "sample-data/100.csv"
135+
self.record_100_dat = rdrecord("sample-data/100", physical=True)
136+
137+
def tearDown(self):
138+
"""
139+
Remove the temporary directory after the test.
140+
"""
141+
if os.path.exists(self.test_dir):
142+
shutil.rmtree(self.test_dir)
143+
144+
def test_write_dir(self):
145+
"""
146+
Call the function with the write_dir argument.
147+
"""
148+
csv_to_wfdb(
149+
file_name=self.record_100_csv,
150+
fs=360,
151+
units="mV",
152+
write_dir=self.test_dir,
153+
)
154+
155+
# Check if the output files are created in the specified directory
156+
base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0]
157+
expected_dat_file = os.path.join(self.test_dir, f"{base_name}.dat")
158+
expected_hea_file = os.path.join(self.test_dir, f"{base_name}.hea")
159+
160+
self.assertTrue(os.path.exists(expected_dat_file))
161+
self.assertTrue(os.path.exists(expected_hea_file))
162+
163+
# Check that newly written file matches the 100.dat file
164+
record_write = rdrecord(os.path.join(self.test_dir, base_name))
165+
166+
self.assertEqual(record_write.fs, 360)
167+
self.assertEqual(record_write.fs, self.record_100_dat.fs)
168+
self.assertEqual(record_write.units, ["mV", "mV"])
169+
self.assertEqual(record_write.units, self.record_100_dat.units)
170+
self.assertEqual(record_write.sig_name, ["MLII", "V5"])
171+
self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name)
172+
self.assertEqual(record_write.p_signal.size, 1300000)
173+
self.assertEqual(
174+
record_write.p_signal.size, self.record_100_dat.p_signal.size
175+
)
176+
177+
178+
if __name__ == "__main__":
179+
unittest.main()

0 commit comments

Comments
 (0)