Skip to content

Commit 7272302

Browse files
committed
Simplify tSNR, fix division warning
1 parent 1020c80 commit 7272302

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

nipype/algorithms/misc.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,11 @@ class TSNRInputSpec(BaseInterfaceInputSpec):
262262
in_file = InputMultiPath(File(exists=True), mandatory=True,
263263
desc='realigned 4D file or a list of 3D files')
264264
regress_poly = traits.Range(low=1, desc='Remove polynomials')
265+
tsnr_file = File('tsnr.nii.gz', usedefault=True, desc='output tSNR file')
266+
mean_file = File('mean.nii.gz', usedefault=True, desc='output mean file')
267+
stddev_file = File('stdev.nii.gz', usedefault=True, desc='output tSNR file')
268+
detrended_file = File('detrend.nii.gz', usedefault=True,
269+
desc='input file after detrending')
265270

266271

267272
class TSNROutputSpec(TraitedSpec):
@@ -287,24 +292,17 @@ class TSNR(BaseInterface):
287292
input_spec = TSNRInputSpec
288293
output_spec = TSNROutputSpec
289294

290-
def _gen_output_file_name(self, suffix=None):
291-
_, base, ext = split_filename(self.inputs.in_file[0])
292-
if suffix in ['mean', 'stddev']:
293-
return os.path.abspath(base + "_tsnr_" + suffix + ext)
294-
elif suffix in ['detrended']:
295-
return os.path.abspath(base + "_" + suffix + ext)
296-
else:
297-
return os.path.abspath(base + "_tsnr" + ext)
298-
299295
def _run_interface(self, runtime):
300296
img = nb.load(self.inputs.in_file[0])
301297
header = img.get_header().copy()
302298
vollist = [nb.load(filename) for filename in self.inputs.in_file]
303299
data = np.concatenate([vol.get_data().reshape(
304300
vol.get_shape()[:3] + (-1,)) for vol in vollist], axis=3)
301+
data = data.nan_to_num()
305302
if data.dtype.kind == 'i':
306303
header.set_data_dtype(np.float32)
307304
data = data.astype(np.float32)
305+
308306
if isdefined(self.inputs.regress_poly):
309307
timepoints = img.get_shape()[-1]
310308
X = np.ones((timepoints, 1))
@@ -318,16 +316,18 @@ def _run_interface(self, runtime):
318316
0, 4)
319317
data = data - datahat
320318
img = nb.Nifti1Image(data, img.get_affine(), header)
321-
nb.save(img, self._gen_output_file_name('detrended'))
319+
nb.save(img, op.abspath(self.inputs.detrended_file))
320+
322321
meanimg = np.mean(data, axis=3)
323322
stddevimg = np.std(data, axis=3)
324-
tsnr = meanimg / stddevimg
323+
tsnr = np.zeros_like(meanimg)
324+
tsnr[stddevimg > 1.e-3] = meanimg[stddevimg > 1.e-3] / stddevimg[stddevimg > 1.e-3]
325325
img = nb.Nifti1Image(tsnr, img.get_affine(), header)
326-
nb.save(img, self._gen_output_file_name())
326+
nb.save(img, op.abspath(self.inputs.tsnr_file))
327327
img = nb.Nifti1Image(meanimg, img.get_affine(), header)
328-
nb.save(img, self._gen_output_file_name('mean'))
328+
nb.save(img, op.abspath(self.inputs.mean_file))
329329
img = nb.Nifti1Image(stddevimg, img.get_affine(), header)
330-
nb.save(img, self._gen_output_file_name('stddev'))
330+
nb.save(img, op.abspath(self.inputs.stddev_file))
331331
return runtime
332332

333333
def _list_outputs(self):

0 commit comments

Comments
 (0)