Skip to content

Commit 9ead3a8

Browse files
matthew-brettMarcCote
authored andcommitted
RF: refactor testing for trk header values
Also: slightly extend tests for version change.
1 parent fe3c9a5 commit 9ead3a8

File tree

2 files changed

+54
-41
lines changed

2 files changed

+54
-41
lines changed

nibabel/streamlines/tests/test_trk.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2+
import sys
23
import copy
34
import unittest
45
import numpy as np
56
from os.path import join as pjoin
67

7-
from six import BytesIO
8+
from nibabel.externals.six import BytesIO
89

910
from nibabel.testing import data_path
1011
from nibabel.testing import clear_and_catch_warnings, assert_arr_dict_equal
@@ -99,82 +100,94 @@ def test_load_complex_file(self):
99100
trk = TrkFile.load(DATA['complex_trk_fname'], lazy_load=lazy_load)
100101
assert_tractogram_equal(trk.tractogram, DATA['complex_tractogram'])
101102

103+
def trk_with_bytes(self, trk_key='simple_trk_fname', endian='<'):
104+
""" Return example trk file bytes and struct view onto bytes """
105+
with open(DATA[trk_key], 'rb') as fobj:
106+
trk_bytes = fobj.read()
107+
dt = trk_module.header_2_dtype.newbyteorder(endian)
108+
trk_struct = np.ndarray((1,), dt, buffer=trk_bytes)
109+
trk_struct.flags.writeable = True
110+
return trk_struct, trk_bytes
111+
102112
def test_load_file_with_wrong_information(self):
103113
trk_file = open(DATA['simple_trk_fname'], 'rb').read()
104114

105115
# Simulate a TRK file where `count` was not provided.
106-
count = np.array(0, dtype="int32").tostring()
107-
new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:]
108-
trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False)
116+
trk_struct, trk_bytes = self.trk_with_bytes()
117+
trk_struct[Field.NB_STREAMLINES] = 0
118+
trk = TrkFile.load(BytesIO(trk_bytes), lazy_load=False)
109119
assert_tractogram_equal(trk.tractogram, DATA['simple_tractogram'])
110120

111121
# Simulate a TRK where `vox_to_ras` is not recorded (i.e. all zeros).
112-
vox_to_ras = np.zeros((4, 4), dtype=np.float32).tostring()
113-
new_trk_file = trk_file[:440] + vox_to_ras + trk_file[440+64:]
122+
trk_struct, trk_bytes = self.trk_with_bytes()
123+
trk_struct[Field.VOXEL_TO_RASMM] = np.zeros((4, 4))
114124
with clear_and_catch_warnings(record=True, modules=[trk_module]) as w:
115-
trk = TrkFile.load(BytesIO(new_trk_file))
125+
trk = TrkFile.load(BytesIO(trk_bytes))
116126
assert_equal(len(w), 1)
117127
assert_true(issubclass(w[0].category, HeaderWarning))
118128
assert_true("identity" in str(w[0].message))
119129
assert_array_equal(trk.affine, np.eye(4))
120130

121131
# Simulate a TRK where `vox_to_ras` is invalid.
122-
vox_to_ras = np.zeros((4, 4), dtype=np.float32)
123-
vox_to_ras[3, 3] = 1
124-
vox_to_ras = vox_to_ras.tostring()
125-
new_trk_file = trk_file[:440] + vox_to_ras + trk_file[440+64:]
132+
trk_struct, trk_bytes = self.trk_with_bytes()
133+
trk_struct[Field.VOXEL_TO_RASMM] = np.diag([0, 0, 0, 1])
126134
with clear_and_catch_warnings(record=True, modules=[trk_module]) as w:
127-
assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file))
135+
assert_raises(HeaderError, TrkFile.load, BytesIO(trk_bytes))
128136

129137
# Simulate a TRK file where `voxel_order` was not provided.
130-
voxel_order = np.zeros(1, dtype="|S3").tostring()
131-
new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:]
138+
trk_struct, trk_bytes = self.trk_with_bytes()
139+
trk_struct[Field.VOXEL_ORDER] = b''
132140
with clear_and_catch_warnings(record=True, modules=[trk_module]) as w:
133-
TrkFile.load(BytesIO(new_trk_file))
141+
TrkFile.load(BytesIO(trk_bytes))
134142
assert_equal(len(w), 1)
135143
assert_true(issubclass(w[0].category, HeaderWarning))
136144
assert_true("LPS" in str(w[0].message))
137145

138146
# Simulate a TRK file with an unsupported version.
139-
version = np.int32(123).tostring()
140-
new_trk_file = trk_file[:992] + version + trk_file[992+4:]
141-
assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file))
147+
trk_struct, trk_bytes = self.trk_with_bytes()
148+
trk_struct['version'] = 123
149+
assert_raises(HeaderError, TrkFile.load, BytesIO(trk_bytes))
142150

