Skip to content

Commit c6c6a1d

Browse files
author
Ben Cipollini
committed
Add an 'axis' parameter to concat_images, plus two tests.
1 parent 96d474c commit c6c6a1d

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

nibabel/funcs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def squeeze_image(img):
8888
img.extra)
8989

9090

91-
def concat_images(images, check_affines=True):
92-
''' Concatenate images in list to single image, along last dimension
91+
def concat_images(images, check_affines=True, axis=None):
92+
''' Concatenate images in list to single image, along specified dimension
9393
9494
Parameters
9595
----------
@@ -98,7 +98,9 @@ def concat_images(images, check_affines=True):
9898
check_affines : {True, False}, optional
9999
If True, then check that all the affines for `images` are nearly
100100
the same, raising a ``ValueError`` otherwise. Default is True
101-
101+
axis : int, optional
102+
If None, concatenates on the last dimension.
103+
If not None, concatenates on the specified dimension.
102104
Returns
103105
-------
104106
concat_img : ``SpatialImage``
@@ -122,8 +124,13 @@ def concat_images(images, check_affines=True):
122124
if check_affines:
123125
if not np.all(img.affine == affine):
124126
raise ValueError('Affines do not match')
125-
out_data[i] = img.get_data()
126-
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
127+
out_data[i] = img.get_data().copy()
128+
if axis is not None:
129+
out_data = np.concatenate(out_data, axis=axis)
130+
elif np.all([d.shape[-1] == 1 for d in out_data]):
131+
out_data = np.concatenate(out_data, axis=d.ndim-1)
132+
else:
133+
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
127134
klass = img0.__class__
128135
return klass(out_data, affine, header)
129136

nibabel/tests/test_funcs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ def test_concat():
5959
for img in imgs:
6060
del(img)
6161

62+
# Test axis parameter and trailing unary dimension
63+
shape_4D = np.asarray(shape + (1,))
64+
data0 = np.arange(10).reshape(shape_4D)
65+
affine = np.eye(4)
66+
img0_mem = Nifti1Image(data0, affine)
67+
img1_mem = Nifti1Image(data0 - 10, affine)
68+
69+
concat_img1 = concat_images([img0_mem, img1_mem])
70+
expected_shape1 = shape_4D.copy()
71+
expected_shape1[-1] *= 2
72+
assert_array_equal(concat_img1.shape, expected_shape1)
73+
74+
concat_img2 = concat_images([img0_mem, img1_mem], axis=0)
75+
expected_shape2 = shape_4D.copy()
76+
expected_shape2[0] *= 2
77+
assert_array_equal(concat_img2.shape, expected_shape2)
78+
6279

6380
def test_closest_canonical():
6481
arr = np.arange(24).reshape((2,3,4,1))

0 commit comments

Comments
 (0)