Skip to content

Add an 'axis' parameter to concat_images, plus two tests. #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 46 additions & 19 deletions nibabel/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,43 +88,70 @@ def squeeze_image(img):
img.extra)


def concat_images(images, check_affines=True):
''' Concatenate images in list to single image, along last dimension
def concat_images(images, check_affines=True, axis=None):
''' Concatenate images in list to single image, along specified dimension
Parameters
----------
images : sequence
sequence of ``SpatialImage`` or of filenames\s
sequence of ``SpatialImage`` or filenames of the same dimensionality\s
check_affines : {True, False}, optional
If True, then check that all the affines for `images` are nearly
the same, raising a ``ValueError`` otherwise. Default is True
axis : None or int, optional
If None, concatenates on a new dimension. This requires all images to
be the same shape. If not None, concatenates on the specified
dimension. This requires all images to be the same shape, except on
the specified dimension.
Returns
-------
concat_img : ``SpatialImage``
New image resulting from concatenating `images` across last
dimension
'''
images = [load(img) if not hasattr(img, 'get_data')
else img for img in images]
n_imgs = len(images)
if n_imgs == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop this check? I guess if they pass in an empty list they can expect an empty list back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, and currently, this throws an error. I added the check because the error did not indicate the issue clearly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

raise ValueError("Cannot concatenate an empty list of images.")
img0 = images[0]
is_filename = False
if not hasattr(img0, 'get_data'):
img0 = load(img0)
is_filename = True
i0shape = img0.shape
affine = img0.affine
header = img0.header
out_shape = (n_imgs, ) + i0shape
out_data = np.empty(out_shape)
for i, img in enumerate(images):
if is_filename:
img = load(img)
if check_affines:
if not np.all(img.affine == affine):
raise ValueError('Affines do not match')
out_data[i] = img.get_data()
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
klass = img0.__class__
shape0 = img0.shape
n_dim = len(shape0)
if axis is None:
# collect images in output array for efficiency
out_shape = (n_imgs, ) + shape0
out_data = np.empty(out_shape)
else:
# collect images in list for use with np.concatenate
out_data = [None] * n_imgs
# Get part of shape we need to check inside loop
idx_mask = np.ones((n_dim,), dtype=bool)
if axis is not None:
idx_mask[axis] = False
masked_shape = np.array(shape0)[idx_mask]
for i, img in enumerate(images):
if len(img.shape) != n_dim:
raise ValueError(
'Image {0} has {1} dimensions, image 0 has {2}'.format(
i, len(img.shape), n_dim))
if not np.all(np.array(img.shape)[idx_mask] == masked_shape):
raise ValueError('shape {0} for image {1} not compatible with '
'first image shape {2} with axis == {0}'.format(
img.shape, i, shape0, axis))
if check_affines and not np.all(img.affine == affine):
raise ValueError('Affine for image {0} does not match affine '
'for first image'.format(i))
# Do not fill cache in image if it is empty
out_data[i] = img.get_data(caching='unchanged')

if axis is None:
out_data = np.rollaxis(out_data, 0, out_data.ndim)
else:
out_data = np.concatenate(out_data, axis=axis)

return klass(out_data, affine, header)


Expand Down
116 changes: 89 additions & 27 deletions nibabel/tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,96 @@ def _as_fname(img):


def test_concat():
shape = (1,2,5)
data0 = np.arange(10).reshape(shape)
# Smoke test: concat empty list.
assert_raises(ValueError, concat_images, [])

# Build combinations of 3D, 4D w/size[3] == 1, and 4D w/size[3] == 3
all_shapes_5D = ((1, 4, 5, 3, 3),
(7, 3, 1, 4, 5),
(0, 2, 1, 4, 5))

affine = np.eye(4)
img0_mem = Nifti1Image(data0, affine)
data1 = data0 - 10
img1_mem = Nifti1Image(data1, affine)
img2_mem = Nifti1Image(data1, affine+1)
img3_mem = Nifti1Image(data1.T, affine)
all_data = np.concatenate(
[data0[:,:,:,np.newaxis],data1[:,:,:,np.newaxis]],3)
# Check filenames and in-memory images work
with InTemporaryDirectory():
imgs = [img0_mem, img1_mem, img2_mem, img3_mem]
img_files = [_as_fname(img) for img in imgs]
for img0, img1, img2, img3 in (imgs, img_files):
all_imgs = concat_images([img0, img1])
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)
# check that not-matching affines raise error
assert_raises(ValueError, concat_images, [img0, img2])
assert_raises(ValueError, concat_images, [img0, img3])
# except if check_affines is False
all_imgs = concat_images([img0, img1])
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)
# Delete images as prophylaxis for windows access errors
for img in imgs:
del(img)
for dim in range(2, 6):
all_shapes_ND = tuple((shape[:dim] for shape in all_shapes_5D))
all_shapes_N1D_unary = tuple((shape + (1,) for shape in all_shapes_ND))
all_shapes = all_shapes_ND + all_shapes_N1D_unary

# Loop over all possible combinations of images, in first and
# second position.
for data0_shape in all_shapes:
data0_numel = np.asarray(data0_shape).prod()
data0 = np.arange(data0_numel).reshape(data0_shape)
img0_mem = Nifti1Image(data0, affine)

for data1_shape in all_shapes:
data1_numel = np.asarray(data1_shape).prod()
data1 = np.arange(data1_numel).reshape(data1_shape)
img1_mem = Nifti1Image(data1, affine)
img2_mem = Nifti1Image(data1, affine+1) # bad affine

# Loop over every possible axis, including None (explicit and implied)
for axis in (list(range(-(dim-2), (dim-1))) + [None, '__default__']):

# Allow testing default vs. passing explicit param
if axis == '__default__':
np_concat_kwargs = dict(axis=-1)
concat_imgs_kwargs = dict()
axis = None # Convert downstream
elif axis is None:
np_concat_kwargs = dict(axis=-1)
concat_imgs_kwargs = dict(axis=axis)
else:
np_concat_kwargs = dict(axis=axis)
concat_imgs_kwargs = dict(axis=axis)

# Create expected output
try:
# Error will be thrown if the np.concatenate fails.
# However, when axis=None, the concatenate is possible
# but our efficient logic (where all images are
# 3D and the same size) fails, so we also
# have to expect errors for those.
if axis is None: # 3D from here and below
all_data = np.concatenate([data0[..., np.newaxis],
data1[..., np.newaxis]],
**np_concat_kwargs)
else: # both 3D, appending on final axis
all_data = np.concatenate([data0, data1],
**np_concat_kwargs)
expect_error = False
except ValueError:
# Shapes are not combinable
expect_error = True

# Check filenames and in-memory images work
with InTemporaryDirectory():
# Try mem-based, file-based, and mixed
imgs = [img0_mem, img1_mem, img2_mem]
img_files = [_as_fname(img) for img in imgs]
imgs_mixed = [imgs[0], img_files[1], imgs[2]]
for img0, img1, img2 in (imgs, img_files, imgs_mixed):
try:
all_imgs = concat_images([img0, img1],
**concat_imgs_kwargs)
except ValueError as ve:
assert_true(expect_error, str(ve))
else:
assert_false(expect_error, "Expected a concatenation error, but got none.")
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)

# check that not-matching affines raise error
assert_raises(ValueError, concat_images, [img0, img2], **concat_imgs_kwargs)

# except if check_affines is False
try:
all_imgs = concat_images([img0, img1], **concat_imgs_kwargs)
except ValueError as ve:
assert_true(expect_error, str(ve))
else:
assert_false(expect_error, "Expected a concatenation error, but got none.")
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)


def test_closest_canonical():
Expand Down