diff --git a/nipype/algorithms/confounds.py b/nipype/algorithms/confounds.py index 8f7e31061b..c6378fe8a3 100644 --- a/nipype/algorithms/confounds.py +++ b/nipype/algorithms/confounds.py @@ -269,6 +269,16 @@ class FramewiseDisplacementInputSpec(BaseInterfaceInputSpec): desc="Source of movement parameters", mandatory=True, ) + metric = traits.Enum( + "L1", + "riemannian", + usedefault=True, + mandatory=True, + desc="Distance metric to apply: " + "L1 = Manhattan distance (original definition); " + "riemannian = Riemannian distance on the " + "Special Euclidean group in 3D (geodesic)", + ) radius = traits.Float( 50, usedefault=True, @@ -342,9 +352,38 @@ def _run_interface(self, runtime): arr=mpars, source=self.inputs.parameter_source, ) - diff = mpars[:-1, :6] - mpars[1:, :6] - diff[:, 3:6] *= self.inputs.radius - fd_res = np.abs(diff).sum(axis=1) + + if self.inputs.metric == "L1": + diff = mpars[:-1, :6] - mpars[1:, :6] + diff[:, 3:6] *= self.inputs.radius + fd_res = np.abs(diff).sum(axis=1) + + elif self.inputs.metric == "riemannian": + from geomstats.invariant_metric import InvariantMetric + from geomstats.special_euclidean_group import SpecialEuclideanGroup + + SE3_GROUP = SpecialEuclideanGroup(n=3) + SO3_GROUP = SE3_GROUP.rotations + DIM_TRANSLATIONS = SE3_GROUP.translations.dimension + DIM_ROTATIONS = SE3_GROUP.rotations.dimension + + so3pars = mpars[:, DIM_TRANSLATIONS:] + so3pars = SO3_GROUP.rotation_vector_from_tait_bryan_angles( + so3pars, extrinsic_or_intrinsic="extrinsic", order="zyx" + ) + + se3pars = np.hstack([so3pars, mpars[:, :DIM_TRANSLATIONS]]) + + diag_rotations = self.inputs.radius * np.ones(DIM_ROTATIONS) + diag_translations = np.ones(DIM_TRANSLATIONS) + diag = np.concatenate([diag_rotations, diag_translations]) + inner_product = np.diag(diag) + metric = InvariantMetric( + group=SE3_GROUP, + inner_product_mat_at_identity=inner_product, + left_or_right="left", + ) + fd_res = metric.dist(se3pars[:-1], se3pars[1:]) self._results = { "out_file": op.abspath(self.inputs.out_file), diff --git a/nipype/algorithms/tests/test_auto_FramewiseDisplacement.py b/nipype/algorithms/tests/test_auto_FramewiseDisplacement.py index 1bc46fba64..82f9ee7143 100644 --- a/nipype/algorithms/tests/test_auto_FramewiseDisplacement.py +++ b/nipype/algorithms/tests/test_auto_FramewiseDisplacement.py @@ -7,6 +7,7 @@ def test_FramewiseDisplacement_inputs(): figdpi=dict(usedefault=True,), figsize=dict(usedefault=True,), in_file=dict(extensions=None, mandatory=True,), + metric=dict(mandatory=True, usedefault=True,), normalize=dict(usedefault=True,), out_figure=dict(extensions=None, usedefault=True,), out_file=dict(extensions=None, usedefault=True,), diff --git a/nipype/algorithms/tests/test_confounds.py b/nipype/algorithms/tests/test_confounds.py index 29f18c9221..f242cb241a 100644 --- a/nipype/algorithms/tests/test_confounds.py +++ b/nipype/algorithms/tests/test_confounds.py @@ -15,14 +15,48 @@ except ImportError: pass +nogeomstats = True +try: + import geomstats + + nogeomstats = False +except ImportError: + pass + def test_fd(tmpdir): tempdir = tmpdir.strpath ground_truth = np.loadtxt(example_data("fsl_motion_outliers_fd.txt")) + + fdisplacement = FramewiseDisplacement( + in_file=example_data("fsl_mcflirt_movpar.txt"), + out_file=tempdir + "/fd.txt", + parameter_source="FSL", + ) + res = fdisplacement.run() + + with open(res.outputs.out_file) as all_lines: + for line in all_lines: + assert "FramewiseDisplacement" in line + break + + assert np.allclose( + ground_truth, np.loadtxt(res.outputs.out_file, skiprows=1), atol=0.16 + ) + assert np.abs(ground_truth.mean() - res.outputs.fd_average) < 1e-2 + + +@pytest.mark.skipif(nogeomstats, reason="geomstats is not installed") +def test_fd_riemannian(tmpdir): + tempdir = tmpdir.strpath + # TODO(nina): Adapt ground_truth w. SPM Euler angles convention + ground_truth = np.loadtxt(example_data("fsl_motion_outliers_fd.txt")) + fdisplacement = FramewiseDisplacement( in_file=example_data("fsl_mcflirt_movpar.txt"), out_file=tempdir + "/fd.txt", parameter_source="FSL", + metric="riemannian", ) res = fdisplacement.run() diff --git a/nipype/info.py b/nipype/info.py index 7a2e4ae70e..5d3730f146 100644 --- a/nipype/info.py +++ b/nipype/info.py @@ -178,7 +178,7 @@ def get_nipype_gitversion(): "sphinxcontrib-napoleon", ], "duecredit": ["duecredit"], - "nipy": ["nitime", "nilearn<0.5.0", "dipy", "nipy", "matplotlib"], + "nipy": ["nitime", "nilearn<0.5.0", "dipy", "nipy", "matplotlib", "geomstats>=1.1.0"], "profiler": ["psutil>=5.0"], "pybids": ["pybids>=0.7.0"], "specs": ["black"],