|
1 | 1 | __all__ = ["Collection", "StandardCollection"]
|
2 | 2 |
|
3 | 3 |
|
4 |
| -from typing import Generic, List, Optional, Sequence, Tuple, TypeVar, cast |
| 4 | +from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar, cast |
5 | 5 |
|
| 6 | +from arangoasync.cursor import Cursor |
6 | 7 | from arangoasync.errno import (
|
7 | 8 | DOCUMENT_NOT_FOUND,
|
8 | 9 | HTTP_BAD_PARAMETER,
|
|
25 | 26 | IndexGetError,
|
26 | 27 | IndexListError,
|
27 | 28 | IndexLoadError,
|
| 29 | + SortValidationError, |
28 | 30 | )
|
29 |
| -from arangoasync.executor import ApiExecutor |
| 31 | +from arangoasync.executor import ApiExecutor, DefaultApiExecutor, NonAsyncExecutor |
30 | 32 | from arangoasync.request import Method, Request
|
31 | 33 | from arangoasync.response import Response
|
32 | 34 | from arangoasync.result import Result
|
@@ -156,6 +158,90 @@ def _prep_from_doc(
|
156 | 158 | else:
|
157 | 159 | return doc_id, {"If-Match": rev}
|
158 | 160 |
|
| 161 | + def _build_filter_conditions(self, filters: Optional[Json]) -> str: |
| 162 | + """Build filter conditions for an AQL query. |
| 163 | +
|
| 164 | + Args: |
| 165 | + filters (dict | None): Document filters. |
| 166 | +
|
| 167 | + Returns: |
| 168 | + str: The complete AQL filter condition. |
| 169 | + """ |
| 170 | + if not filters: |
| 171 | + return "" |
| 172 | + |
| 173 | + conditions = [] |
| 174 | + for k, v in filters.items(): |
| 175 | + field = k if "." in k else f"`{k}`" |
| 176 | + conditions.append(f"doc.{field} == {self.serializer.dumps(v)}") |
| 177 | + |
| 178 | + return "FILTER " + " AND ".join(conditions) |
| 179 | + |
| 180 | + @staticmethod |
| 181 | + def _is_none_or_int(obj: Any) -> bool: |
| 182 | + """Check if obj is `None` or a positive integer. |
| 183 | +
|
| 184 | + Args: |
| 185 | + obj: Object to check. |
| 186 | +
|
| 187 | + Returns: |
| 188 | + bool: `True` if object is `None` or a positive integer. |
| 189 | + """ |
| 190 | + return obj is None or isinstance(obj, int) and obj >= 0 |
| 191 | + |
| 192 | + @staticmethod |
| 193 | + def _is_none_or_dict(obj: Any) -> bool: |
| 194 | + """Check if obj is `None` or a dict. |
| 195 | +
|
| 196 | + Args: |
| 197 | + obj: Object to check. |
| 198 | +
|
| 199 | + Returns: |
| 200 | + bool: `True` if object is `None` or a dict. |
| 201 | + """ |
| 202 | + return obj is None or isinstance(obj, dict) |
| 203 | + |
| 204 | + @staticmethod |
| 205 | + def _validate_sort_parameters(sort: Optional[Jsons]) -> None: |
| 206 | + """Validate sort parameters for an AQL query. |
| 207 | +
|
| 208 | + Args: |
| 209 | + sort (list | None): Document sort parameters. |
| 210 | +
|
| 211 | + Raises: |
| 212 | + SortValidationError: If sort parameters are invalid. |
| 213 | + """ |
| 214 | + if not sort: |
| 215 | + return |
| 216 | + |
| 217 | + for param in sort: |
| 218 | + if "sort_by" not in param or "sort_order" not in param: |
| 219 | + raise SortValidationError( |
| 220 | + "Each sort parameter must have 'sort_by' and 'sort_order'." |
| 221 | + ) |
| 222 | + if param["sort_order"].upper() not in ["ASC", "DESC"]: |
| 223 | + raise SortValidationError("'sort_order' must be either 'ASC' or 'DESC'") |
| 224 | + |
| 225 | + @staticmethod |
| 226 | + def _build_sort_expression(sort: Optional[Jsons]) -> str: |
| 227 | + """Build a sort condition for an AQL query. |
| 228 | +
|
| 229 | + Args: |
| 230 | + sort (list | None): Document sort parameters. |
| 231 | +
|
| 232 | + Returns: |
| 233 | + str: The complete AQL sort condition. |
| 234 | + """ |
| 235 | + if not sort: |
| 236 | + return "" |
| 237 | + |
| 238 | + sort_chunks = [] |
| 239 | + for sort_param in sort: |
| 240 | + chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}" |
| 241 | + sort_chunks.append(chunk) |
| 242 | + |
| 243 | + return "SORT " + ", ".join(sort_chunks) |
| 244 | + |
159 | 245 | @property
|
160 | 246 | def name(self) -> str:
|
161 | 247 | """Return the name of the collection.
|
@@ -1006,3 +1092,74 @@ def response_handler(resp: Response) -> V:
|
1006 | 1092 | return self._doc_deserializer.loads_many(resp.raw_body)
|
1007 | 1093 |
|
1008 | 1094 | return await self._executor.execute(request, response_handler)
|
| 1095 | + |
| 1096 | + async def find( |
| 1097 | + self, |
| 1098 | + filters: Optional[Json] = None, |
| 1099 | + skip: Optional[int] = None, |
| 1100 | + limit: Optional[int | str] = None, |
| 1101 | + allow_dirty_read: Optional[bool] = False, |
| 1102 | + sort: Optional[Jsons] = None, |
| 1103 | + ) -> Result[Cursor]: |
| 1104 | + """Return all documents that match the given filters. |
| 1105 | +
|
| 1106 | + Args: |
| 1107 | + filters (dict | None): Query filters. |
| 1108 | + skip (int | None): Number of documents to skip. |
| 1109 | + limit (int | str | None): Maximum number of documents to return. |
| 1110 | + allow_dirty_read (bool): Allow reads from followers in a cluster. |
| 1111 | + sort (list | None): Document sort parameters. |
| 1112 | +
|
| 1113 | + Returns: |
| 1114 | + Cursor: Document cursor. |
| 1115 | +
|
| 1116 | + Raises: |
| 1117 | + DocumentGetError: If retrieval fails. |
| 1118 | + SortValidationError: If sort parameters are invalid. |
| 1119 | + """ |
| 1120 | + if not self._is_none_or_dict(filters): |
| 1121 | + raise ValueError("filters parameter must be a dict") |
| 1122 | + self._validate_sort_parameters(sort) |
| 1123 | + if not self._is_none_or_int(skip): |
| 1124 | + raise ValueError("skip parameter must be a non-negative int") |
| 1125 | + if not (self._is_none_or_int(limit) or limit == "null"): |
| 1126 | + raise ValueError("limit parameter must be a non-negative int") |
| 1127 | + |
| 1128 | + skip = skip if skip is not None else 0 |
| 1129 | + limit = limit if limit is not None else "null" |
| 1130 | + query = f""" |
| 1131 | + FOR doc IN @@collection |
| 1132 | + {self._build_filter_conditions(filters)} |
| 1133 | + LIMIT {skip}, {limit} |
| 1134 | + {self._build_sort_expression(sort)} |
| 1135 | + RETURN doc |
| 1136 | + """ |
| 1137 | + bind_vars = {"@collection": self.name} |
| 1138 | + data: Json = {"query": query, "bindVars": bind_vars, "count": True} |
| 1139 | + headers: RequestHeaders = {} |
| 1140 | + if allow_dirty_read is not None: |
| 1141 | + if allow_dirty_read is True: |
| 1142 | + headers["x-arango-allow-dirty-read"] = "true" |
| 1143 | + else: |
| 1144 | + headers["x-arango-allow-dirty-read"] = "false" |
| 1145 | + |
| 1146 | + request = Request( |
| 1147 | + method=Method.POST, |
| 1148 | + endpoint="/_api/cursor", |
| 1149 | + data=self.serializer.dumps(data), |
| 1150 | + headers=headers, |
| 1151 | + ) |
| 1152 | + |
| 1153 | + def response_handler(resp: Response) -> Cursor: |
| 1154 | + if not resp.is_success: |
| 1155 | + raise DocumentGetError(resp, request) |
| 1156 | + if self._executor.context == "async": |
| 1157 | + # We cannot have a cursor giving back async jobs |
| 1158 | + executor: NonAsyncExecutor = DefaultApiExecutor( |
| 1159 | + self._executor.connection |
| 1160 | + ) |
| 1161 | + else: |
| 1162 | + executor = cast(NonAsyncExecutor, self._executor) |
| 1163 | + return Cursor(executor, self.deserializer.loads(resp.raw_body)) |
| 1164 | + |
| 1165 | + return await self._executor.execute(request, response_handler) |
0 commit comments