Skip to content

Commit e320bb5

Browse files
committed
Adding find method
1 parent 19307d5 commit e320bb5

File tree

4 files changed

+225
-3
lines changed

4 files changed

+225
-3
lines changed

arangoasync/aql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def response_handler(resp: Response) -> Cursor:
327327
if not resp.is_success:
328328
raise AQLQueryExecuteError(resp, request)
329329
if self._executor.context == "async":
330-
# We cannot have a cursor getting back async jobs
330+
# We cannot have a cursor giving back async jobs
331331
executor: NonAsyncExecutor = DefaultApiExecutor(
332332
self._executor.connection
333333
)

arangoasync/collection.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
__all__ = ["Collection", "StandardCollection"]
22

33

4-
from typing import Generic, List, Optional, Sequence, Tuple, TypeVar, cast
4+
from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar, cast
55

6+
from arangoasync.cursor import Cursor
67
from arangoasync.errno import (
78
DOCUMENT_NOT_FOUND,
89
HTTP_BAD_PARAMETER,
@@ -25,8 +26,9 @@
2526
IndexGetError,
2627
IndexListError,
2728
IndexLoadError,
29+
SortValidationError,
2830
)
29-
from arangoasync.executor import ApiExecutor
31+
from arangoasync.executor import ApiExecutor, DefaultApiExecutor, NonAsyncExecutor
3032
from arangoasync.request import Method, Request
3133
from arangoasync.response import Response
3234
from arangoasync.result import Result
@@ -156,6 +158,90 @@ def _prep_from_doc(
156158
else:
157159
return doc_id, {"If-Match": rev}
158160

161+
def _build_filter_conditions(self, filters: Optional[Json]) -> str:
162+
"""Build filter conditions for an AQL query.
163+
164+
Args:
165+
filters (dict | None): Document filters.
166+
167+
Returns:
168+
str: The complete AQL filter condition.
169+
"""
170+
if not filters:
171+
return ""
172+
173+
conditions = []
174+
for k, v in filters.items():
175+
field = k if "." in k else f"`{k}`"
176+
conditions.append(f"doc.{field} == {self.serializer.dumps(v)}")
177+
178+
return "FILTER " + " AND ".join(conditions)
179+
180+
@staticmethod
181+
def _is_none_or_int(obj: Any) -> bool:
182+
"""Check if obj is `None` or a positive integer.
183+
184+
Args:
185+
obj: Object to check.
186+
187+
Returns:
188+
bool: `True` if object is `None` or a positive integer.
189+
"""
190+
return obj is None or isinstance(obj, int) and obj >= 0
191+
192+
@staticmethod
193+
def _is_none_or_dict(obj: Any) -> bool:
194+
"""Check if obj is `None` or a dict.
195+
196+
Args:
197+
obj: Object to check.
198+
199+
Returns:
200+
bool: `True` if object is `None` or a dict.
201+
"""
202+
return obj is None or isinstance(obj, dict)
203+
204+
@staticmethod
205+
def _validate_sort_parameters(sort: Optional[Jsons]) -> None:
206+
"""Validate sort parameters for an AQL query.
207+
208+
Args:
209+
sort (list | None): Document sort parameters.
210+
211+
Raises:
212+
SortValidationError: If sort parameters are invalid.
213+
"""
214+
if not sort:
215+
return
216+
217+
for param in sort:
218+
if "sort_by" not in param or "sort_order" not in param:
219+
raise SortValidationError(
220+
"Each sort parameter must have 'sort_by' and 'sort_order'."
221+
)
222+
if param["sort_order"].upper() not in ["ASC", "DESC"]:
223+
raise SortValidationError("'sort_order' must be either 'ASC' or 'DESC'")
224+
225+
@staticmethod
226+
def _build_sort_expression(sort: Optional[Jsons]) -> str:
227+
"""Build a sort condition for an AQL query.
228+
229+
Args:
230+
sort (list | None): Document sort parameters.
231+
232+
Returns:
233+
str: The complete AQL sort condition.
234+
"""
235+
if not sort:
236+
return ""
237+
238+
sort_chunks = []
239+
for sort_param in sort:
240+
chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}"
241+
sort_chunks.append(chunk)
242+
243+
return "SORT " + ", ".join(sort_chunks)
244+
159245
@property
160246
def name(self) -> str:
161247
"""Return the name of the collection.
@@ -1006,3 +1092,74 @@ def response_handler(resp: Response) -> V:
10061092
return self._doc_deserializer.loads_many(resp.raw_body)
10071093

10081094
return await self._executor.execute(request, response_handler)
1095+
1096+
async def find(
1097+
self,
1098+
filters: Optional[Json] = None,
1099+
skip: Optional[int] = None,
1100+
limit: Optional[int | str] = None,
1101+
allow_dirty_read: Optional[bool] = False,
1102+
sort: Optional[Jsons] = None,
1103+
) -> Result[Cursor]:
1104+
"""Return all documents that match the given filters.
1105+
1106+
Args:
1107+
filters (dict | None): Query filters.
1108+
skip (int | None): Number of documents to skip.
1109+
limit (int | str | None): Maximum number of documents to return.
1110+
allow_dirty_read (bool): Allow reads from followers in a cluster.
1111+
sort (list | None): Document sort parameters.
1112+
1113+
Returns:
1114+
Cursor: Document cursor.
1115+
1116+
Raises:
1117+
DocumentGetError: If retrieval fails.
1118+
SortValidationError: If sort parameters are invalid.
1119+
"""
1120+
if not self._is_none_or_dict(filters):
1121+
raise ValueError("filters parameter must be a dict")
1122+
self._validate_sort_parameters(sort)
1123+
if not self._is_none_or_int(skip):
1124+
raise ValueError("skip parameter must be a non-negative int")
1125+
if not (self._is_none_or_int(limit) or limit == "null"):
1126+
raise ValueError("limit parameter must be a non-negative int")
1127+
1128+
skip = skip if skip is not None else 0
1129+
limit = limit if limit is not None else "null"
1130+
query = f"""
1131+
FOR doc IN @@collection
1132+
{self._build_filter_conditions(filters)}
1133+
LIMIT {skip}, {limit}
1134+
{self._build_sort_expression(sort)}
1135+
RETURN doc
1136+
"""
1137+
bind_vars = {"@collection": self.name}
1138+
data: Json = {"query": query, "bindVars": bind_vars, "count": True}
1139+
headers: RequestHeaders = {}
1140+
if allow_dirty_read is not None:
1141+
if allow_dirty_read is True:
1142+
headers["x-arango-allow-dirty-read"] = "true"
1143+
else:
1144+
headers["x-arango-allow-dirty-read"] = "false"
1145+
1146+
request = Request(
1147+
method=Method.POST,
1148+
endpoint="/_api/cursor",
1149+
data=self.serializer.dumps(data),
1150+
headers=headers,
1151+
)
1152+
1153+
def response_handler(resp: Response) -> Cursor:
1154+
if not resp.is_success:
1155+
raise DocumentGetError(resp, request)
1156+
if self._executor.context == "async":
1157+
# We cannot have a cursor giving back async jobs
1158+
executor: NonAsyncExecutor = DefaultApiExecutor(
1159+
self._executor.connection
1160+
)
1161+
else:
1162+
executor = cast(NonAsyncExecutor, self._executor)
1163+
return Cursor(executor, self.deserializer.loads(resp.raw_body))
1164+
1165+
return await self._executor.execute(request, response_handler)

arangoasync/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ class ServerVersionError(ArangoServerError):
327327
"""Failed to retrieve server version."""
328328

329329

330+
class SortValidationError(ArangoClientError):
331+
"""Invalid sort parameters."""
332+
333+
330334
class TransactionAbortError(ArangoServerError):
331335
"""Failed to abort transaction."""
332336

tests/test_document.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DocumentReplaceError,
1111
DocumentRevisionError,
1212
DocumentUpdateError,
13+
SortValidationError,
1314
)
1415
from tests.helpers import generate_col_name
1516

@@ -234,3 +235,63 @@ async def test_document_get_many(doc_col, bad_col, docs):
234235
# Empty list
235236
many = await doc_col.get_many([])
236237
assert len(many) == 0
238+
239+
240+
@pytest.mark.asyncio
241+
async def test_document_find(doc_col, bad_col, docs):
242+
# Check errors first
243+
with pytest.raises(DocumentGetError):
244+
await bad_col.find()
245+
with pytest.raises(ValueError):
246+
await doc_col.find(limit=-1)
247+
with pytest.raises(ValueError):
248+
await doc_col.find(skip="abcd")
249+
with pytest.raises(ValueError):
250+
await doc_col.find(filters="abcd")
251+
with pytest.raises(SortValidationError):
252+
await doc_col.find(sort="abcd")
253+
with pytest.raises(SortValidationError):
254+
await doc_col.find(sort=[{"x": "text", "sort_order": "ASC"}])
255+
256+
# Insert all documents
257+
await asyncio.gather(*[doc_col.insert(doc) for doc in docs])
258+
259+
# Empty find
260+
filter_docs = []
261+
async for doc in await doc_col.find():
262+
filter_docs.append(doc)
263+
assert len(filter_docs) == len(docs)
264+
265+
# Test with filter
266+
filter_docs = []
267+
async for doc in await doc_col.find(filters={"val": 42}):
268+
filter_docs.append(doc)
269+
assert len(filter_docs) == 0
270+
async for doc in await doc_col.find(filters={"text": "foo"}):
271+
filter_docs.append(doc)
272+
assert len(filter_docs) == 3
273+
filter_docs = []
274+
async for doc in await doc_col.find(filters={"text": "foo", "val": 1}):
275+
filter_docs.append(doc)
276+
assert len(filter_docs) == 1
277+
278+
# Test with limit
279+
filter_docs = []
280+
async for doc in await doc_col.find(limit=2):
281+
filter_docs.append(doc)
282+
assert len(filter_docs) == 2
283+
284+
# Test with skip
285+
filter_docs = []
286+
async for doc in await doc_col.find(skip=2, allow_dirty_read=True):
287+
filter_docs.append(doc)
288+
assert len(filter_docs) == len(docs) - 2
289+
290+
# Test with sort
291+
filter_docs = []
292+
async for doc in await doc_col.find(
293+
{}, sort=[{"sort_by": "text", "sort_order": "ASC"}]
294+
):
295+
filter_docs.append(doc)
296+
for idx in range(len(filter_docs) - 1):
297+
assert filter_docs[idx]["text"] <= filter_docs[idx + 1]["text"]

0 commit comments

Comments
 (0)