Skip to content

Commit 6d4da36

Browse files
committed
use ConcurrentRunner
1 parent 955b48f commit 6d4da36

File tree

3 files changed

+89
-144
lines changed

3 files changed

+89
-144
lines changed

test/asynchronous/test_gridfs.py

Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
import threading
2323
import time
2424
from io import BytesIO
25+
from test.asynchronous.helpers import ConcurrentRunner
2526
from unittest.mock import patch
2627

2728
sys.path[0:0] = [""]
2829

2930
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
30-
from test.utils import asyncjoinall, joinall, one
31+
from test.utils import joinall, one
3132

3233
import gridfs
3334
from bson.binary import Binary
@@ -44,64 +45,35 @@
4445

4546
_IS_SYNC = False
4647

47-
if _IS_SYNC:
48-
49-
class JustWrite(threading.Thread):
50-
def __init__(self, fs, n):
51-
threading.Thread.__init__(self)
52-
self.fs = fs
53-
self.n = n
54-
self.daemon = True
55-
56-
def run(self):
57-
for _ in range(self.n):
58-
file = self.fs.new_file(filename="test")
59-
file.write(b"hello")
60-
file.close()
61-
62-
class JustRead(threading.Thread):
63-
def __init__(self, fs, n, results):
64-
threading.Thread.__init__(self)
65-
self.fs = fs
66-
self.n = n
67-
self.results = results
68-
self.daemon = True
69-
70-
def run(self):
71-
for _ in range(self.n):
72-
file = self.fs.get("test")
73-
data = file.read()
74-
self.results.append(data)
75-
assert data == b"hello"
76-
else:
77-
78-
class JustWrite:
79-
def __init__(self, fs, n):
80-
self.task = asyncio.create_task(self.run())
81-
self.fs = fs
82-
self.n = n
83-
self.daemon = True
84-
85-
async def run(self):
86-
for _ in range(self.n):
87-
file = self.fs.new_file(filename="test")
88-
await file.write(b"hello")
89-
await file.close()
90-
91-
class JustRead:
92-
def __init__(self, fs, n, results):
93-
self.task = asyncio.create_task(self.run())
94-
self.fs = fs
95-
self.n = n
96-
self.results = results
97-
self.daemon = True
98-
99-
async def run(self):
100-
for _ in range(self.n):
101-
file = await self.fs.get("test")
102-
data = await file.read()
103-
self.results.append(data)
104-
assert data == b"hello"
48+
49+
class JustWrite(ConcurrentRunner):
50+
def __init__(self, fs, n):
51+
super().__init__()
52+
self.fs = fs
53+
self.n = n
54+
self.daemon = True
55+
56+
async def run(self):
57+
for _ in range(self.n):
58+
file = self.fs.new_file(filename="test")
59+
await file.write(b"hello")
60+
await file.close()
61+
62+
63+
class JustRead(ConcurrentRunner):
64+
def __init__(self, fs, n, results):
65+
super().__init__()
66+
self.fs = fs
67+
self.n = n
68+
self.results = results
69+
self.daemon = True
70+
71+
async def run(self):
72+
for _ in range(self.n):
73+
file = await self.fs.get("test")
74+
data = await file.read()
75+
self.results.append(data)
76+
assert data == b"hello"
10577

10678

10779
class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase):
@@ -252,25 +224,29 @@ async def test_alt_collection(self):
252224
async def test_threaded_reads(self):
253225
await self.fs.put(b"hello", _id="test")
254226

255-
threads = []
227+
tasks = []
256228
results: list = []
257229
for i in range(10):
258-
threads.append(JustRead(self.fs, 10, results))
259-
if _IS_SYNC:
260-
threads[i].start()
230+
tasks.append(JustRead(self.fs, 10, results))
231+
await tasks[i].start()
261232

262-
await asyncjoinall(threads)
233+
if _IS_SYNC:
234+
joinall(tasks)
235+
else:
236+
await asyncio.wait([t.task for t in tasks])
263237

264238
self.assertEqual(100 * [b"hello"], results)
265239

266240
async def test_threaded_writes(self):
267-
threads = []
241+
tasks = []
268242
for i in range(10):
269-
threads.append(JustWrite(self.fs, 10))
270-
if _IS_SYNC:
271-
threads[i].start()
243+
tasks.append(JustWrite(self.fs, 10))
244+
await tasks[i].start()
272245

273-
await asyncjoinall(threads)
246+
if _IS_SYNC:
247+
joinall(tasks)
248+
else:
249+
await asyncio.wait([t.task for t in tasks])
274250

275251
f = await self.fs.get_last_version("test")
276252
self.assertEqual(await f.read(), b"hello")

test/test_gridfs.py

