Skip to content

Commit 74e719f

Browse files
committed
TEST: Test Grid methods
1 parent af4e08a commit 74e719f

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

nibabel/tests/test_pointset.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from math import prod
12
from pathlib import Path
23
from unittest import skipUnless
34

@@ -7,8 +8,10 @@
78
from nibabel import pointset as ps
89
from nibabel.affines import apply_affine
910
from nibabel.arrayproxy import ArrayProxy
11+
from nibabel.fileslice import strided_scalar
1012
from nibabel.onetime import auto_attr
1113
from nibabel.optpkg import optional_package
14+
from nibabel.spatialimages import SpatialImage
1215
from nibabel.tests.nibabel_data import get_nibabel_data
1316

1417
h5, has_h5py, _ = optional_package('h5py')
@@ -136,6 +139,57 @@ def test_GridIndices():
136139
assert np.array_equal(gi_arr, np.mgrid[:2, :3, :4].reshape(3, -1).T)
137140

138141

142+
class TestGrids(TestPointsets):
143+
@pytest.mark.parametrize('shape', [(5, 5, 5), (5, 5, 5, 5), (5, 5, 5, 5, 5)])
144+
def test_from_image(self, shape):
145+
# Check image is generates voxel coordinates
146+
affine = np.diag([2, 3, 4, 1])
147+
img = SpatialImage(strided_scalar(shape), affine)
148+
grid = ps.Grid.from_image(img)
149+
grid_coords = grid.get_coords()
150+
151+
assert grid.shape == (prod(shape[:3]), 3)
152+
assert np.allclose(grid.affine, affine)
153+
154+
assert np.allclose(grid_coords[0], [0, 0, 0])
155+
# Final index is [4, 4, 4], scaled by affine
156+
assert np.allclose(grid_coords[-1], [8, 12, 16])
157+
158+
def test_from_mask(self):
159+
affine = np.diag([2, 3, 4, 1])
160+
mask = np.zeros((3, 3, 3))
161+
mask[1, 1, 1] = 1
162+
img = SpatialImage(mask, affine)
163+
164+
grid = ps.Grid.from_mask(img)
165+
grid_coords = grid.get_coords()
166+
167+
assert grid.shape == (1, 3)
168+
assert np.array_equal(grid_coords, [[2, 3, 4]])
169+
170+
def test_to_mask(self):
171+
coords = np.array([[1, 1, 1]])
172+
173+
grid = ps.Grid(coords)
174+
175+
mask_img = grid.to_mask()
176+
assert mask_img.shape == (2, 2, 2)
177+
assert np.array_equal(mask_img.get_fdata(), [[[0, 0], [0, 0]], [[0, 0], [0, 1]]])
178+
assert np.array_equal(mask_img.affine, np.eye(4))
179+
180+
mask_img = grid.to_mask(shape=(3, 3, 3))
181+
assert mask_img.shape == (3, 3, 3)
182+
assert np.array_equal(
183+
mask_img.get_fdata(),
184+
[
185+
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
186+
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
187+
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
188+
],
189+
)
190+
assert np.array_equal(mask_img.affine, np.eye(4))
191+
192+
139193
class H5ArrayProxy:
140194
def __init__(self, file_like, dataset_name):
141195
self.file_like = file_like

0 commit comments

Comments
 (0)