|
1 | 1 | import numpy as np
|
2 | 2 |
|
3 |
| -from pytensor.graph.basic import Apply, Constant, Variable |
| 3 | +from pytensor.graph.basic import Apply, Constant |
4 | 4 | from pytensor.graph.op import Op
|
5 | 5 | from pytensor.link.c.type import Generic
|
6 | 6 | from pytensor.tensor.type import tensor
|
7 |
| -from pytensor.utils import key_to_cmp |
8 | 7 |
|
9 | 8 |
|
10 | 9 | class LoadFromDisk(Op):
|
@@ -92,229 +91,4 @@ def load(path, dtype, shape, mmap_mode=None):
|
92 | 91 | return LoadFromDisk(dtype, shape, mmap_mode)(path)
|
93 | 92 |
|
94 | 93 |
|
95 |
| -########################## |
96 |
| -# MPI |
97 |
| -########################## |
98 |
| - |
99 |
| -try: |
100 |
| - from mpi4py import MPI |
101 |
| -except ImportError: |
102 |
| - mpi_enabled = False |
103 |
| -else: |
104 |
| - comm = MPI.COMM_WORLD |
105 |
| - mpi_enabled = True |
106 |
| - |
107 |
| - |
108 |
| -class MPIRecv(Op): |
109 |
| - """ |
110 |
| - An operation to asynchronously receive an array to a remote host using MPI. |
111 |
| -
|
112 |
| - See Also |
113 |
| - -------- |
114 |
| - MPIRecv |
115 |
| - MPIWait |
116 |
| -
|
117 |
| - Notes |
118 |
| - ----- |
119 |
| - Non-differentiable. |
120 |
| -
|
121 |
| - """ |
122 |
| - |
123 |
| - __props__ = ("source", "tag", "shape", "dtype") |
124 |
| - |
125 |
| - def __init__(self, source, tag, shape, dtype): |
126 |
| - self.source = source |
127 |
| - self.tag = tag |
128 |
| - self.shape = shape |
129 |
| - self.dtype = np.dtype(dtype) # turn "float64" into numpy.float64 |
130 |
| - self.static_shape = (None,) * len(shape) |
131 |
| - |
132 |
| - def make_node(self): |
133 |
| - return Apply( |
134 |
| - self, |
135 |
| - [], |
136 |
| - [ |
137 |
| - Variable(Generic(), None), |
138 |
| - tensor(self.dtype, shape=self.static_shape), |
139 |
| - ], |
140 |
| - ) |
141 |
| - |
142 |
| - def perform(self, node, inp, out): |
143 |
| - |
144 |
| - data = np.zeros(self.shape, dtype=self.dtype) |
145 |
| - request = comm.Irecv(data, self.source, self.tag) |
146 |
| - |
147 |
| - out[0][0] = request |
148 |
| - out[1][0] = data |
149 |
| - |
150 |
| - def __str__(self): |
151 |
| - return f"MPIRecv{{source: {int(self.source)}, tag: {int(self.tag)}, shape: {self.shape}, dtype: {self.dtype}}}" |
152 |
| - |
153 |
| - def infer_shape(self, fgraph, node, shapes): |
154 |
| - return [None, self.shape] |
155 |
| - |
156 |
| - def do_constant_folding(self, fgraph, node): |
157 |
| - return False |
158 |
| - |
159 |
| - |
160 |
| -class MPIRecvWait(Op): |
161 |
| - """ |
162 |
| - An operation to wait on a previously received array using MPI. |
163 |
| -
|
164 |
| - See Also |
165 |
| - -------- |
166 |
| - MPIRecv |
167 |
| -
|
168 |
| - Notes |
169 |
| - ----- |
170 |
| - Non-differentiable. |
171 |
| -
|
172 |
| - """ |
173 |
| - |
174 |
| - __props__ = ("tag",) |
175 |
| - |
176 |
| - def __init__(self, tag): |
177 |
| - self.tag = tag |
178 |
| - |
179 |
| - def make_node(self, request, data): |
180 |
| - return Apply( |
181 |
| - self, |
182 |
| - [request, data], |
183 |
| - [tensor(data.dtype, shape=data.type.shape)], |
184 |
| - ) |
185 |
| - |
186 |
| - def perform(self, node, inp, out): |
187 |
| - |
188 |
| - request = inp[0] |
189 |
| - data = inp[1] |
190 |
| - |
191 |
| - request.wait() |
192 |
| - |
193 |
| - out[0][0] = data |
194 |
| - |
195 |
| - def infer_shape(self, fgraph, node, shapes): |
196 |
| - return [shapes[1]] |
197 |
| - |
198 |
| - view_map = {0: [1]} |
199 |
| - |
200 |
| - |
201 |
| -class MPISend(Op): |
202 |
| - """ |
203 |
| - An operation to asynchronously Send an array to a remote host using MPI. |
204 |
| -
|
205 |
| - See Also |
206 |
| - -------- |
207 |
| - MPIRecv |
208 |
| - MPISendWait |
209 |
| -
|
210 |
| - Notes |
211 |
| - ----- |
212 |
| - Non-differentiable. |
213 |
| -
|
214 |
| - """ |
215 |
| - |
216 |
| - __props__ = ("dest", "tag") |
217 |
| - |
218 |
| - def __init__(self, dest, tag): |
219 |
| - self.dest = dest |
220 |
| - self.tag = tag |
221 |
| - |
222 |
| - def make_node(self, data): |
223 |
| - return Apply(self, [data], [Variable(Generic(), None), data.type()]) |
224 |
| - |
225 |
| - view_map = {1: [0]} |
226 |
| - |
227 |
| - def perform(self, node, inp, out): |
228 |
| - |
229 |
| - data = inp[0] |
230 |
| - |
231 |
| - request = comm.Isend(data, self.dest, self.tag) |
232 |
| - |
233 |
| - out[0][0] = request |
234 |
| - out[1][0] = data |
235 |
| - |
236 |
| - def __str__(self): |
237 |
| - return f"MPISend{{dest: {int(self.dest)}, tag: {int(self.tag)}}}" |
238 |
| - |
239 |
| - |
240 |
| -class MPISendWait(Op): |
241 |
| - """ |
242 |
| - An operation to wait on a previously sent array using MPI. |
243 |
| -
|
244 |
| - See Also |
245 |
| - -------- |
246 |
| - MPISend |
247 |
| -
|
248 |
| - Notes |
249 |
| - ----- |
250 |
| - Non-differentiable. |
251 |
| -
|
252 |
| - """ |
253 |
| - |
254 |
| - __props__ = ("tag",) |
255 |
| - |
256 |
| - def __init__(self, tag): |
257 |
| - self.tag = tag |
258 |
| - |
259 |
| - def make_node(self, request, data): |
260 |
| - return Apply(self, [request, data], [Variable(Generic(), None)]) |
261 |
| - |
262 |
| - def perform(self, node, inp, out): |
263 |
| - request = inp[0] |
264 |
| - request.wait() |
265 |
| - out[0][0] = True |
266 |
| - |
267 |
| - |
268 |
| -def isend(var, dest, tag): |
269 |
| - """ |
270 |
| - Non blocking send. |
271 |
| - """ |
272 |
| - return MPISend(dest, tag)(var) |
273 |
| - |
274 |
| - |
275 |
| -def send(var, dest, tag): |
276 |
| - """ |
277 |
| - Blocking send. |
278 |
| - """ |
279 |
| - return MPISendWait(tag)(*isend(var, dest, tag)) |
280 |
| - |
281 |
| - |
282 |
| -def irecv(shape, dtype, source, tag): |
283 |
| - """ |
284 |
| - Non-blocking receive. |
285 |
| - """ |
286 |
| - return MPIRecv(source, tag, shape, dtype)() |
287 |
| - |
288 |
| - |
289 |
| -def recv(shape, dtype, source, tag): |
290 |
| - """ |
291 |
| - Blocking receive. |
292 |
| - """ |
293 |
| - return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag)) |
294 |
| - |
295 |
| - |
296 |
| -# Ordering keys for scheduling |
297 |
| -def mpi_send_wait_key(a): |
298 |
| - """Wait as long as possible on Waits, Start Send/Recvs early.""" |
299 |
| - if isinstance(a.op, (MPIRecvWait, MPISendWait)): |
300 |
| - return 1 |
301 |
| - if isinstance(a.op, (MPIRecv, MPISend)): |
302 |
| - return -1 |
303 |
| - return 0 |
304 |
| - |
305 |
| - |
306 |
| -def mpi_tag_key(a): |
307 |
| - """Break MPI ties by using the variable tag - prefer lower tags first.""" |
308 |
| - if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)): |
309 |
| - return a.op.tag |
310 |
| - else: |
311 |
| - return 0 |
312 |
| - |
313 |
| - |
314 |
| -mpi_send_wait_cmp = key_to_cmp(mpi_send_wait_key) |
315 |
| -mpi_tag_cmp = key_to_cmp(mpi_tag_key) |
316 |
| - |
317 |
| -mpi_keys = (mpi_send_wait_key, mpi_tag_key) |
318 |
| -mpi_cmps = (mpi_send_wait_cmp, mpi_tag_cmp) |
319 |
| - |
320 | 94 | __all__ = ["load"]
|
0 commit comments