Skip to content

Commit b0a8006

Browse files
committed
Merge pull request #298 from bcipolli/issue-207
MRG: Add an 'axis' parameter to concat_images, plus two tests Add ability to concatenate images over given axis, with tests. Closes #207
2 parents cf4f946 + 69aa16f commit b0a8006

File tree

2 files changed

+135
-46
lines changed

2 files changed

+135
-46
lines changed

nibabel/funcs.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,43 +88,70 @@ def squeeze_image(img):
8888
img.extra)
8989

9090

91-
def concat_images(images, check_affines=True):
92-
''' Concatenate images in list to single image, along last dimension
91+
def concat_images(images, check_affines=True, axis=None):
92+
''' Concatenate images in list to single image, along specified dimension
9393
9494
Parameters
9595
----------
9696
images : sequence
97-
sequence of ``SpatialImage`` or of filenames\s
97+
sequence of ``SpatialImage`` or filenames of the same dimensionality\s
9898
check_affines : {True, False}, optional
9999
If True, then check that all the affines for `images` are nearly
100100
the same, raising a ``ValueError`` otherwise. Default is True
101-
101+
axis : None or int, optional
102+
If None, concatenates on a new dimension. This requires all images to
103+
be the same shape. If not None, concatenates on the specified
104+
dimension. This requires all images to be the same shape, except on
105+
the specified dimension.
102106
Returns
103107
-------
104108
concat_img : ``SpatialImage``
105109
New image resulting from concatenating `images` across last
106110
dimension
107111
'''
112+
images = [load(img) if not hasattr(img, 'get_data')
113+
else img for img in images]
108114
n_imgs = len(images)
115+
if n_imgs == 0:
116+
raise ValueError("Cannot concatenate an empty list of images.")
109117
img0 = images[0]
110-
is_filename = False
111-
if not hasattr(img0, 'get_data'):
112-
img0 = load(img0)
113-
is_filename = True
114-
i0shape = img0.shape
115118
affine = img0.affine
116119
header = img0.header
117-
out_shape = (n_imgs, ) + i0shape
118-
out_data = np.empty(out_shape)
119-
for i, img in enumerate(images):
120-
if is_filename:
121-
img = load(img)
122-
if check_affines:
123-
if not np.all(img.affine == affine):
124-
raise ValueError('Affines do not match')
125-
out_data[i] = img.get_data()
126-
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
127120
klass = img0.__class__
121+
shape0 = img0.shape
122+
n_dim = len(shape0)
123+
if axis is None:
124+
# collect images in output array for efficiency
125+
out_shape = (n_imgs, ) + shape0
126+
out_data = np.empty(out_shape)
127+
else:
128+
# collect images in list for use with np.concatenate
129+
out_data = [None] * n_imgs
130+
# Get part of shape we need to check inside loop
131+
idx_mask = np.ones((n_dim,), dtype=bool)
132+
if axis is not None:
133+
idx_mask[axis] = False
134+
masked_shape = np.array(shape0)[idx_mask]
135+
for i, img in enumerate(images):
136+
if len(img.shape) != n_dim:
137+
raise ValueError(
138+
'Image {0} has {1} dimensions, image 0 has {2}'.format(
139+
i, len(img.shape), n_dim))
140+
if not np.all(np.array(img.shape)[idx_mask] == masked_shape):
141+
raise ValueError('shape {0} for image {1} not compatible with '
142+
'first image shape {2} with axis == {0}'.format(
143+
img.shape, i, shape0, axis))
144+
if check_affines and not np.all(img.affine == affine):
145+
raise ValueError('Affine for image {0} does not match affine '
146+
'for first image'.format(i))
147+
# Do not fill cache in image if it is empty
148+
out_data[i] = img.get_data(caching='unchanged')
149+
150+
if axis is None:
151+
out_data = np.rollaxis(out_data, 0, out_data.ndim)
152+
else:
153+
out_data = np.concatenate(out_data, axis=axis)
154+
128155
return klass(out_data, affine, header)
129156

130157

nibabel/tests/test_funcs.py

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,96 @@ def _as_fname(img):
3030

3131

3232
def test_concat():
33-
shape = (1,2,5)
34-
data0 = np.arange(10).reshape(shape)
33+
# Smoke test: concat empty list.
34+
assert_raises(ValueError, concat_images, [])
35+
36+
# Build combinations of 3D, 4D w/size[3] == 1, and 4D w/size[3] == 3
37+
all_shapes_5D = ((1, 4, 5, 3, 3),
38+
(7, 3, 1, 4, 5),
39+
(0, 2, 1, 4, 5))
40+
3541
affine = np.eye(4)
36-
img0_mem = Nifti1Image(data0, affine)
37-
data1 = data0 - 10
38-
img1_mem = Nifti1Image(data1, affine)
39-
img2_mem = Nifti1Image(data1, affine+1)
40-
img3_mem = Nifti1Image(data1.T, affine)
41-
all_data = np.concatenate(
42-
[data0[:,:,:,np.newaxis],data1[:,:,:,np.newaxis]],3)
43-
# Check filenames and in-memory images work
44-
with InTemporaryDirectory():
45-
imgs = [img0_mem, img1_mem, img2_mem, img3_mem]
46-
img_files = [_as_fname(img) for img in imgs]
47-
for img0, img1, img2, img3 in (imgs, img_files):
48-
all_imgs = concat_images([img0, img1])
49-
assert_array_equal(all_imgs.get_data(), all_data)
50-
assert_array_equal(all_imgs.affine, affine)
51-
# check that not-matching affines raise error
52-
assert_raises(ValueError, concat_images, [img0, img2])
53-
assert_raises(ValueError, concat_images, [img0, img3])
54-
# except if check_affines is False
55-
all_imgs = concat_images([img0, img1])
56-
assert_array_equal(all_imgs.get_data(), all_data)
57-
assert_array_equal(all_imgs.affine, affine)
58-
# Delete images as prophylaxis for windows access errors
59-
for img in imgs:
60-
del(img)
42+
for dim in range(2, 6):
43+
all_shapes_ND = tuple((shape[:dim] for shape in all_shapes_5D))
44+
all_shapes_N1D_unary = tuple((shape + (1,) for shape in all_shapes_ND))
45+
all_shapes = all_shapes_ND + all_shapes_N1D_unary
46+
47+
# Loop over all possible combinations of images, in first and
48+
# second position.
49+
for data0_shape in all_shapes:
50+
data0_numel = np.asarray(data0_shape).prod()
51+
data0 = np.arange(data0_numel).reshape(data0_shape)
52+
img0_mem = Nifti1Image(data0, affine)
53+
54+
for data1_shape in all_shapes:
55+
data1_numel = np.asarray(data1_shape).prod()
56+
data1 = np.arange(data1_numel).reshape(data1_shape)
57+
img1_mem = Nifti1Image(data1, affine)
58+
img2_mem = Nifti1Image(data1, affine+1) # bad affine
59+
60+
# Loop over every possible axis, including None (explicit and implied)
61+
for axis in (list(range(-(dim-2), (dim-1))) + [None, '__default__']):
62+
63+
# Allow testing default vs. passing explicit param
64+
if axis == '__default__':
65+
np_concat_kwargs = dict(axis=-1)
66+
concat_imgs_kwargs = dict()
67+
axis = None # Convert downstream
68+
elif axis is None:
69+
np_concat_kwargs = dict(axis=-1)
70+
concat_imgs_kwargs = dict(axis=axis)
71+
else:
72+
np_concat_kwargs = dict(axis=axis)
73+
concat_imgs_kwargs = dict(axis=axis)
74+
75+
# Create expected output
76+
try:
77+
# Error will be thrown if the np.concatenate fails.
78+
# However, when axis=None, the concatenate is possible
79+
# but our efficient logic (where all images are
80+
# 3D and the same size) fails, so we also
81+
# have to expect errors for those.
82+
if axis is None: # 3D from here and below
83+
all_data = np.concatenate([data0[..., np.newaxis],
84+
data1[..., np.newaxis]],
85+
**np_concat_kwargs)
86+
else: # both 3D, appending on final axis
87+
all_data = np.concatenate([data0, data1],
88+
**np_concat_kwargs)
89+
expect_error = False
90+
except ValueError:
91+
# Shapes are not combinable
92+
expect_error = True
93+
94+
# Check filenames and in-memory images work
95+
with InTemporaryDirectory():
96+
# Try mem-based, file-based, and mixed
97+
imgs = [img0_mem, img1_mem, img2_mem]
98+
img_files = [_as_fname(img) for img in imgs]
99+
imgs_mixed = [imgs[0], img_files[1], imgs[2]]
100+
for img0, img1, img2 in (imgs, img_files, imgs_mixed):
101+
try:
102+
all_imgs = concat_images([img0, img1],
103+
**concat_imgs_kwargs)
104+
except ValueError as ve:
105+
assert_true(expect_error, str(ve))
106+
else:
107+
assert_false(expect_error, "Expected a concatenation error, but got none.")
108+
assert_array_equal(all_imgs.get_data(), all_data)
109+
assert_array_equal(all_imgs.affine, affine)
110+
111+
# check that not-matching affines raise error
112+
assert_raises(ValueError, concat_images, [img0, img2], **concat_imgs_kwargs)
113+
114+
# except if check_affines is False
115+
try:
116+
all_imgs = concat_images([img0, img1], **concat_imgs_kwargs)
117+
except ValueError as ve:
118+
assert_true(expect_error, str(ve))
119+
else:
120+
assert_false(expect_error, "Expected a concatenation error, but got none.")
121+
assert_array_equal(all_imgs.get_data(), all_data)
122+
assert_array_equal(all_imgs.affine, affine)
61123

62124

63125
def test_closest_canonical():

0 commit comments

Comments
 (0)