Skip to content

Commit 4b4e0a0

Browse files
committed
Adding replace_match
1 parent ed6bf65 commit 4b4e0a0

File tree

2 files changed

+102
-3
lines changed

2 files changed

+102
-3
lines changed

arangoasync/collection.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ async def update_match(
11791179
Args:
11801180
filters (dict | None): Query filters.
11811181
body (dict): Full or partial document body with the updates.
1182-
limit (int | str | None): Maximum number of documents to return.
1182+
limit (int | str | None): Maximum number of documents to update.
11831183
keep_none (bool | None): If set to `True`, fields with value `None` are
11841184
retained in the document. Otherwise, they are removed completely.
11851185
wait_for_sync (bool | None): Wait until operation has been synced to disk.
@@ -1198,8 +1198,6 @@ async def update_match(
11981198
if not (self._is_none_or_int(limit) or limit == "null"):
11991199
raise ValueError("limit parameter must be a non-negative int")
12001200

1201-
# If the waitForSync parameter is not specified or set to false,
1202-
# then the collection’s default waitForSync behavior is applied.
12031201
sync = f", waitForSync: {wait_for_sync}" if wait_for_sync is not None else ""
12041202
query = f"""
12051203
FOR doc IN @@collection
@@ -1237,6 +1235,69 @@ def response_handler(resp: Response) -> int:
12371235

12381236
return await self._executor.execute(request, response_handler)
12391237

1238+
async def replace_match(
1239+
self,
1240+
filters: Json,
1241+
body: T,
1242+
limit: Optional[int | str] = None,
1243+
wait_for_sync: Optional[bool] = None,
1244+
allow_dirty_read: Optional[bool] = None,
1245+
) -> Result[int]:
1246+
"""Replace matching documents.
1247+
1248+
Args:
1249+
filters (dict | None): Query filters.
1250+
body (dict): New document body.
1251+
limit (int | str | None): Maximum number of documents to replace.
1252+
wait_for_sync (bool | None): Wait until operation has been synced to disk.
1253+
allow_dirty_read (bool | None): Allow reads from followers in a cluster.
1254+
1255+
Returns:
1256+
int: Number of documents that got replaced.
1257+
1258+
Raises:
1259+
DocumentReplaceError: If replace fails.
1260+
"""
1261+
if not self._is_none_or_dict(filters):
1262+
raise ValueError("filters parameter must be a dict")
1263+
if not (self._is_none_or_int(limit) or limit == "null"):
1264+
raise ValueError("limit parameter must be a non-negative int")
1265+
1266+
sync = f"waitForSync: {wait_for_sync}" if wait_for_sync is not None else ""
1267+
query = f"""
1268+
FOR doc IN @@collection
1269+
{self._build_filter_conditions(filters)}
1270+
{f"LIMIT {limit}" if limit is not None else ""}
1271+
REPLACE doc WITH @body IN @@collection
1272+
{f"OPTIONS {{ {sync} }}" if sync else ""}
1273+
""" # noqa: E201 E202
1274+
bind_vars = {
1275+
"@collection": self.name,
1276+
"body": body,
1277+
}
1278+
data = {"query": query, "bindVars": bind_vars}
1279+
headers: RequestHeaders = {}
1280+
if allow_dirty_read is not None:
1281+
if allow_dirty_read is True:
1282+
headers["x-arango-allow-dirty-read"] = "true"
1283+
else:
1284+
headers["x-arango-allow-dirty-read"] = "false"
1285+
1286+
request = Request(
1287+
method=Method.POST,
1288+
endpoint="/_api/cursor",
1289+
data=self.serializer.dumps(data),
1290+
headers=headers,
1291+
)
1292+
1293+
def response_handler(resp: Response) -> int:
1294+
if resp.is_success:
1295+
result = self.deserializer.loads(resp.raw_body)
1296+
return cast(int, result["extra"]["stats"]["writesExecuted"])
1297+
raise DocumentReplaceError(resp, request)
1298+
1299+
return await self._executor.execute(request, response_handler)
1300+
12401301
async def insert_many(
12411302
self,
12421303
documents: Sequence[T],

tests/test_document.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,41 @@ async def test_document_update_match(doc_col, bad_col, docs):
485485
async for doc in await doc_col.find():
486486
assert doc["val"] != -1
487487
assert count == 0
488+
489+
490+
@pytest.mark.asyncio
491+
async def test_document_replace_match(doc_col, bad_col, docs):
492+
# Check errors first
493+
with pytest.raises(DocumentReplaceError):
494+
await bad_col.replace_match({}, {})
495+
with pytest.raises(ValueError):
496+
await doc_col.replace_match({}, {}, limit=-1)
497+
with pytest.raises(ValueError):
498+
await doc_col.replace_match("abcd", {})
499+
500+
# Replace all documents
501+
await doc_col.insert_many(docs)
502+
count = await doc_col.replace_match({}, {"replacement": 42})
503+
async for doc in await doc_col.find():
504+
assert "replacement" in doc
505+
assert "val" not in doc
506+
assert count == len(docs)
507+
await doc_col.truncate()
508+
509+
# Replace documents partially
510+
await doc_col.insert_many(docs)
511+
count = await doc_col.replace_match({"text": "foo"}, {"replacement": 24})
512+
async for doc in await doc_col.find():
513+
if doc.get("text") == "bar":
514+
assert "replacement" not in doc
515+
else:
516+
assert "replacement" in doc
517+
assert count == sum([1 for doc in docs if doc["text"] == "foo"])
518+
await doc_col.truncate()
519+
520+
# No matching documents
521+
await doc_col.insert_many(docs)
522+
count = await doc_col.replace_match({"text": "no_matching"}, {"val": -1})
523+
async for doc in await doc_col.find():
524+
assert doc["val"] != -1
525+
assert count == 0

0 commit comments

Comments
 (0)