Skip to content

Commit 0a61e9c

Browse files
committed
initial feature commit
1 parent 9b3daf6 commit 0a61e9c

File tree

3 files changed

+101
-18
lines changed

3 files changed

+101
-18
lines changed

arango/collection.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from arango.response import Response
4343
from arango.result import Result
4444
from arango.typings import Fields, Headers, Json, Params
45-
from arango.utils import get_doc_id, is_none_or_int, is_none_or_str
45+
from arango.utils import get_batches, get_doc_id, is_none_or_int, is_none_or_str
4646

4747

4848
class Collection(ApiGroup):
@@ -187,18 +187,27 @@ def _ensure_key_in_body(self, body: Json) -> Json:
187187
return body
188188
raise DocumentParseError('field "_key" or "_id" required')
189189

190-
def _ensure_key_from_id(self, body: Json) -> Json:
190+
def _ensure_key_from_id(self, body: Json, index: Optional[int] = None) -> Json:
191191
"""Return the body with "_key" field if it has "_id" field.
192+
If it has neither, set the "_key" value to i, where i
193+
is the document's index position in the sequence.
192194
193195
:param body: Document body.
194196
:type body: dict
197+
:param index: Document index value in the original list of documents.
198+
:param index: int | None
195199
:return: Document body with "_key" field if it has "_id" field.
196200
:rtype: dict
197201
"""
198202
if "_id" in body and "_key" not in body:
199203
doc_id = self._validate_id(body["_id"])
200204
body = body.copy()
201205
body["_key"] = doc_id[len(self._id_prefix) :]
206+
207+
if "_id" not in body and "_key" not in body:
208+
body = body.copy()
209+
body["_key"] = str(index)
210+
202211
return body
203212

204213
@property
@@ -1934,7 +1943,8 @@ def import_bulk(
19341943
overwrite: Optional[bool] = None,
19351944
on_duplicate: Optional[str] = None,
19361945
sync: Optional[bool] = None,
1937-
) -> Result[Json]:
1946+
batch_size: Optional[int] = None,
1947+
) -> Union[Result[Json], List[Result[Json]]]:
19381948
"""Insert multiple documents into the collection.
19391949
19401950
.. note::
@@ -1984,11 +1994,16 @@ def import_bulk(
19841994
:type on_duplicate: str
19851995
:param sync: Block until operation is synchronized to disk.
19861996
:type sync: bool | None
1997+
:param batch_size: Max number of documents to import at once. If
1998+
unspecified, will import all documents at once.
1999+
:type batch_size: int | None
19872000
:return: Result of the bulk import.
19882001
:rtype: dict
19892002
:raise arango.exceptions.DocumentInsertError: If import fails.
19902003
"""
1991-
documents = [self._ensure_key_from_id(doc) for doc in documents]
2004+
documents = [
2005+
self._ensure_key_from_id(doc, i) for i, doc in enumerate(documents, 1)
2006+
]
19922007

