Skip to content

Commit 894a389

Browse files
committed
Fixing examples
1 parent 8be1a0f commit 894a389

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

examples/darray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
z[:] = MPI.COMM_WORLD.Get_rank()
5151
g0 = z.get((0, slice(None), 0))
5252
z2 = z.redistribute(2)
53-
z = z2.redistribute(darray=z)
53+
z = z2.redistribute(out=z)
5454
g1 = z.get((0, slice(None), 0))
5555
assert np.all(g0 == g1)
5656
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z)**2)
@@ -69,14 +69,14 @@
6969
if MPI.COMM_WORLD.Get_rank() == 0:
7070
assert abs(s0-s1) < 1e-12
7171

72-
z1 = z0.redistribute(darray=z1)
73-
z0 = z1.redistribute(darray=z0)
72+
z1 = z0.redistribute(out=z1)
73+
z0 = z1.redistribute(out=z0)
7474

7575
N = (6, 6, 6, 6, 6)
7676
m0 = DistArray(N, dtype=float, alignment=2)
7777
m0[:] = MPI.COMM_WORLD.Get_rank()
7878
m1 = m0.redistribute(4)
79-
m0 = m1.redistribute(darray=m0)
79+
m0 = m1.redistribute(out=m0)
8080
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(m0)**2)
8181
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(m1)**2)
8282
if MPI.COMM_WORLD.Get_rank() == 0:

examples/transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import numpy as np
33
from mpi4py import MPI
4-
from mpi4py_fft import PFFT, DistArray
4+
from mpi4py_fft import PFFT, newDistArray
55
from mpi4py_fft.fftw import dctn, idctn
66

77
# Set global size of the computational box
@@ -17,16 +17,16 @@
1717

1818
assert fft.axes == pfft.axes
1919

20-
u = DistArray(pfft=fft, forward_output=False)
20+
u = newDistArray(fft, forward_output=False)
2121
u[:] = np.random.random(u.shape).astype(u.dtype)
2222

23-
u_hat = DistArray(pfft=fft, forward_output=True)
23+
u_hat = newDistArray(fft, forward_output=True)
2424
u_hat = fft.forward(u, u_hat)
2525
uj = np.zeros_like(u)
2626
uj = fft.backward(u_hat, uj)
2727
assert np.allclose(uj, u)
2828

29-
u_padded = DistArray(pfft=pfft, forward_output=False)
29+
u_padded = newDistArray(pfft, forward_output=False)
3030
uc = u_hat.copy()
3131
u_padded = pfft.backward(u_hat, u_padded)
3232
u_hat = pfft.forward(u_padded, u_hat)

0 commit comments

Comments
 (0)