From ad7712e1927c9f4418b95602d0bccd414df98302 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 19:52:51 -0800 Subject: [PATCH 01/10] WIP --- test/asynchronous/test_gridfs.py | 632 +++++++++++++++++++++++++++++++ test/test_gridfs.py | 252 +++++++----- test/utils.py | 7 + tools/synchro.py | 3 + 4 files changed, 799 insertions(+), 95 deletions(-) create mode 100644 test/asynchronous/test_gridfs.py diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py new file mode 100644 index 0000000000..ae6088ae68 --- /dev/null +++ b/test/asynchronous/test_gridfs.py @@ -0,0 +1,632 @@ +# +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the gridfs package.""" +from __future__ import annotations + +import asyncio +import datetime +import sys +import threading +import time +from io import BytesIO +from unittest.mock import patch + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import asyncjoinall, joinall, one + +import gridfs +from bson.binary import Binary +from gridfs.asynchronous.grid_file import DEFAULT_CHUNK_SIZE, AsyncGridOutCursor +from gridfs.errors import CorruptGridFile, FileExists, NoFile +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, +) +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False + +if _IS_SYNC: + + class JustWrite(threading.Thread): + def __init__(self, fs, n): + threading.Thread.__init__(self) + self.fs = fs + self.n = n + self.daemon = True + + def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + file.write(b"hello") + file.close() + + class JustRead(threading.Thread): + def __init__(self, fs, n, results): + threading.Thread.__init__(self) + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + def run(self): + for _ in range(self.n): + file = self.fs.get("test") + data = file.read() + self.results.append(data) + assert data == b"hello" +else: + + class JustWrite(asyncio.Task): + def __init__(self, fs, n): + async def run(): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + await file.write(b"hello") + await file.close() + + asyncio.Task.__init__(self, run()) + self.fs = fs + self.n = n + self.daemon = True + + class JustRead(asyncio.Task): + def __init__(self, fs, n, results): + async def run(): + for _ in range(self.n): + file = await self.fs.get("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" + + asyncio.Task.__init__(self, run()) + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + +class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase): + db: AsyncDatabase + + async def asyncSetUp(self): + await super().asyncSetUp() + self.db = AsyncMongoClient(connect=False).pymongo_test + + async def test_gridfs(self): + self.assertRaises(TypeError, gridfs.AsyncGridFS, "foo") + self.assertRaises(TypeError, gridfs.AsyncGridFS, self.db, 5) + + +class TestGridfs(AsyncIntegrationTest): + fs: gridfs.AsyncGridFS + alt: gridfs.AsyncGridFS + + async def asyncSetUp(self): + await super().asyncSetUp() + self.fs = gridfs.AsyncGridFS(self.db) + self.alt = gridfs.AsyncGridFS(self.db, "alt") + await self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) + + async def test_basic(self): + oid = await self.fs.put(b"hello world") + self.assertEqual(b"hello world", await (await self.fs.get(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + await self.fs.delete(oid) + with self.assertRaises(NoFile): + await self.fs.get(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.fs.get("foo") + oid = await self.fs.put(b"hello world", _id="foo") + self.assertEqual("foo", oid) + self.assertEqual(b"hello world", await (await self.fs.get("foo")).read()) + + async def test_multi_chunk_delete(self): + await self.db.fs.drop() + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + gfs = gridfs.AsyncGridFS(self.db) + oid = await gfs.put(b"hello", chunkSize=1) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(5, await self.db.fs.chunks.count_documents({})) + await gfs.delete(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_list(self): + self.assertEqual([], await self.fs.list()) + await self.fs.put(b"hello world") + self.assertEqual([], await self.fs.list()) + + # PYTHON-598: in server versions before 2.5.x, creating an index on + # filename, uploadDate causes list() to include None. + await self.fs.get_last_version() + self.assertEqual([], await self.fs.list()) + + await self.fs.put(b"", filename="mike") + await self.fs.put(b"foo", filename="test") + await self.fs.put(b"", filename="hello world") + + self.assertEqual({"mike", "test", "hello world"}, set(await self.fs.list())) + + async def test_empty_file(self): + oid = await self.fs.put(b"") + self.assertEqual(b"", await (await self.fs.get(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + raw = await self.db.fs.files.find_one() + assert raw is not None + self.assertEqual(0, raw["length"]) + self.assertEqual(oid, raw["_id"]) + self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) + self.assertEqual(255 * 1024, raw["chunkSize"]) + self.assertNotIn("md5", raw) + + async def test_corrupt_chunk(self): + files_id = await self.fs.put(b"foobar") + await self.db.fs.chunks.update_one( + {"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}} + ) + try: + out = await self.fs.get(files_id) + with self.assertRaises(CorruptGridFile): + await out.read() + + out = await self.fs.get(files_id) + with self.assertRaises(CorruptGridFile): + await out.readline() + finally: + await self.fs.delete(files_id) + + async def test_put_ensures_index(self): + chunks = self.db.fs.chunks + files = self.db.fs.files + # Ensure the collections are removed. + await chunks.drop() + await files.drop() + await self.fs.put(b"junk") + + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in (await chunks.index_information()).values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + + async def test_alt_collection(self): + oid = await self.alt.put(b"hello world") + self.assertEqual(b"hello world", await (await self.alt.get(oid)).read()) + self.assertEqual(1, await self.db.alt.files.count_documents({})) + self.assertEqual(1, await self.db.alt.chunks.count_documents({})) + + await self.alt.delete(oid) + with self.assertRaises(NoFile): + await self.alt.get(oid) + self.assertEqual(0, await self.db.alt.files.count_documents({})) + self.assertEqual(0, await self.db.alt.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.alt.get("foo") + oid = await self.alt.put(b"hello world", _id="foo") + self.assertEqual("foo", oid) + self.assertEqual(b"hello world", await (await self.alt.get("foo")).read()) + + await self.alt.put(b"", filename="mike") + await self.alt.put(b"foo", filename="test") + await self.alt.put(b"", filename="hello world") + + self.assertEqual({"mike", "test", "hello world"}, set(await self.alt.list())) + + async def test_threaded_reads(self): + await self.fs.put(b"hello", _id="test") + + threads = [] + results: list = [] + for i in range(10): + threads.append(JustRead(self.fs, 10, results)) + if _IS_SYNC: + threads[i].start() + + await asyncjoinall(threads) + + self.assertEqual(100 * [b"hello"], results) + + async def test_threaded_writes(self): + threads = [] + for i in range(10): + threads.append(JustWrite(self.fs, 10)) + if _IS_SYNC: + threads[i].start() + + await asyncjoinall(threads) + + f = await self.fs.get_last_version("test") + self.assertEqual(await f.read(), b"hello") + + # Should have created 100 versions of 'test' file + self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"})) + + async def test_get_last_version(self): + one = await self.fs.put(b"foo", filename="test") + await asyncio.sleep(0.01) + two = self.fs.new_file(filename="test") + await two.write(b"bar") + await two.close() + await asyncio.sleep(0.01) + two = two._id + three = await self.fs.put(b"baz", filename="test") + + self.assertEqual(b"baz", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(three) + self.assertEqual(b"bar", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(one) + with self.assertRaises(NoFile): + await self.fs.get_last_version("test") + + async def test_get_last_version_with_metadata(self): + one = await self.fs.put(b"foo", filename="test", author="author") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author") + + self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author")).read()) + await self.fs.delete(one) + + one = await self.fs.put(b"foo", filename="test", author="author1") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author2") + + self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author1")).read()) + self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author2")).read()) + self.assertEqual(b"bar", await (await self.fs.get_last_version(filename="test")).read()) + + with self.assertRaises(NoFile): + await self.fs.get_last_version(author="author3") + with self.assertRaises(NoFile): + await self.fs.get_last_version(filename="nottest", author="author1") + + await self.fs.delete(one) + await self.fs.delete(two) + + async def test_get_version(self): + await self.fs.put(b"foo", filename="test") + await asyncio.sleep(0.01) + await self.fs.put(b"bar", filename="test") + await asyncio.sleep(0.01) + await self.fs.put(b"baz", filename="test") + await asyncio.sleep(0.01) + + self.assertEqual(b"foo", await (await self.fs.get_version("test", 0)).read()) + self.assertEqual(b"bar", await (await self.fs.get_version("test", 1)).read()) + self.assertEqual(b"baz", await (await self.fs.get_version("test", 2)).read()) + + self.assertEqual(b"baz", await (await self.fs.get_version("test", -1)).read()) + self.assertEqual(b"bar", await (await self.fs.get_version("test", -2)).read()) + self.assertEqual(b"foo", await (await self.fs.get_version("test", -3)).read()) + + with self.assertRaises(NoFile): + await self.fs.get_version("test", 3) + with self.assertRaises(NoFile): + await self.fs.get_version("test", -4) + + async def test_get_version_with_metadata(self): + one = await self.fs.put(b"foo", filename="test", author="author1") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author1") + await asyncio.sleep(0.01) + three = await self.fs.put(b"baz", filename="test", author="author2") + + self.assertEqual( + b"foo", + await (await self.fs.get_version(filename="test", author="author1", version=-2)).read(), + ) + self.assertEqual( + b"bar", + await (await self.fs.get_version(filename="test", author="author1", version=-1)).read(), + ) + self.assertEqual( + b"foo", + await (await self.fs.get_version(filename="test", author="author1", version=0)).read(), + ) + self.assertEqual( + b"bar", + await (await self.fs.get_version(filename="test", author="author1", version=1)).read(), + ) + self.assertEqual( + b"baz", + await (await self.fs.get_version(filename="test", author="author2", version=0)).read(), + ) + self.assertEqual( + b"baz", await (await self.fs.get_version(filename="test", version=-1)).read() + ) + self.assertEqual( + b"baz", await (await self.fs.get_version(filename="test", version=2)).read() + ) + + with self.assertRaises(NoFile): + await self.fs.get_version(filename="test", author="author3") + with self.assertRaises(NoFile): + await self.fs.get_version(filename="test", author="author1", version=2) + + await self.fs.delete(one) + await self.fs.delete(two) + await self.fs.delete(three) + + async def test_put_filelike(self): + oid = await self.fs.put(BytesIO(b"hello world"), chunk_size=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + self.assertEqual(b"hello world", await (await self.fs.get(oid)).read()) + + async def test_file_exists(self): + oid = await self.fs.put(b"hello") + with self.assertRaises(FileExists): + await self.fs.put(b"world", _id=oid) + + one = self.fs.new_file(_id=123) + await one.write(b"some content") + await one.close() + + # Attempt to upload a file with more chunks to the same _id. + with patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE): + two = self.fs.new_file(_id=123) + with self.assertRaises(FileExists): + await two.write(b"x" * DEFAULT_CHUNK_SIZE * 3) + # Original file is still readable (no extra chunks were uploaded). + self.assertEqual(await (await self.fs.get(123)).read(), b"some content") + + two = self.fs.new_file(_id=123) + await two.write(b"some content") + with self.assertRaises(FileExists): + await two.close() + # Original file is still readable. + self.assertEqual(await (await self.fs.get(123)).read(), b"some content") + + async def test_exists(self): + oid = await self.fs.put(b"hello") + self.assertTrue(await self.fs.exists(oid)) + self.assertTrue(await self.fs.exists({"_id": oid})) + self.assertTrue(await self.fs.exists(_id=oid)) + + self.assertFalse(await self.fs.exists(filename="mike")) + self.assertFalse(await self.fs.exists("mike")) + + oid = await self.fs.put(b"hello", filename="mike", foo=12) + self.assertTrue(await self.fs.exists(oid)) + self.assertTrue(await self.fs.exists({"_id": oid})) + self.assertTrue(await self.fs.exists(_id=oid)) + self.assertTrue(await self.fs.exists(filename="mike")) + self.assertTrue(await self.fs.exists({"filename": "mike"})) + self.assertTrue(await self.fs.exists(foo=12)) + self.assertTrue(await self.fs.exists({"foo": 12})) + self.assertTrue(await self.fs.exists(foo={"$gt": 11})) + self.assertTrue(await self.fs.exists({"foo": {"$gt": 11}})) + + self.assertFalse(await self.fs.exists(foo=13)) + self.assertFalse(await self.fs.exists({"foo": 13})) + self.assertFalse(await self.fs.exists(foo={"$gt": 12})) + self.assertFalse(await self.fs.exists({"foo": {"$gt": 12}})) + + async def test_put_unicode(self): + with self.assertRaises(TypeError): + await self.fs.put("hello") + + oid = await self.fs.put("hello", encoding="utf-8") + self.assertEqual(b"hello", await (await self.fs.get(oid)).read()) + self.assertEqual("utf-8", (await self.fs.get(oid)).encoding) + + oid = await self.fs.put("aé", encoding="iso-8859-1") + self.assertEqual("aé".encode("iso-8859-1"), await (await self.fs.get(oid)).read()) + self.assertEqual("iso-8859-1", (await self.fs.get(oid)).encoding) + + async def test_missing_length_iter(self): + # Test fix that guards against PHP-237 + await self.fs.put(b"", filename="empty") + doc = await self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None + doc.pop("length") + await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) + f = await self.fs.get_last_version(filename="empty") + + async def iterate_file(grid_file): + async for _chunk in grid_file: + pass + return True + + self.assertTrue(await iterate_file(f)) + + async def test_gridfs_lazy_connect(self): + client = await self.async_single_client( + "badhost", connect=False, serverSelectionTimeoutMS=10 + ) + db = client.db + gfs = gridfs.AsyncGridFS(db) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.list() + + fs = gridfs.AsyncGridFS(db) + f = fs.new_file() + with self.assertRaises(ServerSelectionTimeoutError): + await f.close() + + async def test_gridfs_find(self): + await self.fs.put(b"test2", filename="two") + await asyncio.sleep(0.01) + await self.fs.put(b"test2+", filename="two") + await asyncio.sleep(0.01) + await self.fs.put(b"test1", filename="one") + await asyncio.sleep(0.01) + await self.fs.put(b"test2++", filename="two") + files = self.db.fs.files + self.assertEqual(3, await files.count_documents({"filename": "two"})) + self.assertEqual(4, await files.count_documents({})) + cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + await cursor.rewind() + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + gout = await cursor.next() + self.assertEqual(b"test2+", await gout.read()) + with self.assertRaises(StopAsyncIteration): + await cursor.__anext__() + await cursor.rewind() + items = await cursor.to_list() + self.assertEqual(len(items), 2) + await cursor.rewind() + items = await cursor.to_list(1) + self.assertEqual(len(items), 1) + await cursor.close() + self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) + + async def test_delete_not_initialized(self): + # Creating a cursor with invalid arguments will not run __init__ + # but will still call __del__. + cursor = AsyncGridOutCursor.__new__(AsyncGridOutCursor) # Skip calling __init__ + with self.assertRaises(TypeError): + cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore + cursor.__del__() # no error + + async def test_gridfs_find_one(self): + self.assertEqual(None, await self.fs.find_one()) + + id1 = await self.fs.put(b"test1", filename="file1") + res = await self.fs.find_one() + assert res is not None + self.assertEqual(b"test1", await res.read()) + + id2 = await self.fs.put(b"test2", filename="file2", meta="data") + res1 = await self.fs.find_one(id1) + assert res1 is not None + self.assertEqual(b"test1", await res1.read()) + res2 = await self.fs.find_one(id2) + assert res2 is not None + self.assertEqual(b"test2", await res2.read()) + + res3 = await self.fs.find_one({"filename": "file1"}) + assert res3 is not None + self.assertEqual(b"test1", await res3.read()) + + res4 = await self.fs.find_one(id2) + assert res4 is not None + self.assertEqual("data", res4.meta) + + async def test_grid_in_non_int_chunksize(self): + # Lua, and perhaps other buggy AsyncGridFS clients, store size as a float. + data = b"data" + await self.fs.put(data, filename="f") + await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) + + self.assertEqual(data, await (await self.fs.get_version("f")).read()) + + async def test_unacknowledged(self): + # w=0 is prohibited. + with self.assertRaises(ConfigurationError): + gridfs.AsyncGridFS((await self.async_rs_or_single_client(w=0)).pymongo_test) + + async def test_md5(self): + gin = self.fs.new_file() + await gin.write(b"no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.get(gin._id) + self.assertIsNone(gout.md5) + + _id = await self.fs.put(b"still no md5 sum") + gout = await self.fs.get(_id) + self.assertIsNone(gout.md5) + + +class TestGridfsReplicaSet(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + + @classmethod + @async_client_context.require_connection + async def asyncTearDownClass(cls): + await async_client_context.client.drop_database("gfsreplica") + + async def test_gridfs_replica_set(self): + rsc = await self.async_rs_client( + w=async_client_context.w, read_preference=ReadPreference.SECONDARY + ) + + fs = gridfs.AsyncGridFS(rsc.gfsreplica, "gfsreplicatest") + + gin = fs.new_file() + self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) + + oid = await fs.put(b"foo") + content = await (await fs.get(oid)).read() + self.assertEqual(b"foo", content) + + async def test_gridfs_secondary(self): + secondary_host, secondary_port = one(await self.client.secondaries) + secondary_connection = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) + + # Should detect it's connected to secondary and not attempt to + # create index + fs = gridfs.AsyncGridFS(secondary_connection.gfsreplica, "gfssecondarytest") + + # This won't detect secondary, raises error + with self.assertRaises(NotPrimaryError): + await fs.put(b"foo") + + async def test_gridfs_secondary_lazy(self): + # Should detect it's connected to secondary and not attempt to + # create index. + secondary_host, secondary_port = one(await self.client.secondaries) + client = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) + + # Still no connection. + fs = gridfs.AsyncGridFS(client.gfsreplica, "gfssecondarylazytest") + + # Connects, doesn't create index. + with self.assertRaises(NoFile): + await fs.get_last_version() + with self.assertRaises(NotPrimaryError): + await fs.put("data", encoding="utf-8") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_gridfs.py b/test/test_gridfs.py index ab8950250b..905f6a2c6a 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -16,6 +16,7 @@ """Tests for the gridfs package.""" from __future__ import annotations +import asyncio import datetime import sys import threading @@ -41,35 +42,66 @@ from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient - -class JustWrite(threading.Thread): - def __init__(self, fs, n): - threading.Thread.__init__(self) - self.fs = fs - self.n = n - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - file.write(b"hello") - file.close() - - -class JustRead(threading.Thread): - def __init__(self, fs, n, results): - threading.Thread.__init__(self) - self.fs = fs - self.n = n - self.results = results - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.get("test") - data = file.read() - self.results.append(data) - assert data == b"hello" +_IS_SYNC = True + +if _IS_SYNC: + + class JustWrite(threading.Thread): + def __init__(self, fs, n): + threading.Thread.__init__(self) + self.fs = fs + self.n = n + self.daemon = True + + def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + file.write(b"hello") + file.close() + + class JustRead(threading.Thread): + def __init__(self, fs, n, results): + threading.Thread.__init__(self) + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + def run(self): + for _ in range(self.n): + file = self.fs.get("test") + data = file.read() + self.results.append(data) + assert data == b"hello" +else: + + class JustWrite(asyncio.Task): + def __init__(self, fs, n): + def run(): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + file.write(b"hello") + file.close() + + asyncio.Task.__init__(self, run()) + self.fs = fs + self.n = n + self.daemon = True + + class JustRead(asyncio.Task): + def __init__(self, fs, n, results): + def run(): + for _ in range(self.n): + file = self.fs.get("test") + data = file.read() + self.results.append(data) + assert data == b"hello" + + asyncio.Task.__init__(self, run()) + self.fs = fs + self.n = n + self.results = results + self.daemon = True class TestGridfsNoConnect(unittest.TestCase): @@ -98,19 +130,21 @@ def setUp(self): def test_basic(self): oid = self.fs.put(b"hello world") - self.assertEqual(b"hello world", self.fs.get(oid).read()) + self.assertEqual(b"hello world", (self.fs.get(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) - self.assertRaises(NoFile, self.fs.get, oid) + with self.assertRaises(NoFile): + self.fs.get(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) - self.assertRaises(NoFile, self.fs.get, "foo") + with self.assertRaises(NoFile): + self.fs.get("foo") oid = self.fs.put(b"hello world", _id="foo") self.assertEqual("foo", oid) - self.assertEqual(b"hello world", self.fs.get("foo").read()) + self.assertEqual(b"hello world", (self.fs.get("foo")).read()) def test_multi_chunk_delete(self): self.db.fs.drop() @@ -142,7 +176,7 @@ def test_list(self): def test_empty_file(self): oid = self.fs.put(b"") - self.assertEqual(b"", self.fs.get(oid).read()) + self.assertEqual(b"", (self.fs.get(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -159,10 +193,12 @@ def test_corrupt_chunk(self): self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.get(files_id) - self.assertRaises(CorruptGridFile, out.read) + with self.assertRaises(CorruptGridFile): + out.read() out = self.fs.get(files_id) - self.assertRaises(CorruptGridFile, out.readline) + with self.assertRaises(CorruptGridFile): + out.readline() finally: self.fs.delete(files_id) @@ -177,31 +213,33 @@ def test_put_ensures_index(self): self.assertTrue( any( info.get("key") == [("files_id", 1), ("n", 1)] - for info in chunks.index_information().values() + for info in (chunks.index_information()).values() ) ) self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) def test_alt_collection(self): oid = self.alt.put(b"hello world") - self.assertEqual(b"hello world", self.alt.get(oid).read()) + self.assertEqual(b"hello world", (self.alt.get(oid)).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) - self.assertRaises(NoFile, self.alt.get, oid) + with self.assertRaises(NoFile): + self.alt.get(oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) - self.assertRaises(NoFile, self.alt.get, "foo") + with self.assertRaises(NoFile): + self.alt.get("foo") oid = self.alt.put(b"hello world", _id="foo") self.assertEqual("foo", oid) - self.assertEqual(b"hello world", self.alt.get("foo").read()) + self.assertEqual(b"hello world", (self.alt.get("foo")).read()) self.alt.put(b"", filename="mike") self.alt.put(b"foo", filename="test") @@ -216,7 +254,8 @@ def test_threaded_reads(self): results: list = [] for i in range(10): threads.append(JustRead(self.fs, 10, results)) - threads[i].start() + if _IS_SYNC: + threads[i].start() joinall(threads) @@ -226,7 +265,8 @@ def test_threaded_writes(self): threads = [] for i in range(10): threads.append(JustWrite(self.fs, 10)) - threads[i].start() + if _IS_SYNC: + threads[i].start() joinall(threads) @@ -246,34 +286,37 @@ def test_get_last_version(self): two = two._id three = self.fs.put(b"baz", filename="test") - self.assertEqual(b"baz", self.fs.get_last_version("test").read()) + self.assertEqual(b"baz", (self.fs.get_last_version("test")).read()) self.fs.delete(three) - self.assertEqual(b"bar", self.fs.get_last_version("test").read()) + self.assertEqual(b"bar", (self.fs.get_last_version("test")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.get_last_version("test").read()) + self.assertEqual(b"foo", (self.fs.get_last_version("test")).read()) self.fs.delete(one) - self.assertRaises(NoFile, self.fs.get_last_version, "test") + with self.assertRaises(NoFile): + self.fs.get_last_version("test") def test_get_last_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author") - self.assertEqual(b"bar", self.fs.get_last_version(author="author").read()) + self.assertEqual(b"bar", (self.fs.get_last_version(author="author")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.get_last_version(author="author").read()) + self.assertEqual(b"foo", (self.fs.get_last_version(author="author")).read()) self.fs.delete(one) one = self.fs.put(b"foo", filename="test", author="author1") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author2") - self.assertEqual(b"foo", self.fs.get_last_version(author="author1").read()) - self.assertEqual(b"bar", self.fs.get_last_version(author="author2").read()) - self.assertEqual(b"bar", self.fs.get_last_version(filename="test").read()) + self.assertEqual(b"foo", (self.fs.get_last_version(author="author1")).read()) + self.assertEqual(b"bar", (self.fs.get_last_version(author="author2")).read()) + self.assertEqual(b"bar", (self.fs.get_last_version(filename="test")).read()) - self.assertRaises(NoFile, self.fs.get_last_version, author="author3") - self.assertRaises(NoFile, self.fs.get_last_version, filename="nottest", author="author1") + with self.assertRaises(NoFile): + self.fs.get_last_version(author="author3") + with self.assertRaises(NoFile): + self.fs.get_last_version(filename="nottest", author="author1") self.fs.delete(one) self.fs.delete(two) @@ -286,16 +329,18 @@ def test_get_version(self): self.fs.put(b"baz", filename="test") time.sleep(0.01) - self.assertEqual(b"foo", self.fs.get_version("test", 0).read()) - self.assertEqual(b"bar", self.fs.get_version("test", 1).read()) - self.assertEqual(b"baz", self.fs.get_version("test", 2).read()) + self.assertEqual(b"foo", (self.fs.get_version("test", 0)).read()) + self.assertEqual(b"bar", (self.fs.get_version("test", 1)).read()) + self.assertEqual(b"baz", (self.fs.get_version("test", 2)).read()) - self.assertEqual(b"baz", self.fs.get_version("test", -1).read()) - self.assertEqual(b"bar", self.fs.get_version("test", -2).read()) - self.assertEqual(b"foo", self.fs.get_version("test", -3).read()) + self.assertEqual(b"baz", (self.fs.get_version("test", -1)).read()) + self.assertEqual(b"bar", (self.fs.get_version("test", -2)).read()) + self.assertEqual(b"foo", (self.fs.get_version("test", -3)).read()) - self.assertRaises(NoFile, self.fs.get_version, "test", 3) - self.assertRaises(NoFile, self.fs.get_version, "test", -4) + with self.assertRaises(NoFile): + self.fs.get_version("test", 3) + with self.assertRaises(NoFile): + self.fs.get_version("test", -4) def test_get_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author1") @@ -305,25 +350,32 @@ def test_get_version_with_metadata(self): three = self.fs.put(b"baz", filename="test", author="author2") self.assertEqual( - b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read() + b"foo", + (self.fs.get_version(filename="test", author="author1", version=-2)).read(), ) self.assertEqual( - b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read() + b"bar", + (self.fs.get_version(filename="test", author="author1", version=-1)).read(), ) self.assertEqual( - b"foo", self.fs.get_version(filename="test", author="author1", version=0).read() + b"foo", + (self.fs.get_version(filename="test", author="author1", version=0)).read(), ) self.assertEqual( - b"bar", self.fs.get_version(filename="test", author="author1", version=1).read() + b"bar", + (self.fs.get_version(filename="test", author="author1", version=1)).read(), ) self.assertEqual( - b"baz", self.fs.get_version(filename="test", author="author2", version=0).read() + b"baz", + (self.fs.get_version(filename="test", author="author2", version=0)).read(), ) - self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read()) - self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read()) + self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=-1)).read()) + self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=2)).read()) - self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3") - self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2) + with self.assertRaises(NoFile): + self.fs.get_version(filename="test", author="author3") + with self.assertRaises(NoFile): + self.fs.get_version(filename="test", author="author1", version=2) self.fs.delete(one) self.fs.delete(two) @@ -332,11 +384,12 @@ def test_get_version_with_metadata(self): def test_put_filelike(self): oid = self.fs.put(BytesIO(b"hello world"), chunk_size=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) - self.assertEqual(b"hello world", self.fs.get(oid).read()) + self.assertEqual(b"hello world", (self.fs.get(oid)).read()) def test_file_exists(self): oid = self.fs.put(b"hello") - self.assertRaises(FileExists, self.fs.put, b"world", _id=oid) + with self.assertRaises(FileExists): + self.fs.put(b"world", _id=oid) one = self.fs.new_file(_id=123) one.write(b"some content") @@ -345,15 +398,17 @@ def test_file_exists(self): # Attempt to upload a file with more chunks to the same _id. with patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE): two = self.fs.new_file(_id=123) - self.assertRaises(FileExists, two.write, b"x" * DEFAULT_CHUNK_SIZE * 3) + with self.assertRaises(FileExists): + two.write(b"x" * DEFAULT_CHUNK_SIZE * 3) # Original file is still readable (no extra chunks were uploaded). - self.assertEqual(self.fs.get(123).read(), b"some content") + self.assertEqual((self.fs.get(123)).read(), b"some content") two = self.fs.new_file(_id=123) two.write(b"some content") - self.assertRaises(FileExists, two.close) + with self.assertRaises(FileExists): + two.close() # Original file is still readable. - self.assertEqual(self.fs.get(123).read(), b"some content") + self.assertEqual((self.fs.get(123)).read(), b"some content") def test_exists(self): oid = self.fs.put(b"hello") @@ -381,15 +436,16 @@ def test_exists(self): self.assertFalse(self.fs.exists({"foo": {"$gt": 12}})) def test_put_unicode(self): - self.assertRaises(TypeError, self.fs.put, "hello") + with self.assertRaises(TypeError): + self.fs.put("hello") oid = self.fs.put("hello", encoding="utf-8") - self.assertEqual(b"hello", self.fs.get(oid).read()) - self.assertEqual("utf-8", self.fs.get(oid).encoding) + self.assertEqual(b"hello", (self.fs.get(oid)).read()) + self.assertEqual("utf-8", (self.fs.get(oid)).encoding) oid = self.fs.put("aé", encoding="iso-8859-1") - self.assertEqual("aé".encode("iso-8859-1"), self.fs.get(oid).read()) - self.assertEqual("iso-8859-1", self.fs.get(oid).encoding) + self.assertEqual("aé".encode("iso-8859-1"), (self.fs.get(oid)).read()) + self.assertEqual("iso-8859-1", (self.fs.get(oid)).encoding) def test_missing_length_iter(self): # Test fix that guards against PHP-237 @@ -411,11 +467,13 @@ def test_gridfs_lazy_connect(self): client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) - self.assertRaises(ServerSelectionTimeoutError, gfs.list) + with self.assertRaises(ServerSelectionTimeoutError): + gfs.list() fs = gridfs.GridFS(db) f = fs.new_file() - self.assertRaises(ServerSelectionTimeoutError, f.close) + with self.assertRaises(ServerSelectionTimeoutError): + f.close() def test_gridfs_find(self): self.fs.put(b"test2", filename="two") @@ -429,14 +487,15 @@ def test_gridfs_find(self): self.assertEqual(3, files.count_documents({"filename": "two"})) self.assertEqual(4, files.count_documents({})) cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) cursor.rewind() - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test2+", gout.read()) - self.assertRaises(StopIteration, cursor.__next__) + with self.assertRaises(StopIteration): + cursor.__next__() cursor.rewind() items = cursor.to_list() self.assertEqual(len(items), 2) @@ -484,12 +543,12 @@ def test_grid_in_non_int_chunksize(self): self.fs.put(data, filename="f") self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, self.fs.get_version("f").read()) + self.assertEqual(data, (self.fs.get_version("f")).read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test) + gridfs.GridFS((self.rs_or_single_client(w=0)).pymongo_test) def test_md5(self): gin = self.fs.new_file() @@ -524,7 +583,7 @@ def test_gridfs_replica_set(self): self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) oid = fs.put(b"foo") - content = fs.get(oid).read() + content = (fs.get(oid)).read() self.assertEqual(b"foo", content) def test_gridfs_secondary(self): @@ -538,7 +597,8 @@ def test_gridfs_secondary(self): fs = gridfs.GridFS(secondary_connection.gfsreplica, "gfssecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, fs.put, b"foo") + with self.assertRaises(NotPrimaryError): + fs.put(b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to @@ -552,8 +612,10 @@ def test_gridfs_secondary_lazy(self): fs = gridfs.GridFS(client.gfsreplica, "gfssecondarylazytest") # Connects, doesn't create index. - self.assertRaises(NoFile, fs.get_last_version) - self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8") + with self.assertRaises(NoFile): + fs.get_last_version() + with self.assertRaises(NotPrimaryError): + fs.put("data", encoding="utf-8") if __name__ == "__main__": diff --git a/test/utils.py b/test/utils.py index 69154bc63b..ad2aa95429 100644 --- a/test/utils.py +++ b/test/utils.py @@ -608,6 +608,13 @@ def joinall(threads): assert not t.is_alive(), "Thread %s hung" % t +async def asyncjoinall(tasks): + """Join tasks with a 5-minute timeout, assert joins succeeded""" + for t in tasks: + await asyncio.wait_for(t, 300) + assert t.done(), "Task %s hung" % t + + def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. diff --git a/tools/synchro.py b/tools/synchro.py index 897e5e8018..c4cdd4defa 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,8 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "StopAsyncIteration": "StopIteration", + "asyncjoinall": "joinall", } docstring_replacements: dict[tuple[str, str], str] = { @@ -207,6 +209,7 @@ def async_only_test(f: str) -> bool: "test_data_lake.py", "test_encryption.py", "test_grid_file.py", + "test_gridfs.py", "test_logger.py", "test_monitoring.py", "test_raw_bson.py", From cb0ede9224bdd9034e1a1b660416054687c8cfa6 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 08:50:41 -0800 Subject: [PATCH 02/10] catch StopAsyncIteration --- gridfs/asynchronous/grid_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index a49d51d304..00f02c9220 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -231,7 +231,7 @@ async def get_version( try: doc = await anext(cursor) return AsyncGridOut(self._collection, file_document=doc, session=session) - except StopIteration: + except (StopIteration, StopAsyncIteration): raise NoFile("no version %d for filename %r" % (version, filename)) from None async def get_last_version( From fd5c8debeab841e1ca87b3be09dd8a2d963e7863 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 13:21:30 -0800 Subject: [PATCH 03/10] address review --- gridfs/asynchronous/grid_file.py | 2 +- test/asynchronous/test_gridfs.py | 8 ++++---- test/test_gridfs.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 00f02c9220..92a4b748fe 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -231,7 +231,7 @@ async def get_version( try: doc = await anext(cursor) return AsyncGridOut(self._collection, file_document=doc, session=session) - except (StopIteration, StopAsyncIteration): + except StopAsyncIteration: raise NoFile("no version %d for filename %r" % (version, filename)) from None async def get_last_version( diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index ae6088ae68..6ed1633461 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -75,7 +75,7 @@ def run(self): assert data == b"hello" else: - class JustWrite(asyncio.Task): + class JustWrite: def __init__(self, fs, n): async def run(): for _ in range(self.n): @@ -83,12 +83,12 @@ async def run(): await file.write(b"hello") await file.close() - asyncio.Task.__init__(self, run()) + self.task = asyncio.create_task(run()) self.fs = fs self.n = n self.daemon = True - class JustRead(asyncio.Task): + class JustRead: def __init__(self, fs, n, results): async def run(): for _ in range(self.n): @@ -97,7 +97,7 @@ async def run(): self.results.append(data) assert data == b"hello" - asyncio.Task.__init__(self, run()) + self.task = asyncio.create_task(run()) self.fs = fs self.n = n self.results = results diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 905f6a2c6a..a4b62a9eee 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -75,7 +75,7 @@ def run(self): assert data == b"hello" else: - class JustWrite(asyncio.Task): + class JustWrite: def __init__(self, fs, n): def run(): for _ in range(self.n): @@ -83,12 +83,12 @@ def run(): file.write(b"hello") file.close() - asyncio.Task.__init__(self, run()) + self.task = asyncio.create_task(run()) self.fs = fs self.n = n self.daemon = True - class JustRead(asyncio.Task): + class JustRead: def __init__(self, fs, n, results): def run(): for _ in range(self.n): @@ -97,7 +97,7 @@ def run(): self.results.append(data) assert data == b"hello" - asyncio.Task.__init__(self, run()) + self.task = asyncio.create_task(run()) self.fs = fs self.n = n self.results = results From 662b7e13557ab40ea24a24f97d940e1437a6aea0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 13:34:48 -0800 Subject: [PATCH 04/10] fix --- test/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/utils.py b/test/utils.py index ad2aa95429..857f699782 100644 --- a/test/utils.py +++ b/test/utils.py @@ -611,8 +611,8 @@ def joinall(threads): async def asyncjoinall(tasks): """Join tasks with a 5-minute timeout, assert joins succeeded""" for t in tasks: - await asyncio.wait_for(t, 300) - assert t.done(), "Task %s hung" % t + await asyncio.wait_for(t.task, 300) + assert t.task.done(), "Task %s hung" % t def wait_until(predicate, success_description, timeout=10): From a21c7e5f61d9be72453f3718ffafe079ceaf4b63 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 13:58:19 -0800 Subject: [PATCH 05/10] address review --- test/asynchronous/test_gridfs.py | 30 +++++++++++++++--------------- test/test_gridfs.py | 30 +++++++++++++++--------------- test/utils.py | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index 6ed1633461..ecfcfd94e1 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -77,32 +77,32 @@ def run(self): class JustWrite: def __init__(self, fs, n): - async def run(): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - await file.write(b"hello") - await file.close() - - self.task = asyncio.create_task(run()) + self.task = asyncio.create_task(self.run()) self.fs = fs self.n = n self.daemon = True + async def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + await file.write(b"hello") + await file.close() + class JustRead: def __init__(self, fs, n, results): - async def run(): - for _ in range(self.n): - file = await self.fs.get("test") - data = await file.read() - self.results.append(data) - assert data == b"hello" - - self.task = asyncio.create_task(run()) + self.task = asyncio.create_task(self.run()) self.fs = fs self.n = n self.results = results self.daemon = True + async def run(self): + for _ in range(self.n): + file = await self.fs.get("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" + class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase): db: AsyncDatabase diff --git a/test/test_gridfs.py b/test/test_gridfs.py index a4b62a9eee..7f460406c1 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -77,32 +77,32 @@ def run(self): class JustWrite: def __init__(self, fs, n): - def run(): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - file.write(b"hello") - file.close() - - self.task = asyncio.create_task(run()) + self.task = asyncio.create_task(self.run()) self.fs = fs self.n = n self.daemon = True + def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + file.write(b"hello") + file.close() + class JustRead: def __init__(self, fs, n, results): - def run(): - for _ in range(self.n): - file = self.fs.get("test") - data = file.read() - self.results.append(data) - assert data == b"hello" - - self.task = asyncio.create_task(run()) + self.task = asyncio.create_task(self.run()) self.fs = fs self.n = n self.results = results self.daemon = True + def run(self): + for _ in range(self.n): + file = self.fs.get("test") + data = file.read() + self.results.append(data) + assert data == b"hello" + class TestGridfsNoConnect(unittest.TestCase): db: Database diff --git a/test/utils.py b/test/utils.py index 541cab217f..f692186099 100644 --- a/test/utils.py +++ b/test/utils.py @@ -668,7 +668,7 @@ def joinall(threads): async def asyncjoinall(tasks): """Join tasks with a 5-minute timeout, assert joins succeeded""" for t in tasks: - await asyncio.wait_for(t.task, 300) + await asyncio.wait([t.task], timeout=300) assert t.task.done(), "Task %s hung" % t From 6d4da369e9e70593e5f54b8860f49afff03934d3 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 5 Feb 2025 18:34:19 -0800 Subject: [PATCH 06/10] use ConcurrentRunner --- test/asynchronous/test_gridfs.py | 114 ++++++++++++------------------- test/test_gridfs.py | 112 ++++++++++++------------------ test/utils.py | 7 -- 3 files changed, 89 insertions(+), 144 deletions(-) diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index ecfcfd94e1..f8d6c5f7c0 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -22,12 +22,13 @@ import threading import time from io import BytesIO +from test.asynchronous.helpers import ConcurrentRunner from unittest.mock import patch sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import asyncjoinall, joinall, one +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -44,64 +45,35 @@ _IS_SYNC = False -if _IS_SYNC: - - class JustWrite(threading.Thread): - def __init__(self, fs, n): - threading.Thread.__init__(self) - self.fs = fs - self.n = n - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - file.write(b"hello") - file.close() - - class JustRead(threading.Thread): - def __init__(self, fs, n, results): - threading.Thread.__init__(self) - self.fs = fs - self.n = n - self.results = results - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.get("test") - data = file.read() - self.results.append(data) - assert data == b"hello" -else: - - class JustWrite: - def __init__(self, fs, n): - self.task = asyncio.create_task(self.run()) - self.fs = fs - self.n = n - self.daemon = True - - async def run(self): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - await file.write(b"hello") - await file.close() - - class JustRead: - def __init__(self, fs, n, results): - self.task = asyncio.create_task(self.run()) - self.fs = fs - self.n = n - self.results = results - self.daemon = True - - async def run(self): - for _ in range(self.n): - file = await self.fs.get("test") - data = await file.read() - self.results.append(data) - assert data == b"hello" + +class JustWrite(ConcurrentRunner): + def __init__(self, fs, n): + super().__init__() + self.fs = fs + self.n = n + self.daemon = True + + async def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + await file.write(b"hello") + await file.close() + + +class JustRead(ConcurrentRunner): + def __init__(self, fs, n, results): + super().__init__() + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + async def run(self): + for _ in range(self.n): + file = await self.fs.get("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase): @@ -252,25 +224,29 @@ async def test_alt_collection(self): async def test_threaded_reads(self): await self.fs.put(b"hello", _id="test") - threads = [] + tasks = [] results: list = [] for i in range(10): - threads.append(JustRead(self.fs, 10, results)) - if _IS_SYNC: - threads[i].start() + tasks.append(JustRead(self.fs, 10, results)) + await tasks[i].start() - await asyncjoinall(threads) + if _IS_SYNC: + joinall(tasks) + else: + await asyncio.wait([t.task for t in tasks]) self.assertEqual(100 * [b"hello"], results) async def test_threaded_writes(self): - threads = [] + tasks = [] for i in range(10): - threads.append(JustWrite(self.fs, 10)) - if _IS_SYNC: - threads[i].start() + tasks.append(JustWrite(self.fs, 10)) + await tasks[i].start() - await asyncjoinall(threads) + if _IS_SYNC: + joinall(tasks) + else: + await asyncio.wait([t.task for t in tasks]) f = await self.fs.get_last_version("test") self.assertEqual(await f.read(), b"hello") diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 7f460406c1..c745af38b8 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -22,6 +22,7 @@ import threading import time from io import BytesIO +from test.helpers import ConcurrentRunner from unittest.mock import patch sys.path[0:0] = [""] @@ -44,64 +45,35 @@ _IS_SYNC = True -if _IS_SYNC: - - class JustWrite(threading.Thread): - def __init__(self, fs, n): - threading.Thread.__init__(self) - self.fs = fs - self.n = n - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - file.write(b"hello") - file.close() - - class JustRead(threading.Thread): - def __init__(self, fs, n, results): - threading.Thread.__init__(self) - self.fs = fs - self.n = n - self.results = results - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.get("test") - data = file.read() - self.results.append(data) - assert data == b"hello" -else: - - class JustWrite: - def __init__(self, fs, n): - self.task = asyncio.create_task(self.run()) - self.fs = fs - self.n = n - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.new_file(filename="test") - file.write(b"hello") - file.close() - - class JustRead: - def __init__(self, fs, n, results): - self.task = asyncio.create_task(self.run()) - self.fs = fs - self.n = n - self.results = results - self.daemon = True - - def run(self): - for _ in range(self.n): - file = self.fs.get("test") - data = file.read() - self.results.append(data) - assert data == b"hello" + +class JustWrite(ConcurrentRunner): + def __init__(self, fs, n): + super().__init__() + self.fs = fs + self.n = n + self.daemon = True + + def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + file.write(b"hello") + file.close() + + +class JustRead(ConcurrentRunner): + def __init__(self, fs, n, results): + super().__init__() + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + def run(self): + for _ in range(self.n): + file = self.fs.get("test") + data = file.read() + self.results.append(data) + assert data == b"hello" class TestGridfsNoConnect(unittest.TestCase): @@ -250,25 +222,29 @@ def test_alt_collection(self): def test_threaded_reads(self): self.fs.put(b"hello", _id="test") - threads = [] + tasks = [] results: list = [] for i in range(10): - threads.append(JustRead(self.fs, 10, results)) - if _IS_SYNC: - threads[i].start() + tasks.append(JustRead(self.fs, 10, results)) + tasks[i].start() - joinall(threads) + if _IS_SYNC: + joinall(tasks) + else: + asyncio.wait([t.task for t in tasks]) self.assertEqual(100 * [b"hello"], results) def test_threaded_writes(self): - threads = [] + tasks = [] for i in range(10): - threads.append(JustWrite(self.fs, 10)) - if _IS_SYNC: - threads[i].start() + tasks.append(JustWrite(self.fs, 10)) + tasks[i].start() - joinall(threads) + if _IS_SYNC: + joinall(tasks) + else: + asyncio.wait([t.task for t in tasks]) f = self.fs.get_last_version("test") self.assertEqual(f.read(), b"hello") diff --git a/test/utils.py b/test/utils.py index f692186099..91000a636a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -665,13 +665,6 @@ def joinall(threads): assert not t.is_alive(), "Thread %s hung" % t -async def asyncjoinall(tasks): - """Join tasks with a 5-minute timeout, assert joins succeeded""" - for t in tasks: - await asyncio.wait([t.task], timeout=300) - assert t.task.done(), "Task %s hung" % t - - def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. From 6b7f09b8ce1037ae1727d0c6eda11ce1b3453ed6 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 5 Feb 2025 18:35:16 -0800 Subject: [PATCH 07/10] remove asyncjoinall from synchro --- tools/synchro.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/synchro.py b/tools/synchro.py index 165a38cfa9..6addfd99c4 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -123,7 +123,6 @@ "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", "StopAsyncIteration": "StopIteration", - "asyncjoinall": "joinall", } docstring_replacements: dict[tuple[str, str], str] = { From 0d60657ace4a31a1418bea6a5e42de0912200c84 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 5 Feb 2025 18:48:05 -0800 Subject: [PATCH 08/10] fix typing --- test/asynchronous/test_gridfs.py | 4 ++-- test/test_gridfs.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index f8d6c5f7c0..ba140bb1f5 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -233,7 +233,7 @@ async def test_threaded_reads(self): if _IS_SYNC: joinall(tasks) else: - await asyncio.wait([t.task for t in tasks]) + await asyncio.wait([t.task for t in tasks if t.task is not None]) self.assertEqual(100 * [b"hello"], results) @@ -246,7 +246,7 @@ async def test_threaded_writes(self): if _IS_SYNC: joinall(tasks) else: - await asyncio.wait([t.task for t in tasks]) + await asyncio.wait([t.task for t in tasks if t.task is not None]) f = await self.fs.get_last_version("test") self.assertEqual(await f.read(), b"hello") diff --git a/test/test_gridfs.py b/test/test_gridfs.py index c745af38b8..2650c086c5 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -231,7 +231,7 @@ def test_threaded_reads(self): if _IS_SYNC: joinall(tasks) else: - asyncio.wait([t.task for t in tasks]) + asyncio.wait([t.task for t in tasks if t.task is not None]) self.assertEqual(100 * [b"hello"], results) @@ -244,7 +244,7 @@ def test_threaded_writes(self): if _IS_SYNC: joinall(tasks) else: - asyncio.wait([t.task for t in tasks]) + asyncio.wait([t.task for t in tasks if t.task is not None]) f = self.fs.get_last_version("test") self.assertEqual(f.read(), b"hello") From 6827f71744fdb1c521acf4216c03a9516b311451 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 6 Feb 2025 10:23:36 -0800 Subject: [PATCH 09/10] use async_joinall --- test/asynchronous/test_gridfs.py | 12 +++--------- test/test_gridfs.py | 10 ++-------- test/utils.py | 5 +++++ tools/synchro.py | 1 + 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index ba140bb1f5..b1c1e754ff 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import joinall, one +from test.utils import async_joinall, one import gridfs from bson.binary import Binary @@ -230,10 +230,7 @@ async def test_threaded_reads(self): tasks.append(JustRead(self.fs, 10, results)) await tasks[i].start() - if _IS_SYNC: - joinall(tasks) - else: - await asyncio.wait([t.task for t in tasks if t.task is not None]) + await async_joinall(tasks) self.assertEqual(100 * [b"hello"], results) @@ -243,10 +240,7 @@ async def test_threaded_writes(self): tasks.append(JustWrite(self.fs, 10)) await tasks[i].start() - if _IS_SYNC: - joinall(tasks) - else: - await asyncio.wait([t.task for t in tasks if t.task is not None]) + await async_joinall(tasks) f = await self.fs.get_last_version("test") self.assertEqual(await f.read(), b"hello") diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 2650c086c5..47e38141b2 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -228,10 +228,7 @@ def test_threaded_reads(self): tasks.append(JustRead(self.fs, 10, results)) tasks[i].start() - if _IS_SYNC: - joinall(tasks) - else: - asyncio.wait([t.task for t in tasks if t.task is not None]) + joinall(tasks) self.assertEqual(100 * [b"hello"], results) @@ -241,10 +238,7 @@ def test_threaded_writes(self): tasks.append(JustWrite(self.fs, 10)) tasks[i].start() - if _IS_SYNC: - joinall(tasks) - else: - asyncio.wait([t.task for t in tasks if t.task is not None]) + joinall(tasks) f = self.fs.get_last_version("test") self.assertEqual(f.read(), b"hello") diff --git a/test/utils.py b/test/utils.py index 5c1e0bfb7c..40eec01cb4 100644 --- a/test/utils.py +++ b/test/utils.py @@ -666,6 +666,11 @@ def joinall(threads): assert not t.is_alive(), "Thread %s hung" % t +async def async_joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) + + def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. diff --git a/tools/synchro.py b/tools/synchro.py index 48b3103c88..f38da72d51 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -124,6 +124,7 @@ "AsyncMockPool": "MockPool", "StopAsyncIteration": "StopIteration", "create_async_event": "create_event", + "async_joinall": "joinall", } docstring_replacements: dict[tuple[str, str], str] = { From aa2a26dea337f0b010e98e2202db2098be80ac34 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Feb 2025 09:08:51 -0800 Subject: [PATCH 10/10] remove duplicated file names in synchro --- tools/synchro.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tools/synchro.py b/tools/synchro.py index 37cde9a8b7..8fec7627ea 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -214,14 +214,13 @@ def async_only_test(f: str) -> bool: "test_dns.py", "test_encryption.py", "test_examples.py", - "test_heartbeat_monitoring.py", - "test_index_management.py", "test_grid_file.py", "test_gridfs.py", - "test_load_balancer.py", - "test_json_util_integration.py", "test_gridfs_spec.py", + "test_heartbeat_monitoring.py", + "test_index_management.py", "test_json_util_integration.py", + "test_load_balancer.py", "test_logger.py", "test_max_staleness.py", "test_monitoring.py",