Skip to content

Commit 80c5984

Browse files
authored
Merge pull request #1647 from shoshber/partialvoluming
ENH signal extraction interface
2 parents c4c724b + 850ed41 commit 80c5984

File tree

2 files changed

+340
-0
lines changed

2 files changed

+340
-0
lines changed

nipype/interfaces/nilearn.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# -*- coding: utf-8 -*-
2+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
3+
# vi: set ft=python sts=4 ts=4 sw=4 et:
4+
'''
5+
Algorithms to compute statistics on :abbr:`fMRI (functional MRI)`
6+
7+
Change directory to provide relative paths for doctests
8+
>>> import os
9+
>>> filepath = os.path.dirname(os.path.realpath(__file__))
10+
>>> datadir = os.path.realpath(os.path.join(filepath, '../testing/data'))
11+
>>> os.chdir(datadir)
12+
13+
'''
14+
from __future__ import (print_function, division, unicode_literals,
15+
absolute_import)
16+
17+
import numpy as np
18+
import nibabel as nb
19+
20+
from .. import logging
21+
from ..interfaces.base import (traits, TraitedSpec, BaseInterface,
22+
BaseInterfaceInputSpec, File, InputMultiPath)
23+
IFLOG = logging.getLogger('interface')
24+
25+
class SignalExtractionInputSpec(BaseInterfaceInputSpec):
26+
in_file = File(exists=True, mandatory=True, desc='4-D fMRI nii file')
27+
label_files = InputMultiPath(File(exists=True), mandatory=True,
28+
desc='a 3-D label image, with 0 denoting '
29+
'background, or a list of 3-D probability '
30+
'maps (one per label) or the equivalent 4D '
31+
'file.')
32+
class_labels = traits.List(mandatory=True,
33+
desc='Human-readable labels for each segment '
34+
'in the label file, in order. The length of '
35+
'class_labels must be equal to the number of '
36+
'segments (background excluded). This list '
37+
'corresponds to the class labels in label_file '
38+
'in ascending order')
39+
out_file = File('signals.tsv', usedefault=True, exists=False,
40+
mandatory=False, desc='The name of the file to output to. '
41+
'signals.tsv by default')
42+
incl_shared_variance = traits.Bool(True, usedefault=True, mandatory=False, desc='By default '
43+
'(True), returns simple time series calculated from each '
44+
'region independently (e.g., for noise regression). If '
45+
'False, returns unique signals for each region, discarding '
46+
'shared variance (e.g., for connectivity. Only has effect '
47+
'with 4D probability maps.')
48+
include_global = traits.Bool(False, usedefault=True, mandatory=False,
49+
desc='If True, include an extra column '
50+
'labeled "global", with values calculated from the entire brain '
51+
'(instead of just regions).')
52+
detrend = traits.Bool(False, usedefault=True, mandatory=False,
53+
desc='If True, perform detrending using nilearn.')
54+
55+
class SignalExtractionOutputSpec(TraitedSpec):
56+
out_file = File(exists=True, desc='tsv file containing the computed '
57+
'signals, with as many columns as there are labels and as '
58+
'many rows as there are timepoints in in_file, plus a '
59+
'header row with values from class_labels')
60+
61+
class SignalExtraction(BaseInterface):
62+
'''
63+
Extracts signals over tissue classes or brain regions
64+
65+
>>> seinterface = SignalExtraction()
66+
>>> seinterface.inputs.in_file = 'functional.nii'
67+
>>> seinterface.inputs.label_files = 'segmentation0.nii.gz'
68+
>>> seinterface.inputs.out_file = 'means.tsv'
69+
>>> segments = ['CSF', 'gray', 'white']
70+
>>> seinterface.inputs.class_labels = segments
71+
>>> seinterface.inputs.detrend = True
72+
>>> seinterface.inputs.include_global = True
73+
'''
74+
input_spec = SignalExtractionInputSpec
75+
output_spec = SignalExtractionOutputSpec
76+
77+
def _run_interface(self, runtime):
78+
maskers = self._process_inputs()
79+
80+
signals = []
81+
for masker in maskers:
82+
signals.append(masker.fit_transform(self.inputs.in_file))
83+
region_signals = np.hstack(signals)
84+
85+
output = np.vstack((self.inputs.class_labels, region_signals.astype(str)))
86+
87+
# save output
88+
np.savetxt(self.inputs.out_file, output, fmt=b'%s', delimiter='\t')
89+
return runtime
90+
91+
def _process_inputs(self):
92+
''' validate and process inputs into useful form.
93+
Returns a list of nilearn maskers and the list of corresponding label names.'''
94+
import nilearn.input_data as nl
95+
import nilearn.image as nli
96+
97+
label_data = nli.concat_imgs(self.inputs.label_files)
98+
maskers = []
99+
100+
# determine form of label files, choose appropriate nilearn masker
101+
if np.amax(label_data.get_data()) > 1: # 3d label file
102+
n_labels = np.amax(label_data.get_data())
103+
maskers.append(nl.NiftiLabelsMasker(label_data))
104+
else: # 4d labels
105+
n_labels = label_data.get_data().shape[3]
106+
if self.inputs.incl_shared_variance: # 4d labels, independent computation
107+
for img in nli.iter_img(label_data):
108+
maskers.append(nl.NiftiMapsMasker(self._4d(img.get_data(), img.affine)))
109+
else: # 4d labels, one computation fitting all
110+
maskers.append(nl.NiftiMapsMasker(label_data))
111+
112+
# check label list size
113+
if len(self.inputs.class_labels) != n_labels:
114+
raise ValueError('The length of class_labels {} does not '
115+
'match the number of regions {} found in '
116+
'label_files {}'.format(self.inputs.class_labels,
117+
n_labels,
118+
self.inputs.label_files))
119+
120+
if self.inputs.include_global:
121+
global_label_data = label_data.get_data().sum(axis=3) # sum across all regions
122+
global_label_data = np.rint(global_label_data).astype(int).clip(0, 1) # binarize
123+
global_label_data = self._4d(global_label_data, label_data.affine)
124+
global_masker = nl.NiftiLabelsMasker(global_label_data, detrend=self.inputs.detrend)
125+
maskers.insert(0, global_masker)
126+
self.inputs.class_labels.insert(0, 'global')
127+
128+
for masker in maskers:
129+
masker.set_params(detrend=self.inputs.detrend)
130+
131+
return maskers
132+
133+
def _4d(self, array, affine):
134+
''' takes a 3-dimensional numpy array and an affine,
135+
returns the equivalent 4th dimensional nifti file '''
136+
return nb.Nifti1Image(array[:, :, :, np.newaxis], affine)
137+
138+
def _list_outputs(self):
139+
outputs = self._outputs().get()
140+
outputs['out_file'] = self.inputs.out_file
141+
return outputs
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
import unittest
4+
import os
5+
import tempfile
6+
import shutil
7+
8+
import numpy as np
9+
10+
from ...testing import (assert_equal, utils, assert_almost_equal, raises,
11+
skipif)
12+
from .. import nilearn as iface
13+
14+
no_nilearn = True
15+
try:
16+
__import__('nilearn')
17+
no_nilearn = False
18+
except ImportError:
19+
pass
20+
21+
class TestSignalExtraction(unittest.TestCase):
22+
23+
filenames = {
24+
'in_file': 'fmri.nii',
25+
'label_files': 'labels.nii',
26+
'4d_label_file': '4dlabels.nii',
27+
'out_file': 'signals.tsv'
28+
}
29+
labels = ['csf', 'gray', 'white']
30+
global_labels = ['global'] + labels
31+
32+
def setUp(self):
33+
self.orig_dir = os.getcwd()
34+
self.temp_dir = tempfile.mkdtemp()
35+
os.chdir(self.temp_dir)
36+
37+
utils.save_toy_nii(self.fake_fmri_data, self.filenames['in_file'])
38+
utils.save_toy_nii(self.fake_label_data, self.filenames['label_files'])
39+
40+
@skipif(no_nilearn)
41+
def test_signal_extract_no_shared(self):
42+
# run
43+
iface.SignalExtraction(in_file=self.filenames['in_file'],
44+
label_files=self.filenames['label_files'],
45+
class_labels=self.labels,
46+
incl_shared_variance=False).run()
47+
# assert
48+
self.assert_expected_output(self.labels, self.base_wanted)
49+
50+
51+
@skipif(no_nilearn)
52+
@raises(ValueError)
53+
def test_signal_extr_bad_label_list(self):
54+
# run
55+
iface.SignalExtraction(in_file=self.filenames['in_file'],
56+
label_files=self.filenames['label_files'],
57+
class_labels=['bad'],
58+
incl_shared_variance=False).run()
59+
60+
@skipif(no_nilearn)
61+
def test_signal_extr_equiv_4d_no_shared(self):
62+
self._test_4d_label(self.base_wanted, self.fake_equiv_4d_label_data,
63+
incl_shared_variance=False)
64+
65+
@skipif(no_nilearn)
66+
def test_signal_extr_4d_no_shared(self):
67+
# set up & run & assert
68+
self._test_4d_label(self.fourd_wanted, self.fake_4d_label_data, incl_shared_variance=False)
69+
70+
@skipif(no_nilearn)
71+
def test_signal_extr_global_no_shared(self):
72+
# set up
73+
wanted_global = [[-4./6], [-1./6], [3./6], [-1./6], [-7./6]]
74+
for i, vals in enumerate(self.base_wanted):
75+
wanted_global[i].extend(vals)
76+
77+
# run
78+
iface.SignalExtraction(in_file=self.filenames['in_file'],
79+
label_files=self.filenames['label_files'],
80+
class_labels=self.labels,
81+
include_global=True,
82+
incl_shared_variance=False).run()
83+
84+
# assert
85+
self.assert_expected_output(self.global_labels, wanted_global)
86+
87+
@skipif(no_nilearn)
88+
def test_signal_extr_4d_global_no_shared(self):
89+
# set up
90+
wanted_global = [[3./8], [-3./8], [1./8], [-7./8], [-9./8]]
91+
for i, vals in enumerate(self.fourd_wanted):
92+
wanted_global[i].extend(vals)
93+
94+
# run & assert
95+
self._test_4d_label(wanted_global, self.fake_4d_label_data,
96+
include_global=True, incl_shared_variance=False)
97+
98+
@skipif(no_nilearn)
99+
def test_signal_extr_shared(self):
100+
# set up
101+
wanted = []
102+
for vol in range(self.fake_fmri_data.shape[3]):
103+
volume = self.fake_fmri_data[:, :, :, vol].flatten()
104+
wanted_row = []
105+
for reg in range(self.fake_4d_label_data.shape[3]):
106+
region = self.fake_4d_label_data[:, :, :, reg].flatten()
107+
wanted_row.append((volume*region).sum()/(region*region).sum())
108+
109+
wanted.append(wanted_row)
110+
# run & assert
111+
self._test_4d_label(wanted, self.fake_4d_label_data)
112+
113+
def _test_4d_label(self, wanted, fake_labels, include_global=False, incl_shared_variance=True):
114+
# set up
115+
utils.save_toy_nii(fake_labels, self.filenames['4d_label_file'])
116+
117+
# run
118+
iface.SignalExtraction(in_file=self.filenames['in_file'],
119+
label_files=self.filenames['4d_label_file'],
120+
class_labels=self.labels,
121+
incl_shared_variance=incl_shared_variance,
122+
include_global=include_global).run()
123+
124+
wanted_labels = self.global_labels if include_global else self.labels
125+
126+
# assert
127+
self.assert_expected_output(wanted_labels, wanted)
128+
129+
def assert_expected_output(self, labels, wanted):
130+
with open(self.filenames['out_file'], 'r') as output:
131+
got = [line.split() for line in output]
132+
labels_got = got.pop(0) # remove header
133+
assert_equal(labels_got, labels)
134+
assert_equal(len(got), self.fake_fmri_data.shape[3],
135+
'num rows and num volumes')
136+
# convert from string to float
137+
got = [[float(num) for num in row] for row in got]
138+
for i, time in enumerate(got):
139+
assert_equal(len(labels), len(time))
140+
for j, segment in enumerate(time):
141+
assert_almost_equal(segment, wanted[i][j], decimal=1)
142+
143+
144+
def tearDown(self):
145+
os.chdir(self.orig_dir)
146+
shutil.rmtree(self.temp_dir)
147+
148+
fake_fmri_data = np.array([[[[2, -1, 4, -2, 3],
149+
[4, -2, -5, -1, 0]],
150+
151+
[[-2, 0, 1, 4, 4],
152+
[-5, 3, -3, 1, -5]]],
153+
154+
155+
[[[2, -2, -1, -2, -5],
156+
[3, 0, 3, -5, -2]],
157+
158+
[[-4, -2, -2, 1, -2],
159+
[3, 1, 4, -3, -2]]]])
160+
161+
fake_label_data = np.array([[[1, 0],
162+
[3, 1]],
163+
164+
[[2, 0],
165+
[1, 3]]])
166+
167+
fake_equiv_4d_label_data = np.array([[[[1., 0., 0.],
168+
[0., 0., 0.]],
169+
[[0., 0., 1.],
170+
[1., 0., 0.]]],
171+
[[[0., 1., 0.],
172+
[0., 0., 0.]],
173+
[[1., 0., 0.],
174+
[0., 0., 1.]]]])
175+
176+
base_wanted = [[-2.33333, 2, .5],
177+
[0, -2, .5],
178+
[-.3333333, -1, 2.5],
179+
[0, -2, .5],
180+
[-1.3333333, -5, 1]]
181+
182+
fake_4d_label_data = np.array([[[[0.2, 0.3, 0.5],
183+
[0.1, 0.1, 0.8]],
184+
185+
[[0.1, 0.3, 0.6],
186+
[0.3, 0.4, 0.3]]],
187+
188+
[[[0.2, 0.2, 0.6],
189+
[0., 0.3, 0.7]],
190+
191+
[[0.3, 0.3, 0.4],
192+
[0.3, 0.4, 0.3]]]])
193+
194+
195+
fourd_wanted = [[-5.0652173913, -5.44565217391, 5.50543478261],
196+
[-7.02173913043, 11.1847826087, -4.33152173913],
197+
[-19.0869565217, 21.2391304348, -4.57608695652],
198+
[5.19565217391, -3.66304347826, -1.51630434783],
199+
[-12.0, 3., 0.5]]

0 commit comments

Comments
 (0)