19932008
params: Params = {"type": "array", "collection": self.name}
19942009
if halt_on_error is not None:
@@ -2006,21 +2021,25 @@ def import_bulk(
20062021
if sync is not None:
20072022
params["waitForSync"] = sync
20082023

2009-
request = Request(
2010-
method="post",
2011-
endpoint="/_api/import",
2012-
data=documents,
2013-
params=params,
2014-
write=self.name,
2015-
)
2016-
20172024
def response_handler(resp: Response) -> Json:
20182025
if resp.is_success:
20192026
result: Json = resp.body
20202027
return result
20212028
raise DocumentInsertError(resp, request)
20222029

2023-
return self._execute(request, response_handler)
2030+
result = []
2031+
for batch in get_batches(documents, batch_size):
2032+
request = Request(
2033+
method="post",
2034+
endpoint="/_api/import",
2035+
data=batch,
2036+
params=params,
2037+
write=self.name,
2038+
)
2039+
2040+
result.append(self._execute(request, response_handler))
2041+
2042+
return result[0] if len(result) == 1 else result
20242043

20252044

20262045
class StandardCollection(Collection):

arango/database.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,7 @@ def create_graph(
11701170
shard_count: Optional[int] = None,
11711171
replication_factor: Optional[int] = None,
11721172
write_concern: Optional[int] = None,
1173+
collections: Optional[Json] = None,
11731174
) -> Result[Graph]:
11741175
"""Create a new graph.
11751176
@@ -1217,18 +1218,49 @@ def create_graph(
12171218
parameter cannot be larger than that of **replication_factor**.
12181219
Default value is 1. Used for clusters only.
12191220
:type write_concern: int
1221+
:param collections: A list collection data objects to provision
1222+
the graph with. See below for example.
1223+
:type collections: dict | None
12201224
:return: Graph API wrapper.
12211225
:rtype: arango.graph.Graph
12221226
:raise arango.exceptions.GraphCreateError: If create fails.
12231227
12241228
Here is an example entry for parameter **edge_definitions**:
12251229
1230+
.. code-block:: python
1231+
1232+
[
1233+
{
1234+
'edge_collection': 'teach',
1235+
'from_vertex_collections': ['teachers'],
1236+
'to_vertex_collections': ['lectures']
1237+
}
1238+
]
1239+
1240+
Here is an example entry for parameter **collections**:
1241+
TODO: Rework **collections** data structure?
12261242
.. code-block:: python
12271243
12281244
{
1229-
'edge_collection': 'teach',
1230-
'from_vertex_collections': ['teachers'],
1231-
'to_vertex_collections': ['lectures']
1245+
'teachers': {
1246+
'docs': teacher_vertices_to_insert
1247+
'options': {
1248+
'overwrite' = True,
1249+
'sync' = True,
1250+
'batch_size' = 50
1251+
}
1252+
},
1253+
'lectures': {
1254+
'docs': lecture_vertices_to_insert
1255+
'options': {
1256+
'overwrite' = False,
1257+
'sync' = False,
1258+
'batch_size' = 4
1259+
}
1260+
},
1261+
'teach': {
1262+
'docs': teach_edges_to_insert
1263+
}
12321264
}
12331265
"""
12341266
data: Json = {"name": name, "options": dict()}
@@ -1263,7 +1295,15 @@ def response_handler(resp: Response) -> Graph:
12631295
return Graph(self._conn, self._executor, name)
12641296
raise GraphCreateError(resp, request)
12651297

1266-
return self._execute(request, response_handler)
1298+
graph = self._execute(request, response_handler)
1299+
1300+
if collections is not None:
1301+
for name, data in collections.items():
1302+
self.collection(name).import_bulk(
1303+
data["docs"], **data.get("options", {})
1304+
)
1305+
1306+
return graph
12671307

12681308
def delete_graph(
12691309
self,

arango/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from contextlib import contextmanager
11-
from typing import Any, Iterator, Union
11+
from typing import Any, Iterator, List, Optional, Union
1212

1313
from arango.exceptions import DocumentParseError
1414
from arango.typings import Json
@@ -82,3 +82,27 @@ def is_none_or_str(obj: Any) -> bool:
8282
:rtype: bool
8383
"""
8484
return obj is None or isinstance(obj, str)
85+
86+
87+
def get_batches(
88+
l: List[Any], batch_size: Optional[int] = None
89+
) -> Union[List[List[Any]], Iterator[List[Any]]]:
90+
"""Generator to split a list in batches
91+
of (maximum) **batch_size** elements each.
92+
If **batch_size** is invalid, return entire
93+
list as one batch.
94+
95+
:param l: The list of elements.
96+
:type l: list
97+
:param batch_size: Number of elements per batch.
98+
:type batch_size: int | None
99+
"""
100+
if batch_size is None or batch_size <= 0 or batch_size >= len(l):
101+
return [l]
102+
103+
def generator() -> Iterator[List[Any]]:
104+
n = int(batch_size) # type: ignore # (false positive)
105+
for i in range(0, len(l), n):
106+
yield l[i : i + n]
107+
108+
return generator()

0 commit comments

Comments
 (0)