19
19
from scipy .ndimage .measurements import center_of_mass , label
20
20
21
21
from .. import config , logging
22
- from ..utils .misc import package_check
23
22
24
- from ..interfaces .base import (BaseInterface , traits , TraitedSpec , File ,
25
- InputMultiPath , BaseInterfaceInputSpec ,
26
- isdefined )
27
- from ..utils import NUMPY_MMAP
23
+ from ..interfaces .base import (
24
+ SimpleInterface , BaseInterface , traits , TraitedSpec , File ,
25
+ InputMultiPath , BaseInterfaceInputSpec ,
26
+ isdefined )
27
+ from ..interfaces .nipy .base import NipyBaseInterface
28
28
29
29
iflogger = logging .getLogger ('interface' )
30
30
@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
383
383
File (exists = True ),
384
384
mandatory = True ,
385
385
desc = 'Test image. Requires the same dimensions as in_ref.' )
386
+ in_mask = File (exists = True , desc = 'calculate overlap only within mask' )
386
387
weighting = traits .Enum (
387
388
'none' ,
388
389
'volume' ,
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
403
404
class FuzzyOverlapOutputSpec (TraitedSpec ):
404
405
jaccard = traits .Float (desc = 'Fuzzy Jaccard Index (fJI), all the classes' )
405
406
dice = traits .Float (desc = 'Fuzzy Dice Index (fDI), all the classes' )
406
- diff_file = File (
407
- exists = True ,
408
- desc =
409
- 'resulting difference-map of all classes, using the chosen weighting' )
410
407
class_fji = traits .List (
411
408
traits .Float (),
412
409
desc = 'Array containing the fJIs of each computed class' )
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
415
412
desc = 'Array containing the fDIs of each computed class' )
416
413
417
414
418
- class FuzzyOverlap (BaseInterface ):
415
+ class FuzzyOverlap (SimpleInterface ):
419
416
"""Calculates various overlap measures between two maps, using the fuzzy
420
417
definition proposed in: Crum et al., Generalized Overlap Measures for
421
418
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
@@ -439,77 +436,75 @@ class FuzzyOverlap(BaseInterface):
439
436
output_spec = FuzzyOverlapOutputSpec
440
437
441
438
def _run_interface (self , runtime ):
442
- ncomp = len (self .inputs .in_ref )
443
- assert (ncomp == len (self .inputs .in_tst ))
444
- weights = np .ones (shape = ncomp )
445
-
446
- img_ref = np .array ([
447
- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
448
- for fname in self .inputs .in_ref
449
- ])
450
- img_tst = np .array ([
451
- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
452
- for fname in self .inputs .in_tst
453
- ])
454
-
455
- msk = np .sum (img_ref , axis = 0 )
456
- msk [msk > 0 ] = 1.0
457
- tst_msk = np .sum (img_tst , axis = 0 )
458
- tst_msk [tst_msk > 0 ] = 1.0
459
-
460
- # check that volumes are normalized
461
- # img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
462
- # img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]
463
-
464
- self ._jaccards = []
465
- volumes = []
466
-
467
- diff_im = np .zeros (img_ref .shape )
468
-
469
- for ref_comp , tst_comp , diff_comp in zip (img_ref , img_tst , diff_im ):
470
- num = np .minimum (ref_comp , tst_comp )
471
- ddr = np .maximum (ref_comp , tst_comp )
472
- diff_comp [ddr > 0 ] += 1.0 - (num [ddr > 0 ] / ddr [ddr > 0 ])
473
- self ._jaccards .append (np .sum (num ) / np .sum (ddr ))
474
- volumes .append (np .sum (ref_comp ))
475
-
476
- self ._dices = 2.0 * (np .array (self ._jaccards ) /
477
- (np .array (self ._jaccards ) + 1.0 ))
439
+ # Load data
440
+ refdata = nb .concat_images (self .inputs .in_ref ).get_data ()
441
+ tstdata = nb .concat_images (self .inputs .in_tst ).get_data ()
442
+
443
+ # Data must have same shape
444
+ if not refdata .shape == tstdata .shape :
445
+ raise RuntimeError (
446
+ 'Size of "in_tst" %s must match that of "in_ref" %s.' %
447
+ (tstdata .shape , refdata .shape ))
478
448
449
+ ncomp = refdata .shape [- 1 ]
450
+
451
+ # Load mask
452
+ mask = np .ones_like (refdata , dtype = bool )
453
+ if isdefined (self .inputs .in_mask ):
454
+ mask = nb .load (self .inputs .in_mask ).get_data ()
455
+ mask = mask > 0
456
+ mask = np .repeat (mask [..., np .newaxis ], ncomp , - 1 )
457
+ assert mask .shape == refdata .shape
458
+
459
+ # Drop data outside mask
460
+ refdata = refdata [mask ]
461
+ tstdata = tstdata [mask ]
462
+
463
+ if np .any (refdata < 0.0 ):
464
+ iflogger .warning ('Negative values encountered in "in_ref" input, '
465
+ 'taking absolute values.' )
466
+ refdata = np .abs (refdata )
467
+
468
+ if np .any (tstdata < 0.0 ):
469
+ iflogger .warning ('Negative values encountered in "in_tst" input, '
470
+ 'taking absolute values.' )
471
+ tstdata = np .abs (tstdata )
472
+
473
+ if np .any (refdata > 1.0 ):
474
+ iflogger .warning ('Values greater than 1.0 found in "in_ref" input, '
475
+ 'scaling values.' )
476
+ refdata /= refdata .max ()
477
+
478
+ if np .any (tstdata > 1.0 ):
479
+ iflogger .warning ('Values greater than 1.0 found in "in_tst" input, '
480
+ 'scaling values.' )
481
+ tstdata /= tstdata .max ()
482
+
483
+ numerators = np .atleast_2d (
484
+ np .minimum (refdata , tstdata ).reshape ((- 1 , ncomp )))
485
+ denominators = np .atleast_2d (
486
+ np .maximum (refdata , tstdata ).reshape ((- 1 , ncomp )))
487
+
488
+ jaccards = numerators .sum (axis = 0 ) / denominators .sum (axis = 0 )
489
+
490
+ # Calculate weights
491
+ weights = np .ones_like (jaccards , dtype = float )
479
492
if self .inputs .weighting != "none" :
480
- weights = 1.0 / np .array (volumes )
493
+ volumes = np .sum ((refdata + tstdata ) > 0 , axis = 1 ).reshape ((- 1 , ncomp ))
494
+ weights = 1.0 / volumes
481
495
if self .inputs .weighting == "squared_vol" :
482
496
weights = weights ** 2
483
497
484
498
weights = weights / np .sum (weights )
499
+ dices = 2.0 * jaccards / (jaccards + 1.0 )
485
500
486
- setattr (self , '_jaccard' , np .sum (weights * self ._jaccards ))
487
- setattr (self , '_dice' , np .sum (weights * self ._dices ))
488
-
489
- diff = np .zeros (diff_im [0 ].shape )
490
-
491
- for w , ch in zip (weights , diff_im ):
492
- ch [msk == 0 ] = 0
493
- diff += w * ch
494
-
495
- nb .save (
496
- nb .Nifti1Image (diff ,
497
- nb .load (self .inputs .in_ref [0 ]).affine ,
498
- nb .load (self .inputs .in_ref [0 ]).header ),
499
- self .inputs .out_file )
500
-
501
+ # Fill-in the results object
502
+ self ._results ['jaccard' ] = float (weights .dot (jaccards ))
503
+ self ._results ['dice' ] = float (weights .dot (dices ))
504
+ self ._results ['class_fji' ] = [float (v ) for v in jaccards ]
505
+ self ._results ['class_fdi' ] = [float (v ) for v in dices ]
501
506
return runtime
502
507
503
- def _list_outputs (self ):
504
- outputs = self ._outputs ().get ()
505
- for method in ("dice" , "jaccard" ):
506
- outputs [method ] = getattr (self , '_' + method )
507
- # outputs['volume_difference'] = self._volume
508
- outputs ['diff_file' ] = os .path .abspath (self .inputs .out_file )
509
- outputs ['class_fji' ] = np .array (self ._jaccards ).astype (float ).tolist ()
510
- outputs ['class_fdi' ] = self ._dices .astype (float ).tolist ()
511
- return outputs
512
-
513
508
514
509
class ErrorMapInputSpec (BaseInterfaceInputSpec ):
515
510
in_ref = File (
@@ -651,7 +646,7 @@ class SimilarityOutputSpec(TraitedSpec):
651
646
traits .Float (desc = "Similarity between volume 1 and 2, frame by frame" ))
652
647
653
648
654
- class Similarity (BaseInterface ):
649
+ class Similarity (NipyBaseInterface ):
655
650
"""Calculates similarity between two 3D or 4D volumes. Both volumes have to be in
656
651
the same coordinate system, same space within that coordinate system and
657
652
with the same voxel dimensions.
@@ -674,19 +669,8 @@ class Similarity(BaseInterface):
674
669
675
670
input_spec = SimilarityInputSpec
676
671
output_spec = SimilarityOutputSpec
677
- _have_nipy = True
678
-
679
- def __init__ (self , ** inputs ):
680
- try :
681
- package_check ('nipy' )
682
- except Exception :
683
- self ._have_nipy = False
684
- super (Similarity , self ).__init__ (** inputs )
685
672
686
673
def _run_interface (self , runtime ):
687
- if not self ._have_nipy :
688
- raise RuntimeError ('nipy is not installed' )
689
-
690
674
from nipy .algorithms .registration .histogram_registration import HistogramRegistration
691
675
from nipy .algorithms .registration .affine import Affine
692
676
0 commit comments