143151
# Simulate a TRK file with a wrong hdr_size.
144-
hdr_size = np.int32(1234).tostring()
145-
new_trk_file = trk_file[:996] + hdr_size + trk_file[996+4:]
146-
assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file))
152+
trk_struct, trk_bytes = self.trk_with_bytes()
153+
trk_struct['hdr_size'] = 1234
154+
assert_raises(HeaderError, TrkFile.load, BytesIO(trk_bytes))
147155

148156
# Simulate a TRK file with a wrong scalar_name.
149-
trk_file = open(DATA['complex_trk_fname'], 'rb').read()
150-
noise = np.int32(42).tostring()
151-
new_trk_file = trk_file[:47] + noise + trk_file[47+4:]
152-
assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file))
157+
trk_struct, trk_bytes = self.trk_with_bytes('complex_trk_fname')
158+
trk_struct['scalar_name'][0, 0] = b'colors\x003\x004'
159+
assert_raises(HeaderError, TrkFile.load, BytesIO(trk_bytes))
153160

154161
# Simulate a TRK file with a wrong property_name.
155-
noise = np.int32(42).tostring()
156-
new_trk_file = trk_file[:254] + noise + trk_file[254+4:]
157-
assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file))
162+
trk_struct, trk_bytes = self.trk_with_bytes('complex_trk_fname')
163+
trk_struct['property_name'][0, 0] = b'colors\x003\x004'
164+
assert_raises(HeaderError, TrkFile.load, BytesIO(trk_bytes))
158165

159166
def test_load_trk_version_1(self):
160-
trk_file = open(DATA['simple_trk_fname'], 'rb').read()
161-
162-
# Simulate a TRK (version 1).
163-
version = np.array(1, dtype=np.int32).tostring()
164-
new_trk_file = trk_file[:992] + version + trk_file[992+4:]
167+
# Simulate and test a TRK (version 1).
168+
# First check that setting the RAS affine works in version 2.
169+
trk_struct, trk_bytes = self.trk_with_bytes()
170+
trk_struct[Field.VOXEL_TO_RASMM] = np.diag([2, 3, 4, 1])
171+
trk = TrkFile.load(BytesIO(trk_bytes))
172+
assert_array_equal(trk.affine, np.diag([2, 3, 4, 1]))
173+
# Next check that affine assumed identity if version 1.
174+
trk_struct['version'] = 1
165175
with clear_and_catch_warnings(record=True, modules=[trk_module]) as w:
166-
trk = TrkFile.load(BytesIO(new_trk_file))
176+
trk = TrkFile.load(BytesIO(trk_bytes))
167177
assert_equal(len(w), 1)
168178
assert_true(issubclass(w[0].category, HeaderWarning))
169179
assert_true("identity" in str(w[0].message))
170180
assert_array_equal(trk.affine, np.eye(4))
171181
assert_array_equal(trk.header['version'], 1)
172182

173183
def test_load_complex_file_in_big_endian(self):
174-
trk_file = open(DATA['complex_trk_big_endian_fname'], 'rb').read()
184+
trk_struct, trk_bytes = self.trk_with_bytes(
185+
'complex_trk_big_endian_fname', endian='>')
175186
# We use hdr_size as an indicator of little vs big endian.
176-
hdr_size_big_endian = np.array(1000, dtype=">i4").tostring()
177-
assert_equal(trk_file[996:996+4], hdr_size_big_endian)
187+
good_orders = '>' if sys.byteorder == 'little' else '>='
188+
hdr_size = trk_struct['hdr_size']
189+
assert_true(hdr_size.dtype.byteorder in good_orders)
190+
assert_equal(hdr_size, 1000)
178191

179192
for lazy_load in [False, True]:
180193
trk = TrkFile.load(DATA['complex_trk_big_endian_fname'],

nibabel/streamlines/trk.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ def load(cls, fileobj, lazy_load=False):
315315
data_per_point_slice = {}
316316
if hdr[Field.NB_SCALARS_PER_POINT] > 0:
317317
cpt = 0
318-
for scalar_name in hdr['scalar_name']:
319-
scalar_name, nb_scalars = decode_value_from_name(scalar_name)
318+
for scalar_field in hdr['scalar_name']:
319+
scalar_name, nb_scalars = decode_value_from_name(scalar_field)
320320

321321
if nb_scalars == 0:
322322
continue
@@ -332,8 +332,8 @@ def load(cls, fileobj, lazy_load=False):
332332
data_per_streamline_slice = {}
333333
if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0:
334334
cpt = 0
335-
for property_name in hdr['property_name']:
336-
results = decode_value_from_name(property_name)
335+
for property_field in hdr['property_name']:
336+
results = decode_value_from_name(property_field)
337337
property_name, nb_properties = results
338338

339339
if nb_properties == 0:

0 commit comments

Comments
 (0)