|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +import unittest |
| 4 | + |
1 | 5 | import numpy as np
|
2 | 6 |
|
3 | 7 | from wfdb.io.record import rdrecord
|
4 | 8 | from wfdb.io.convert.edf import read_edf
|
| 9 | +from wfdb.io.convert.csv import csv_to_wfdb |
| 10 | + |
5 | 11 |
|
| 12 | +class TestEdfToWfdb: |
| 13 | + """ |
| 14 | + Tests for the io.convert.edf module. |
| 15 | + """ |
6 | 16 |
|
7 |
| -class TestConvert: |
8 | 17 | def test_edf_uniform(self):
|
9 | 18 | """
|
10 | 19 | EDF format conversion to MIT for uniform sample rates.
|
11 |
| -
|
12 | 20 | """
|
13 | 21 | # Uniform sample rates
|
14 | 22 | record_MIT = rdrecord("sample-data/n16").__dict__
|
@@ -60,7 +68,6 @@ def test_edf_uniform(self):
|
60 | 68 | def test_edf_non_uniform(self):
|
61 | 69 | """
|
62 | 70 | EDF format conversion to MIT for non-uniform sample rates.
|
63 |
| -
|
64 | 71 | """
|
65 | 72 | # Non-uniform sample rates
|
66 | 73 | record_MIT = rdrecord("sample-data/wave_4").__dict__
|
@@ -108,3 +115,65 @@ def test_edf_non_uniform(self):
|
108 | 115 |
|
109 | 116 | target_results = len(fields) * [True]
|
110 | 117 | assert np.array_equal(test_results, target_results)
|
| 118 | + |
| 119 | + |
| 120 | +class TestCsvToWfdb(unittest.TestCase): |
| 121 | + """ |
| 122 | + Tests for the io.convert.csv module. |
| 123 | + """ |
| 124 | + |
| 125 | + def setUp(self): |
| 126 | + """ |
| 127 | + Create a temporary directory containing data for testing. |
| 128 | +
|
| 129 | + Load 100.dat file for comparison to 100.csv file. |
| 130 | + """ |
| 131 | + self.test_dir = "test_output" |
| 132 | + os.makedirs(self.test_dir, exist_ok=True) |
| 133 | + |
| 134 | + self.record_100_csv = "sample-data/100.csv" |
| 135 | + self.record_100_dat = rdrecord("sample-data/100", physical=True) |
| 136 | + |
| 137 | + def tearDown(self): |
| 138 | + """ |
| 139 | + Remove the temporary directory after the test. |
| 140 | + """ |
| 141 | + if os.path.exists(self.test_dir): |
| 142 | + shutil.rmtree(self.test_dir) |
| 143 | + |
| 144 | + def test_write_dir(self): |
| 145 | + """ |
| 146 | + Call the function with the write_dir argument. |
| 147 | + """ |
| 148 | + csv_to_wfdb( |
| 149 | + file_name=self.record_100_csv, |
| 150 | + fs=360, |
| 151 | + units="mV", |
| 152 | + write_dir=self.test_dir, |
| 153 | + ) |
| 154 | + |
| 155 | + # Check if the output files are created in the specified directory |
| 156 | + base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0] |
| 157 | + expected_dat_file = os.path.join(self.test_dir, f"{base_name}.dat") |
| 158 | + expected_hea_file = os.path.join(self.test_dir, f"{base_name}.hea") |
| 159 | + |
| 160 | + self.assertTrue(os.path.exists(expected_dat_file)) |
| 161 | + self.assertTrue(os.path.exists(expected_hea_file)) |
| 162 | + |
| 163 | + # Check that newly written file matches the 100.dat file |
| 164 | + record_write = rdrecord(os.path.join(self.test_dir, base_name)) |
| 165 | + |
| 166 | + self.assertEqual(record_write.fs, 360) |
| 167 | + self.assertEqual(record_write.fs, self.record_100_dat.fs) |
| 168 | + self.assertEqual(record_write.units, ["mV", "mV"]) |
| 169 | + self.assertEqual(record_write.units, self.record_100_dat.units) |
| 170 | + self.assertEqual(record_write.sig_name, ["MLII", "V5"]) |
| 171 | + self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name) |
| 172 | + self.assertEqual(record_write.p_signal.size, 1300000) |
| 173 | + self.assertEqual( |
| 174 | + record_write.p_signal.size, self.record_100_dat.p_signal.size |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +if __name__ == "__main__": |
| 179 | + unittest.main() |
0 commit comments