Skip to content

Commit 6827f71

Browse files
committed
use async_joinall
1 parent 39f6721 commit 6827f71

File tree

4 files changed

+11
-17
lines changed

4 files changed

+11
-17
lines changed

test/asynchronous/test_gridfs.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
sys.path[0:0] = [""]
2929

3030
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
31-
from test.utils import joinall, one
31+
from test.utils import async_joinall, one
3232

3333
import gridfs
3434
from bson.binary import Binary
@@ -230,10 +230,7 @@ async def test_threaded_reads(self):
230230
tasks.append(JustRead(self.fs, 10, results))
231231
await tasks[i].start()
232232

233-
if _IS_SYNC:
234-
joinall(tasks)
235-
else:
236-
await asyncio.wait([t.task for t in tasks if t.task is not None])
233+
await async_joinall(tasks)
237234

238235
self.assertEqual(100 * [b"hello"], results)
239236

@@ -243,10 +240,7 @@ async def test_threaded_writes(self):
243240
tasks.append(JustWrite(self.fs, 10))
244241
await tasks[i].start()
245242

246-
if _IS_SYNC:
247-
joinall(tasks)
248-
else:
249-
await asyncio.wait([t.task for t in tasks if t.task is not None])
243+
await async_joinall(tasks)
250244

251245
f = await self.fs.get_last_version("test")
252246
self.assertEqual(await f.read(), b"hello")

test/test_gridfs.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,7 @@ def test_threaded_reads(self):
228228
tasks.append(JustRead(self.fs, 10, results))
229229
tasks[i].start()
230230

231-
if _IS_SYNC:
232-
joinall(tasks)
233-
else:
234-
asyncio.wait([t.task for t in tasks if t.task is not None])
231+
joinall(tasks)
235232

236233
self.assertEqual(100 * [b"hello"], results)
237234

@@ -241,10 +238,7 @@ def test_threaded_writes(self):
241238
tasks.append(JustWrite(self.fs, 10))
242239
tasks[i].start()
243240

244-
if _IS_SYNC:
245-
joinall(tasks)
246-
else:
247-
asyncio.wait([t.task for t in tasks if t.task is not None])
241+
joinall(tasks)
248242

249243
f = self.fs.get_last_version("test")
250244
self.assertEqual(f.read(), b"hello")

test/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,11 @@ def joinall(threads):
666666
assert not t.is_alive(), "Thread %s hung" % t
667667

668668

669+
async def async_joinall(tasks):
670+
"""Join threads with a 5-minute timeout, assert joins succeeded"""
671+
await asyncio.wait([t.task for t in tasks if t is not None], timeout=300)
672+
673+
669674
def wait_until(predicate, success_description, timeout=10):
670675
"""Wait up to 10 seconds (by default) for predicate to be true.
671676

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
"AsyncMockPool": "MockPool",
125125
"StopAsyncIteration": "StopIteration",
126126
"create_async_event": "create_event",
127+
"async_joinall": "joinall",
127128
}
128129

129130
docstring_replacements: dict[tuple[str, str], str] = {

0 commit comments

Comments
 (0)