Lines changed: 44 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import threading
2323
import time
2424
from io import BytesIO
25+
from test.helpers import ConcurrentRunner
2526
from unittest.mock import patch
2627

2728
sys.path[0:0] = [""]
@@ -44,64 +45,35 @@
4445

4546
_IS_SYNC = True
4647

47-
if _IS_SYNC:
48-
49-
class JustWrite(threading.Thread):
50-
def __init__(self, fs, n):
51-
threading.Thread.__init__(self)
52-
self.fs = fs
53-
self.n = n
54-
self.daemon = True
55-
56-
def run(self):
57-
for _ in range(self.n):
58-
file = self.fs.new_file(filename="test")
59-
file.write(b"hello")
60-
file.close()
61-
62-
class JustRead(threading.Thread):
63-
def __init__(self, fs, n, results):
64-
threading.Thread.__init__(self)
65-
self.fs = fs
66-
self.n = n
67-
self.results = results
68-
self.daemon = True
69-
70-
def run(self):
71-
for _ in range(self.n):
72-
file = self.fs.get("test")
73-
data = file.read()
74-
self.results.append(data)
75-
assert data == b"hello"
76-
else:
77-
78-
class JustWrite:
79-
def __init__(self, fs, n):
80-
self.task = asyncio.create_task(self.run())
81-
self.fs = fs
82-
self.n = n
83-
self.daemon = True
84-
85-
def run(self):
86-
for _ in range(self.n):
87-
file = self.fs.new_file(filename="test")
88-
file.write(b"hello")
89-
file.close()
90-
91-
class JustRead:
92-
def __init__(self, fs, n, results):
93-
self.task = asyncio.create_task(self.run())
94-
self.fs = fs
95-
self.n = n
96-
self.results = results
97-
self.daemon = True
98-
99-
def run(self):
100-
for _ in range(self.n):
101-
file = self.fs.get("test")
102-
data = file.read()
103-
self.results.append(data)
104-
assert data == b"hello"
48+
49+
class JustWrite(ConcurrentRunner):
50+
def __init__(self, fs, n):
51+
super().__init__()
52+
self.fs = fs
53+
self.n = n
54+
self.daemon = True
55+
56+
def run(self):
57+
for _ in range(self.n):
58+
file = self.fs.new_file(filename="test")
59+
file.write(b"hello")
60+
file.close()
61+
62+
63+
class JustRead(ConcurrentRunner):
64+
def __init__(self, fs, n, results):
65+
super().__init__()
66+
self.fs = fs
67+
self.n = n
68+
self.results = results
69+
self.daemon = True
70+
71+
def run(self):
72+
for _ in range(self.n):
73+
file = self.fs.get("test")
74+
data = file.read()
75+
self.results.append(data)
76+
assert data == b"hello"
10577

10678

10779
class TestGridfsNoConnect(unittest.TestCase):
@@ -250,25 +222,29 @@ def test_alt_collection(self):
250222
def test_threaded_reads(self):
251223
self.fs.put(b"hello", _id="test")
252224

253-
threads = []
225+
tasks = []
254226
results: list = []
255227
for i in range(10):
256-
threads.append(JustRead(self.fs, 10, results))
257-
if _IS_SYNC:
258-
threads[i].start()
228+
tasks.append(JustRead(self.fs, 10, results))
229+
tasks[i].start()
259230

260-
joinall(threads)
231+
if _IS_SYNC:
232+
joinall(tasks)
233+
else:
234+
asyncio.wait([t.task for t in tasks])
261235

262236
self.assertEqual(100 * [b"hello"], results)
263237

264238
def test_threaded_writes(self):
265-
threads = []
239+
tasks = []
266240
for i in range(10):
267-
threads.append(JustWrite(self.fs, 10))
268-
if _IS_SYNC:
269-
threads[i].start()
241+
tasks.append(JustWrite(self.fs, 10))
242+
tasks[i].start()
270243

271-
joinall(threads)
244+
if _IS_SYNC:
245+
joinall(tasks)
246+
else:
247+
asyncio.wait([t.task for t in tasks])
272248

273249
f = self.fs.get_last_version("test")
274250
self.assertEqual(f.read(), b"hello")

test/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,6 @@ def joinall(threads):
665665
assert not t.is_alive(), "Thread %s hung" % t
666666

667667

668-
async def asyncjoinall(tasks):
669-
"""Join tasks with a 5-minute timeout, assert joins succeeded"""
670-
for t in tasks:
671-
await asyncio.wait([t.task], timeout=300)
672-
assert t.task.done(), "Task %s hung" % t
673-
674-
675668
def wait_until(predicate, success_description, timeout=10):
676669
"""Wait up to 10 seconds (by default) for predicate to be true.
677670

0 commit comments

Comments
 (0)