Skip to content

Commit bc9c814

Browse files
committed
Minor change to paper
1 parent 073b566 commit bc9c814

File tree

4 files changed

+51
-20
lines changed

4 files changed

+51
-20
lines changed

docs/paper/paper.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ bibliography: paper.bib
2727
# Summary
2828

2929
The fast Fourier transform (FFT) is an algorithm that efficiently
30-
computes the discrete Fourier transform. The FFT is a ubiquitous
30+
computes the discrete Fourier transform. The FFT is a celebrated
3131
algorithm utilized throughout science and engineering. Since the dawn
3232
of our digital society, the FFT has permeated to the heart of everyday
3333
life applications involving audio, image, and video processing. The
@@ -84,7 +84,7 @@ unknowns.
8484
which allows for further reuse in applications beyond the FFT. In
8585
fact, the distributed array interface can be used for boosting
8686
performance through MPI-based parallelism in any algorithm able to
87-
operate on local arrays by processing undivided axes one at the time.
87+
operate on local arrays by processing undivided axes.
8888

8989
# Acknowledgements
9090

examples/darray.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
s0, s1 = np.linalg.norm(z2), np.linalg.norm(z2c)
2525
assert abs(s0-s1) < 1e-12, s0-s1
2626

27-
print(z3.get((5, 4, 5)))
28-
print(z3.local_slice(), z3.substart, z3.commsizes)
29-
3027
v0 = newDistArray(fft, forward_output=False, rank=1)
3128
#v0 = Function(fft, forward_output=False, rank=1)
3229
v0[:] = np.random.random(v0.shape)
@@ -40,8 +37,6 @@
4037
s0, s1 = np.linalg.norm(v0c), np.linalg.norm(v0)
4138
assert abs(s0-s1) < 1e-12
4239

43-
print(v0.substart, v0.commsizes)
44-
4540
nfft = PFFT(MPI.COMM_WORLD, darray=v0[0], axes=(0, 2, 1))
4641
for i in range(3):
4742
v1[i] = nfft.forward(v0[i], v1[i])
@@ -53,17 +48,11 @@
5348
N = (6, 6, 6)
5449
z = DistArray(N, dtype=float, alignment=0)
5550
z[:] = MPI.COMM_WORLD.Get_rank()
56-
g = z.get((0, slice(None), 0))
57-
if MPI.COMM_WORLD.Get_rank() == 0:
58-
print(g)
59-
60-
z2 = DistArray(N, dtype=float, alignment=2)
61-
z.redistribute(darray=z2)
62-
63-
g = z2.get((0, slice(None), 0))
64-
if MPI.COMM_WORLD.Get_rank() == 0:
65-
print(g)
66-
51+
g0 = z.get((0, slice(None), 0))
52+
z2 = z.redistribute(2)
53+
z = z2.redistribute(darray=z)
54+
g1 = z.get((0, slice(None), 0))
55+
assert np.all(g0 == g1)
6756
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z)**2)
6857
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
6958
if MPI.COMM_WORLD.Get_rank() == 0:
@@ -79,3 +68,16 @@
7968
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z0)**2)
8069
if MPI.COMM_WORLD.Get_rank() == 0:
8170
assert abs(s0-s1) < 1e-12
71+
72+
z1 = z0.redistribute(darray=z1)
73+
z0 = z1.redistribute(darray=z0)
74+
75+
N = (6, 6, 6, 6, 6)
76+
m0 = DistArray(N, dtype=float, alignment=2)
77+
m0[:] = MPI.COMM_WORLD.Get_rank()
78+
m1 = m0.redistribute(4)
79+
m0 = m1.redistribute(darray=m0)
80+
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(m0)**2)
81+
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(m1)**2)
82+
if MPI.COMM_WORLD.Get_rank() == 0:
83+
assert abs(s0-s1) < 1e-12

mpi4py_fft/distarray.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from numbers import Number
33
import numpy as np
44
from mpi4py import MPI
5-
from .pencil import Pencil, Subcomm
5+
from .pencil import Pencil, Subcomm, Transfer
66

77
comm = MPI.COMM_WORLD
88

@@ -276,6 +276,10 @@ def get_pencil_and_transfer(self, axis):
276276
def redistribute(self, axis=None, darray=None):
277277
"""Global redistribution of local ``self`` array
278278
279+
Note
280+
----
281+
Use either ``axis`` or ``darray``, not both.
282+
279283
Parameters
280284
----------
281285
axis : int, optional
@@ -290,17 +294,40 @@ def redistribute(self, axis=None, darray=None):
290294
None then a new DistArray (aligned along ``axis``) is created
291295
and returned. Otherwise the provided darray is returned.
292296
"""
293-
if axis is None:
297+
# Take care of some trivial cases first
298+
if axis == self.alignment:
299+
return self
300+
301+
# Check if self is already aligned along axis. In that case just switch
302+
# axis of pencil (both axes are undivided) and return
303+
if axis is not None:
304+
if self.commsizes[self.rank+axis] == 1:
305+
self._p0.axis = axis
306+
return self
307+
308+
if axis is None: # darray interface
294309
assert isinstance(darray, np.ndarray)
295310
assert self.global_shape == darray.global_shape
296311
axis = darray.alignment
312+
if self.commsizes == darray.commsizes:
313+
# Just a copy required. Should probably not be here
314+
darray[:] = self
315+
return darray
316+
317+
# Check that arrays are compatible
318+
for i in range(len(self._p0.shape)):
319+
if i != self._p0.axis and i != darray._p0.axis:
320+
assert self._p0.subcomm[i] == darray._p0.subcomm[i]
321+
assert self._p0.subshape[i] == darray._p0.subshape[i]
322+
297323
p1, transfer = self.get_pencil_and_transfer(axis)
298324
if darray is None:
299325
darray = DistArray(self.global_shape,
300326
subcomm=p1.subcomm,
301327
dtype=self.dtype,
302328
alignment=axis,
303329
rank=self.rank)
330+
304331
if self.rank == 0:
305332
transfer.forward(self, darray)
306333
elif self.rank == 1:

tests/test_darray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def test_2Darray():
4545
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
4646
if MPI.COMM_WORLD.Get_rank() == 0:
4747
assert abs(s0-s1) < 1e-1
48+
c = a.redistribute(a.alignment)
49+
assert c is a
4850

4951
def test_3Darray():
5052
N = (8, 8, 8)

0 commit comments

Comments
 (0)