@@ -162,9 +162,14 @@ class EstimateResponseSHInputSpec(DipyBaseInterfaceInputSpec):
162
162
in_mask = File (
163
163
exists = True , desc = ('input mask in which we find single fibers' ))
164
164
fa_thresh = traits .Float (
165
- 0.7 , usedefault = True , desc = ('default FA threshold' ))
165
+ 0.7 , usedefault = True , desc = ('FA threshold' ))
166
+ roi_radius = traits .Int (
167
+ 10 , usedefault = True , desc = ('ROI radius to be used in auto_response' ))
168
+ auto = traits .Bool (
169
+ True , usedefault = True , xor = ['recursive' ],
170
+ desc = 'use the auto_response estimator from dipy' )
166
171
recursive = traits .Bool (
167
- False , usedefault = True ,
172
+ False , usedefault = True , xor = [ 'auto' ],
168
173
desc = 'use the recursive response estimator from dipy' )
169
174
response = File (
170
175
'response.txt' , usedefault = True , desc = ('the output response file' ))
@@ -203,7 +208,7 @@ class EstimateResponseSH(DipyBaseInterface):
203
208
def _run_interface (self , runtime ):
204
209
from dipy .core .gradients import GradientTable
205
210
from dipy .reconst .dti import fractional_anisotropy , mean_diffusivity
206
- from dipy .reconst .csdeconv import recursive_response
211
+ from dipy .reconst .csdeconv import recursive_response , auto_response
207
212
208
213
img = nb .load (self .inputs .in_file )
209
214
affine = img .get_affine ()
@@ -218,23 +223,18 @@ def _run_interface(self, runtime):
218
223
data = img .get_data ().astype (np .float32 )
219
224
gtab = self ._get_gradient_table ()
220
225
221
- evals = nb .load (self .inputs .in_evals ).get_data ()
226
+ evals = np . nan_to_num ( nb .load (self .inputs .in_evals ).get_data () )
222
227
FA = np .nan_to_num (fractional_anisotropy (evals )) * msk
223
-
224
- if not self .inputs .recursive :
225
- indices = np .where (FA > self .inputs .fa_thresh )
226
- lambdas = evals [indices ][:, :2 ]
227
- S0s = data [indices ][:, np .nonzero (gtab .b0s_mask )[0 ]]
228
- S0 = np .mean (S0s )
229
- l01 = np .mean (lambdas , axis = 0 )
230
- respev = np .array ([l01 [0 ], l01 [1 ], l01 [1 ]])
231
- response = np .array (respev .tolist () + [S0 ]).reshape (- 1 )
232
-
233
- ratio = abs (respev [1 ] / respev [0 ])
234
- if ratio > 0.25 :
235
- iflogger .warn (('Estimated response is not prolate enough. '
236
- 'Ratio=%0.3f.' ) % ratio )
237
- else :
228
+ indices = np .where (FA > self .inputs .fa_thresh )
229
+ S0s = data [indices ][:, np .nonzero (gtab .b0s_mask )[0 ]]
230
+ S0 = np .mean (S0s )
231
+
232
+ if self .inputs .auto :
233
+ response , ratio = auto_response (gtab , data ,
234
+ roi_radius = self .inputs .roi_radius ,
235
+ fa_thr = self .inputs .fa_thresh )
236
+ response = response [0 ].tolist () + [S0 ]
237
+ elif self .inputs .recursive :
238
238
MD = np .nan_to_num (mean_diffusivity (evals )) * msk
239
239
indices = np .logical_or (
240
240
FA >= 0.4 , (np .logical_and (FA >= 0.15 , MD >= 0.0011 )))
@@ -244,6 +244,23 @@ def _run_interface(self, runtime):
244
244
init_trace = 0.0021 , iter = 8 ,
245
245
convergence = 0.001 ,
246
246
parallel = True )
247
+ ratio = abs (response [1 ] / response [0 ])
248
+ else :
249
+ lambdas = evals [indices ]
250
+ l01 = np .sort (np .mean (lambdas , axis = 0 ))
251
+
252
+ response = np .array ([l01 [- 1 ], l01 [- 2 ], l01 [- 2 ], S0 ])
253
+ ratio = abs (response [1 ] / response [0 ])
254
+
255
+ if ratio > 0.25 :
256
+ iflogger .warn (('Estimated response is not prolate enough. '
257
+ 'Ratio=%0.3f.' ) % ratio )
258
+ elif ratio < 1.e-5 or np .any (np .isnan (response )):
259
+ response = np .array ([1.8e-3 , 3.6e-4 , 3.6e-4 , S0 ])
260
+ iflogger .warn (
261
+ ('Estimated response is not valid, using a default one' ))
262
+ else :
263
+ iflogger .info (('Estimated response: %s' ) % str (response [:3 ]))
247
264
248
265
np .savetxt (op .abspath (self .inputs .response ), response )
249
266
@@ -252,7 +269,6 @@ def _run_interface(self, runtime):
252
269
nb .Nifti1Image (
253
270
wm_mask .astype (np .uint8 ), affine ,
254
271
None ).to_filename (op .abspath (self .inputs .out_mask ))
255
-
256
272
return runtime
257
273
258
274
def _list_outputs (self ):
0 commit comments