Skip to content

Commit 1e026e9

Browse files
committed
several improvements in Estimator interface
1 parent 465ad27 commit 1e026e9

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

nipype/interfaces/dipy/reconstruction.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,14 @@ class EstimateResponseSHInputSpec(DipyBaseInterfaceInputSpec):
162162
in_mask = File(
163163
exists=True, desc=('input mask in which we find single fibers'))
164164
fa_thresh = traits.Float(
165-
0.7, usedefault=True, desc=('default FA threshold'))
165+
0.7, usedefault=True, desc=('FA threshold'))
166+
roi_radius = traits.Int(
167+
10, usedefault=True, desc=('ROI radius to be used in auto_response'))
168+
auto = traits.Bool(
169+
True, usedefault=True, xor=['recursive'],
170+
desc='use the auto_response estimator from dipy')
166171
recursive = traits.Bool(
167-
False, usedefault=True,
172+
False, usedefault=True, xor=['auto'],
168173
desc='use the recursive response estimator from dipy')
169174
response = File(
170175
'response.txt', usedefault=True, desc=('the output response file'))
@@ -203,7 +208,7 @@ class EstimateResponseSH(DipyBaseInterface):
203208
def _run_interface(self, runtime):
204209
from dipy.core.gradients import GradientTable
205210
from dipy.reconst.dti import fractional_anisotropy, mean_diffusivity
206-
from dipy.reconst.csdeconv import recursive_response
211+
from dipy.reconst.csdeconv import recursive_response, auto_response
207212

208213
img = nb.load(self.inputs.in_file)
209214
affine = img.get_affine()
@@ -218,23 +223,18 @@ def _run_interface(self, runtime):
218223
data = img.get_data().astype(np.float32)
219224
gtab = self._get_gradient_table()
220225

221-
evals = nb.load(self.inputs.in_evals).get_data()
226+
evals = np.nan_to_num(nb.load(self.inputs.in_evals).get_data())
222227
FA = np.nan_to_num(fractional_anisotropy(evals)) * msk
223-
224-
if not self.inputs.recursive:
225-
indices = np.where(FA > self.inputs.fa_thresh)
226-
lambdas = evals[indices][:, :2]
227-
S0s = data[indices][:, np.nonzero(gtab.b0s_mask)[0]]
228-
S0 = np.mean(S0s)
229-
l01 = np.mean(lambdas, axis=0)
230-
respev = np.array([l01[0], l01[1], l01[1]])
231-
response = np.array(respev.tolist() + [S0]).reshape(-1)
232-
233-
ratio = abs(respev[1] / respev[0])
234-
if ratio > 0.25:
235-
iflogger.warn(('Estimated response is not prolate enough. '
236-
'Ratio=%0.3f.') % ratio)
237-
else:
228+
indices = np.where(FA > self.inputs.fa_thresh)
229+
S0s = data[indices][:, np.nonzero(gtab.b0s_mask)[0]]
230+
S0 = np.mean(S0s)
231+
232+
if self.inputs.auto:
233+
response, ratio = auto_response(gtab, data,
234+
roi_radius=self.inputs.roi_radius,
235+
fa_thr=self.inputs.fa_thresh)
236+
response = response[0].tolist() + [S0]
237+
elif self.inputs.recursive:
238238
MD = np.nan_to_num(mean_diffusivity(evals)) * msk
239239
indices = np.logical_or(
240240
FA >= 0.4, (np.logical_and(FA >= 0.15, MD >= 0.0011)))
@@ -244,6 +244,23 @@ def _run_interface(self, runtime):
244244
init_trace=0.0021, iter=8,
245245
convergence=0.001,
246246
parallel=True)
247+
ratio = abs(response[1] / response[0])
248+
else:
249+
lambdas = evals[indices]
250+
l01 = np.sort(np.mean(lambdas, axis=0))
251+
252+
response = np.array([l01[-1], l01[-2], l01[-2], S0])
253+
ratio = abs(response[1] / response[0])
254+
255+
if ratio > 0.25:
256+
iflogger.warn(('Estimated response is not prolate enough. '
257+
'Ratio=%0.3f.') % ratio)
258+
elif ratio < 1.e-5 or np.any(np.isnan(response)):
259+
response = np.array([1.8e-3, 3.6e-4, 3.6e-4, S0])
260+
iflogger.warn(
261+
('Estimated response is not valid, using a default one'))
262+
else:
263+
iflogger.info(('Estimated response: %s') % str(response[:3]))
247264

248265
np.savetxt(op.abspath(self.inputs.response), response)
249266

@@ -252,7 +269,6 @@ def _run_interface(self, runtime):
252269
nb.Nifti1Image(
253270
wm_mask.astype(np.uint8), affine,
254271
None).to_filename(op.abspath(self.inputs.out_mask))
255-
256272
return runtime
257273

258274
def _list_outputs(self):

0 commit comments

Comments
 (0)