Skip to content

Commit f697e59

Browse files
author
Benjamin Moody
committed
Merge pull request #420 into main
Support FLAC signal formats in Record.wrsamp and wfdb.wrsamp. Additionally, if the signals cannot all be stored in a single file, generate multiple signal files as necessary.
2 parents f07bfff + b0a3d86 commit f697e59

File tree

4 files changed

+151
-27
lines changed

4 files changed

+151
-27
lines changed

tests/test_record.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_1f(self):
219219
"Mismatch in %s" % name,
220220
)
221221

222-
def test_read_flac(self):
222+
def test_read_write_flac(self):
223223
"""
224224
All FLAC formats, multiple signal files in one record.
225225
@@ -250,6 +250,28 @@ def test_read_flac(self):
250250
f"Mismatch in {name}",
251251
)
252252

253+
# Test file writing
254+
record.wrsamp()
255+
record_write = wfdb.rdrecord("flacformats", physical=False)
256+
assert record == record_write
257+
258+
def test_read_write_flac_multifrequency(self):
259+
"""
260+
Format 516 with multiple signal files and variable samples per frame.
261+
"""
262+
# Check that we can read a record and write it out again
263+
record = wfdb.rdrecord(
264+
"sample-data/mixedsignals",
265+
physical=False,
266+
smooth_frames=False,
267+
)
268+
record.wrsamp(expanded=True)
269+
270+
# Check that result matches the original
271+
record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False)
272+
record_write = wfdb.rdrecord("mixedsignals", smooth_frames=False)
273+
assert record == record_write
274+
253275
def test_read_flac_longduration(self):
254276
"""
255277
Three signals multiplexed in a FLAC file, over 2**24 samples.
@@ -628,6 +650,14 @@ def tearDownClass(cls):
628650
"100_3chan.hea",
629651
"a103l.hea",
630652
"a103l.mat",
653+
"flacformats.d0",
654+
"flacformats.d1",
655+
"flacformats.d2",
656+
"flacformats.hea",
657+
"mixedsignals.hea",
658+
"mixedsignals_e.dat",
659+
"mixedsignals_p.dat",
660+
"mixedsignals_r.dat",
631661
"s0010_re.dat",
632662
"s0010_re.hea",
633663
"s0010_re.xyz",

wfdb/io/_header.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,35 @@ def get_write_fields(self):
361361

362362
return rec_write_fields, sig_write_fields
363363

