Skip to content

Commit 1d192c5

Browse files
author
Benjamin Moody
committed
Add capability to write signal with unique samps_per_frame to wfdb.io.wrsamp (#510)
This PR adds the capability for writing signals with unique samples per frame (`samps_per_frame`) to `wfdb.io.wrsamp`. This is typically the function that is used to write WFDB files. This was previously only possible to do by creating a Record first and using its `wrsamp` method to do the write. I've added a couple of tests to check that this continues to work as expected.
2 parents 86028c4 + 6b0d317 commit 1d192c5

File tree

3 files changed

+281
-39
lines changed

3 files changed

+281
-39
lines changed

tests/test_record.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@ class TestRecord(unittest.TestCase):
2020
2121
"""
2222

23+
wrsamp_params = [
24+
"record_name",
25+
"fs",
26+
"units",
27+
"sig_name",
28+
"p_signal",
29+
"d_signal",
30+
"e_p_signal",
31+
"e_d_signal",
32+
"samps_per_frame",
33+
"fmt",
34+
"adc_gain",
35+
"baseline",
36+
"comments",
37+
"base_time",
38+
"base_date",
39+
"base_datetime",
40+
]
41+
2342
# ----------------------- 1. Basic Tests -----------------------#
2443

2544
def test_1a(self):
@@ -307,6 +326,172 @@ def test_read_write_flac_multifrequency(self):
307326
)
308327
assert record == record_write
309328

329+
def test_unique_samps_per_frame_e_p_signal(self):
330+
"""
331+
Test writing an e_p_signal with wfdb.io.wrsamp where the signals have different samples per frame. All other
332+
parameters which overlap between a Record object and wfdb.io.wrsamp are also checked.
333+
"""
334+
# Read in a record with different samples per frame
335+
record = wfdb.rdrecord(
336+
"sample-data/mixedsignals",
337+
smooth_frames=False,
338+
)
339+
340+
# Write the signals
341+
wfdb.io.wrsamp(
342+
"mixedsignals",
343+
fs=record.fs,
344+
units=record.units,
345+
sig_name=record.sig_name,
346+
base_date=record.base_date,
347+
base_time=record.base_time,
348+
comments=record.comments,
349+
p_signal=record.p_signal,
350+
d_signal=record.d_signal,
351+
e_p_signal=record.e_p_signal,
352+
e_d_signal=record.e_d_signal,
353+
samps_per_frame=record.samps_per_frame,
354+
baseline=record.baseline,
355+
adc_gain=record.adc_gain,
356+
fmt=record.fmt,
357+
write_dir=self.temp_path,
358+
)
359+
360+
# Check that the written record matches the original
361+
# Read in the original and written records
362+
record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False)
363+
record_write = wfdb.rdrecord(
364+
os.path.join(self.temp_path, "mixedsignals"),
365+
smooth_frames=False,
366+
)
367+
368+
# Check that the signals match
369+
for n, name in enumerate(record.sig_name):
370+
np.testing.assert_array_equal(
371+
record.e_p_signal[n],
372+
record_write.e_p_signal[n],
373+
f"Mismatch in {name}",
374+
)
375+
376+
# Filter out the signal
377+
record_filtered = {
378+
k: getattr(record, k)
379+
for k in self.wrsamp_params
380+
if not (
381+
isinstance(getattr(record, k), np.ndarray)
382+
or (
383+
isinstance(getattr(record, k), list)
384+
and all(
385+
isinstance(item, np.ndarray)
386+
for item in getattr(record, k)
387+
)
388+
)
389+
)
390+
}
391+
392+
record_write_filtered = {
393+
k: getattr(record_write, k)
394+
for k in self.wrsamp_params
395+
if not (
396+
isinstance(getattr(record_write, k), np.ndarray)
397+
or (
398+
isinstance(getattr(record_write, k), list)
399+
and all(
400+
isinstance(item, np.ndarray)
401+
for item in getattr(record_write, k)
402+
)
403+
)
404+
)
405+
}
406+
407+
# Check that the arguments beyond the signals also match
408+
assert record_filtered == record_write_filtered
409+
410+
def test_unique_samps_per_frame_e_d_signal(self):
411+
"""
412+
Test writing an e_d_signal with wfdb.io.wrsamp where the signals have different samples per frame. All other
413+
parameters which overlap between a Record object and wfdb.io.wrsamp are also checked.
414+
"""
415+
# Read in a record with different samples per frame
416+
record = wfdb.rdrecord(
417+
"sample-data/mixedsignals",
418+
physical=False,
419+
smooth_frames=False,
420+
)
421+
422+
# Write the signals
423+
wfdb.io.wrsamp(
424+
"mixedsignals",
425+
fs=record.fs,
426+
units=record.units,
427+
sig_name=record.sig_name,
428+
base_date=record.base_date,
429+
base_time=record.base_time,
430+
comments=record.comments,
431+
p_signal=record.p_signal,
432+
d_signal=record.d_signal,
433+
e_p_signal=record.e_p_signal,
434+
e_d_signal=record.e_d_signal,
435+
samps_per_frame=record.samps_per_frame,
436+
baseline=record.baseline,
437+
adc_gain=record.adc_gain,
438+
fmt=record.fmt,
439+
write_dir=self.temp_path,
440+
)
441+
442+
# Check that the written record matches the original
443+
# Read in the original and written records
444+
record = wfdb.rdrecord(
445+
"sample-data/mixedsignals", physical=False, smooth_frames=False
446+
)
447+
record_write = wfdb.rdrecord(
448+
os.path.join(self.temp_path, "mixedsignals"),
449+
physical=False,
450+
smooth_frames=False,
451+
)
452+
453+
# Check that the signals match
454+
for n, name in enumerate(record.sig_name):
455+
np.testing.assert_array_equal(
456+
record.e_d_signal[n],
457+
record_write.e_d_signal[n],
458+
f"Mismatch in {name}",
459+
)
460+
461+
# Filter out the signal
462+
record_filtered = {
463+
k: getattr(record, k)
464+
for k in self.wrsamp_params
465+
if not (
466+
isinstance(getattr(record, k), np.ndarray)
467+
or (
468+
isinstance(getattr(record, k), list)
469+
and all(
470+
isinstance(item, np.ndarray)
471+
for item in getattr(record, k)
472+
)
473+
)
474+
)
475+
}
476+
477+
record_write_filtered = {
478+
k: getattr(record_write, k)
479+
for k in self.wrsamp_params
480+
if not (
481+
isinstance(getattr(record_write, k), np.ndarray)
482+
or (
483+
isinstance(getattr(record_write, k), list)
484+
and all(
485+
isinstance(item, np.ndarray)
486+
for item in getattr(record_write, k)
487+
)
488+
)
489+
)
490+
}
491+
492+
# Check that the arguments beyond the signals also match
493+
assert record_filtered == record_write_filtered
494+
310495
def test_read_write_flac_many_channels(self):
311496
"""
312497
Check we can read and write to format 516 with more than 8 channels.

wfdb/io/_signal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def set_d_features(self, do_adc=False, single_fmt=True, expanded=False):
433433
self.check_field("baseline", "all")
434434

435435
# All required fields are present and valid. Perform ADC
436-
self.d_signal = self.adc(expanded)
436+
self.e_d_signal = self.adc(expanded)
437437

438438
# Use e_d_signal to set fields
439439
self.check_field("e_d_signal", "all")

wfdb/io/record.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,6 +2822,9 @@ def wrsamp(
28222822
sig_name,
28232823
p_signal=None,
28242824
d_signal=None,
2825+
e_p_signal=None,
2826+
e_d_signal=None,
2827+
samps_per_frame=None,
28252828
fmt=None,
28262829
adc_gain=None,
28272830
baseline=None,
@@ -2860,6 +2863,14 @@ def wrsamp(
28602863
file(s). The dtype must be an integer type. Either p_signal or d_signal
28612864
must be set, but not both. In addition, if d_signal is set, fmt, gain
28622865
and baseline must also all be set.
2866+
e_p_signal : ndarray, optional
2867+
The expanded physical conversion of the signal. Either a 2d numpy
2868+
array or a list of 1d numpy arrays.
2869+
e_d_signal : ndarray, optional
2870+
The expanded digital conversion of the signal. Either a 2d numpy
2871+
array or a list of 1d numpy arrays.
2872+
samps_per_frame : int or list of ints, optional
2873+
The total number of samples per frame.
28632874
fmt : list, optional
28642875
A list of strings giving the WFDB format of each file used to store each
28652876
channel. Accepted formats are: '80','212','16','24', and '32'. There are
@@ -2911,59 +2922,105 @@ def wrsamp(
29112922
if "." in record_name:
29122923
raise Exception("Record name must not contain '.'")
29132924
# Check input field combinations
2914-
if p_signal is not None and d_signal is not None:
2925+
signal_list = [p_signal, d_signal, e_p_signal, e_d_signal]
2926+
signals_set = sum(1 for var in signal_list if var is not None)
2927+
if signals_set != 1:
29152928
raise Exception(
2916-
"Must only give one of the inputs: p_signal or d_signal"
2929+
"Must provide one and only one input signal: p_signal, d_signal, e_p_signal, or e_d_signal"
29172930
)
2918-
if d_signal is not None:
2931+
if d_signal is not None or e_d_signal is not None:
29192932
if fmt is None or adc_gain is None or baseline is None:
29202933
raise Exception(
2921-
"When using d_signal, must also specify 'fmt', 'gain', and 'baseline' fields."
2934+
"When using d_signal or e_d_signal, must also specify 'fmt', 'gain', and 'baseline' fields"
29222935
)
2923-
# Depending on whether d_signal or p_signal was used, set other
2924-
# required features.
2925-
if p_signal is not None:
2926-
# Create the Record object
2927-
record = Record(
2928-
record_name=record_name,
2929-
p_signal=p_signal,
2930-
fs=fs,
2931-
fmt=fmt,
2932-
units=units,
2933-
sig_name=sig_name,
2934-
adc_gain=adc_gain,
2935-
baseline=baseline,
2936-
comments=comments,
2937-
base_time=base_time,
2938-
base_date=base_date,
2939-
base_datetime=base_datetime,
2936+
if (
2937+
e_p_signal is not None or e_d_signal is not None
2938+
) and samps_per_frame is None:
2939+
raise Exception(
2940+
"When passing e_p_signal or e_d_signal, you also need to specify samples per frame for each channel"
2941+
)
2942+
2943+
# If samps_per_frame is provided, check that it aligns as expected with the channels in the signal
2944+
if samps_per_frame:
2945+
# Get the number of elements being passed in samps_per_frame
2946+
samps_per_frame_length = (
2947+
len(samps_per_frame) if isinstance(samps_per_frame, list) else 1
29402948
)
2949+
# Get properties of the signal being passed
2950+
first_valid_signal = next(
2951+
signal for signal in signal_list if signal is not None
2952+
)
2953+
if isinstance(first_valid_signal, np.ndarray):
2954+
num_sig_channels = first_valid_signal.shape[1]
2955+
channel_samples = [
2956+
first_valid_signal.shape[0]
2957+
] * first_valid_signal.shape[1]
2958+
elif isinstance(first_valid_signal, list):
2959+
num_sig_channels = len(first_valid_signal)
2960+
channel_samples = [len(channel) for channel in first_valid_signal]
2961+
else:
2962+
raise TypeError(
2963+
"Unsupported signal format. Must be ndarray or list of lists."
2964+
)
2965+
# Check that the number of channels matches the number of samps_per_frame entries
2966+
if num_sig_channels != samps_per_frame_length:
2967+
raise Exception(
2968+
"When passing samps_per_frame, it must have the same number of entries as the signal has channels"
2969+
)
2970+
# Check that the number of frames is the same across all channels
2971+
frames = [a / b for a, b in zip(channel_samples, samps_per_frame)]
2972+
if len(set(frames)) > 1:
2973+
raise Exception(
2974+
"The number of samples in a channel divided by the corresponding samples_per_frame entry must be uniform"
2975+
)
2976+
2977+
# Create the Record object
2978+
record = Record(
2979+
record_name=record_name,
2980+
p_signal=p_signal,
2981+
d_signal=d_signal,
2982+
e_p_signal=e_p_signal,
2983+
e_d_signal=e_d_signal,
2984+
samps_per_frame=samps_per_frame,
2985+
fs=fs,
2986+
fmt=fmt,
2987+
units=units,
2988+
sig_name=sig_name,
2989+
adc_gain=adc_gain,
2990+
baseline=baseline,
2991+
comments=comments,
2992+
base_time=base_time,
2993+
base_date=base_date,
2994+
base_datetime=base_datetime,
2995+
)
2996+
2997+
# Depending on which signal was used, set other required fields.
2998+
if p_signal is not None:
29412999
# Compute optimal fields to store the digital signal, carry out adc,
29423000
# and set the fields.
29433001
record.set_d_features(do_adc=1)
2944-
else:
2945-
# Create the Record object
2946-
record = Record(
2947-
record_name=record_name,
2948-
d_signal=d_signal,
2949-
fs=fs,
2950-
fmt=fmt,
2951-
units=units,
2952-
sig_name=sig_name,
2953-
adc_gain=adc_gain,
2954-
baseline=baseline,
2955-
comments=comments,
2956-
base_time=base_time,
2957-
base_date=base_date,
2958-
base_datetime=base_datetime,
2959-
)
3002+
elif d_signal is not None:
29603003
# Use d_signal to set the fields directly
29613004
record.set_d_features()
3005+
elif e_p_signal is not None:
3006+
# Compute optimal fields to store the digital signal, carry out adc,
3007+
# and set the fields.
3008+
record.set_d_features(do_adc=1, expanded=True)
3009+
elif e_d_signal is not None:
3010+
# Use e_d_signal to set the fields directly
3011+
record.set_d_features(expanded=True)
29623012

29633013
# Set default values of any missing field dependencies
29643014
record.set_defaults()
3015+
3016+
# Determine whether the signal is expanded
3017+
if (e_d_signal or e_p_signal) is not None:
3018+
expanded = True
3019+
else:
3020+
expanded = False
3021+
29653022
# Write the record files - header and associated dat
2966-
record.wrsamp(write_dir=write_dir)
3023+
record.wrsamp(write_dir=write_dir, expanded=expanded)
29673024

29683025

29693026
def dl_database(

0 commit comments

Comments
 (0)