Skip to content

Commit f564d18

Browse files
committed
add average error to ErrorMap
1 parent fcae401 commit f564d18

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

nipype/algorithms/metrics.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,12 +506,13 @@ class ErrorMap(BaseInterface):
506506
_out_file = ''
507507

508508
def _run_interface(self, runtime):
509-
from scipy.spatial.distance import cdist, pdist
509+
# Get two numpy data matrices
510510
nii_ref = nb.load(self.inputs.in_ref)
511511
ref_data = np.squeeze(nii_ref.get_data())
512512
tst_data = np.squeeze(nb.load(self.inputs.in_tst).get_data())
513513
assert(ref_data.ndim == tst_data.ndim)
514514

515+
# Load mask
515516
comps = 1
516517
mapshape = ref_data.shape
517518

@@ -528,12 +529,14 @@ def _run_interface(self, runtime):
528529
else:
529530
msk = np.ones(shape=mapshape)
530531

532+
# Vectorise both volumes and make the pixel differennce
531533
mskvector = msk.reshape(-1)
532534
msk_idxs = np.where(mskvector==1)
533535
refvector = ref_data.reshape(-1,comps)[msk_idxs].astype(np.float32)
534536
tstvector = tst_data.reshape(-1,comps)[msk_idxs].astype(np.float32)
535537
diffvector = (refvector-tstvector)
536538

539+
# scale the diffrernce
537540
if self.inputs.metric == 'sqeuclidean':
538541
errvector = diffvector**2
539542
elif self.inputs.metric == 'euclidean':
@@ -548,6 +551,12 @@ def _run_interface(self, runtime):
548551
errvectorexp = np.zeros_like(mskvector)
549552
errvectorexp[msk_idxs] = errvector
550553

554+
# Get averaged error
555+
if self.inputs.metric == 'sqeuclidean':
556+
self._distance = np.sqrt(np.sum(errvectorexp))
557+
elif self.inputs.metric == 'euclidean':
558+
self._distance = np.average(errvectorexp)
559+
551560
errmap = errvectorexp.reshape(mapshape)
552561

553562
hdr = nii_ref.get_header().copy()
@@ -567,11 +576,12 @@ def _run_interface(self, runtime):
567576
nb.Nifti1Image(errmap.astype(np.float32), nii_ref.get_affine(),
568577
hdr).to_filename(self._out_file)
569578

570-
return runtime
579+
return runtime
571580

572581
def _list_outputs(self):
573582
outputs = self.output_spec().get()
574583
outputs['out_map'] = self._out_file
584+
outputs['distance'] = self._distance
575585
return outputs
576586

577587

0 commit comments

Comments
 (0)