364+
def _auto_signal_file_names(self):
365+
fmt = self.fmt or [None] * self.n_sig
366+
spf = self.samps_per_frame or [None] * self.n_sig
367+
num_groups = 0
368+
group_number = []
369+
prev_fmt = prev_spf = None
370+
channels_in_group = 0
371+
372+
for ch_fmt, ch_spf in zip(fmt, spf):
373+
if ch_fmt != prev_fmt:
374+
num_groups += 1
375+
channels_in_group = 0
376+
elif ch_fmt in ("508", "516", "524"):
377+
if channels_in_group >= 8 or ch_spf != prev_spf:
378+
num_groups += 1
379+
channels_in_group = 0
380+
group_number.append(num_groups)
381+
prev_fmt = ch_fmt
382+
prev_spf = ch_spf
383+
384+
if num_groups < 2:
385+
return [self.record_name + ".dat"] * self.n_sig
386+
else:
387+
digits = len(str(group_number[-1]))
388+
return [
389+
self.record_name + "_" + str(g).rjust(digits, "0") + ".dat"
390+
for g in group_number
391+
]
392+
364393
def set_default(self, field):
365394
"""
366395
Set the object's attribute to its default value if it is missing
@@ -394,7 +423,7 @@ def set_default(self, field):
394423

395424
# Specific dynamic case
396425
if field == "file_name" and self.file_name is None:
397-
self.file_name = self.n_sig * [self.record_name + ".dat"]
426+
self.file_name = self._auto_signal_file_names()
398427
return
399428

400429
item = getattr(self, field)

wfdb/io/_signal.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -956,12 +956,11 @@ def wr_dat_files(self, expanded=False, write_dir=""):
956956
dat_offsets[fn],
957957
True,
958958
[self.e_d_signal[ch] for ch in dat_channels[fn]],
959-
self.samps_per_frame,
959+
[self.samps_per_frame[ch] for ch in dat_channels[fn]],
960960
write_dir=write_dir,
961961
)
962962
else:
963-
# Create a copy to prevent overwrite
964-
dsig = self.d_signal.copy()
963+
dsig = self.d_signal
965964
for fn in file_names:
966965
wr_dat_file(
967966
fn,
@@ -2273,16 +2272,15 @@ def wr_dat_file(
22732272
fmt : str
22742273
WFDB fmt of the dat file.
22752274
d_signal : ndarray
2276-
The digital conversion of the signal. Either a 2d numpy
2277-
array or a list of 1d numpy arrays.
2275+
The digital conversion of the signal, as a 2d numpy array.
22782276
byte_offset : int
22792277
The byte offset of the dat file.
22802278
expanded : bool, optional
22812279
Whether to transform the `e_d_signal` attribute (True) or
22822280
the `d_signal` attribute (False).
2283-
d_signal : ndarray, optional
2284-
The expanded digital conversion of the signal. Either a 2d numpy
2285-
array or a list of 1d numpy arrays.
2281+
e_d_signal : ndarray, optional
2282+
The expanded digital conversion of the signal, as a list of 1d
2283+
numpy arrays.
22862284
samps_per_frame : list, optional
22872285
The samples/frame for each signal of the dat file.
22882286
write_dir : str, optional
@@ -2293,10 +2291,19 @@ def wr_dat_file(
22932291
N/A
22942292
22952293
"""
2294+
file_path = os.path.join(write_dir, file_name)
2295+
22962296
# Combine list of arrays into single array
22972297
if expanded:
22982298
n_sig = len(e_d_signal)
2299-
sig_len = int(len(e_d_signal[0]) / samps_per_frame[0])
2299+
if len(samps_per_frame) != n_sig:
2300+
raise ValueError("mismatch between samps_per_frame and e_d_signal")
2301+
2302+
sig_len = len(e_d_signal[0]) // samps_per_frame[0]
2303+
for sig, spf in zip(e_d_signal, samps_per_frame):
2304+
if len(sig) != sig_len * spf:
2305+
raise ValueError("mismatch in lengths of expanded signals")
2306+
23002307
# Effectively create MxN signal, with extra frame samples acting
23012308
# like extra channels
23022309
d_signal = np.zeros((sig_len, sum(samps_per_frame)), dtype="int64")
@@ -2307,10 +2314,17 @@ def wr_dat_file(
23072314
for framenum in range(spf):
23082315
d_signal[:, expand_ch] = e_d_signal[ch][framenum::spf]
23092316
expand_ch = expand_ch + 1
2317+
else:
2318+
# Create a copy to prevent overwrite
2319+
d_signal = d_signal.copy()
23102320

2311-
# This n_sig is used for making list items.
2312-
# Does not necessarily represent number of signals (ie. for expanded=True)
2313-
n_sig = d_signal.shape[1]
2321+
# Non-expanded format always has 1 sample per frame
2322+
n_sig = d_signal.shape[1]
2323+
samps_per_frame = [1] * n_sig
2324+
2325+
# Total number of samples per frame (equal to number of signals if
2326+
# expanded=False, but may be greater for expanded=True)
2327+
tsamps_per_frame = d_signal.shape[1]
23142328

23152329
if fmt == "80":
23162330
# convert to 8 bit offset binary form
@@ -2368,8 +2382,8 @@ def wr_dat_file(
23682382
# convert to 16 bit two's complement
23692383
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 65536
23702384
# Split samples into separate bytes using binary masks
2371-
b1 = d_signal & [255] * n_sig
2372-
b2 = (d_signal & [65280] * n_sig) >> 8
2385+
b1 = d_signal & [255] * tsamps_per_frame
2386+
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
23732387
# Interweave the bytes so that the same samples' bytes are consecutive
23742388
b1 = b1.reshape((-1, 1))
23752389
b2 = b2.reshape((-1, 1))
@@ -2381,9 +2395,9 @@ def wr_dat_file(
23812395
# convert to 24 bit two's complement
23822396
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 16777216
23832397
# Split samples into separate bytes using binary masks
2384-
b1 = d_signal & [255] * n_sig
2385-
b2 = (d_signal & [65280] * n_sig) >> 8
2386-
b3 = (d_signal & [16711680] * n_sig) >> 16
2398+
b1 = d_signal & [255] * tsamps_per_frame
2399+
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
2400+
b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16
23872401
# Interweave the bytes so that the same samples' bytes are consecutive
23882402
b1 = b1.reshape((-1, 1))
23892403
b2 = b2.reshape((-1, 1))
@@ -2397,10 +2411,10 @@ def wr_dat_file(
23972411
# convert to 32 bit two's complement
23982412
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 4294967296
23992413
# Split samples into separate bytes using binary masks
2400-
b1 = d_signal & [255] * n_sig
2401-
b2 = (d_signal & [65280] * n_sig) >> 8
2402-
b3 = (d_signal & [16711680] * n_sig) >> 16
2403-
b4 = (d_signal & [4278190080] * n_sig) >> 24
2414+
b1 = d_signal & [255] * tsamps_per_frame
2415+
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
2416+
b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16
2417+
b4 = (d_signal & [4278190080] * tsamps_per_frame) >> 24
24042418
# Interweave the bytes so that the same samples' bytes are consecutive
24052419
b1 = b1.reshape((-1, 1))
24062420
b2 = b2.reshape((-1, 1))
@@ -2410,9 +2424,54 @@ def wr_dat_file(
24102424
b_write = b_write.reshape((1, -1))[0]
24112425
# Convert to un_signed 8 bit dtype to write
24122426
b_write = b_write.astype("uint8")
2427+
2428+
elif fmt in ("508", "516", "524"):
2429+
import soundfile
2430+
2431+
if any(spf != samps_per_frame[0] for spf in samps_per_frame):
2432+
raise ValueError(
2433+
"All channels in a FLAC signal file must have the same "
2434+
"sampling rate and samples per frame"
2435+
)
2436+
if n_sig > 8:
2437+
raise ValueError(
2438+
"A single FLAC signal file cannot contain more than 8 channels"
2439+
)
2440+
2441+
d_signal = d_signal.reshape(-1, n_sig, samps_per_frame[0])
2442+
d_signal = d_signal.transpose(0, 2, 1)
2443+
d_signal = d_signal.reshape(-1, n_sig)
2444+
2445+
if fmt == "508":
2446+
d_signal = d_signal.astype("int16")
2447+
np.left_shift(d_signal, 8, out=d_signal)
2448+
subtype = "PCM_S8"
2449+
elif fmt == "516":
2450+
d_signal = d_signal.astype("int16")
2451+
subtype = "PCM_16"
2452+
elif fmt == "524":
2453+
d_signal = d_signal.astype("int32")
2454+
np.left_shift(d_signal, 8, out=d_signal)
2455+
subtype = "PCM_24"
2456+
else:
2457+
raise ValueError(f"unknown format ({fmt})")
2458+
2459+
sf = soundfile.SoundFile(
2460+
file_path,
2461+
mode="w",
2462+
samplerate=96000,
2463+
channels=n_sig,
2464+
subtype=subtype,
2465+
format="FLAC",
2466+
)
2467+
with sf:
2468+
sf.write(d_signal)
2469+
return
2470+
24132471
else:
24142472
raise ValueError(
2415-
"This library currently only supports writing the following formats: 80, 16, 24, 32"
2473+
"This library currently only supports writing the "
2474+
"following formats: 80, 16, 24, 32, 508, 516, 524"
24162475
)
24172476

24182477
# Byte offset in the file
@@ -2427,7 +2486,7 @@ def wr_dat_file(
24272486
b_write = np.append(np.zeros(byte_offset, dtype="uint8"), b_write)
24282487

24292488
# Write the bytes to the file
2430-
with open(os.path.join(write_dir, file_name), "wb") as f:
2489+
with open(file_path, "wb") as f:
24312490
b_write.tofile(f)
24322491

24332492

wfdb/io/record.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,9 +520,15 @@ def check_field(self, field, required_channels="all"):
520520
"block_size values must be non-negative integers"
521521
)
522522
elif field == "sig_name":
523-
if re.search(r"\s", item[ch]):
523+
if item[ch][:1].isspace() or item[ch][-1:].isspace():
524+
raise ValueError(
525+
"sig_name strings may not begin or end with "
526+
"whitespace."
527+
)
528+
if re.search(r"[\x00-\x1f\x7f-\x9f]", item[ch]):
524529
raise ValueError(
525-
"sig_name strings may not contain whitespaces."
530+
"sig_name strings may not contain "
531+
"control characters."
526532
)
527533
if len(set(item)) != len(item):
528534
raise ValueError("sig_name strings must be unique.")

0 commit comments

Comments
 (0)