Skip to content

Commit 917f699

Browse files
authored
Add batch_size parameter in import_bulk method (#207)
1 parent 21c9e5d commit 917f699

File tree

4 files changed

+68
-18
lines changed

4 files changed

+68
-18
lines changed

arango/collection.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from arango.response import Response
4343
from arango.result import Result
4444
from arango.typings import Fields, Headers, Json, Params
45-
from arango.utils import get_doc_id, is_none_or_int, is_none_or_str
45+
from arango.utils import get_batches, get_doc_id, is_none_or_int, is_none_or_str
4646

4747

4848
class Collection(ApiGroup):
@@ -1934,7 +1934,8 @@ def import_bulk(
19341934
overwrite: Optional[bool] = None,
19351935
on_duplicate: Optional[str] = None,
19361936
sync: Optional[bool] = None,
1937-
) -> Result[Json]:
1937+
batch_size: Optional[int] = None,
1938+
) -> Union[Result[Json], List[Result[Json]]]:
19381939
"""Insert multiple documents into the collection.
19391940
19401941
.. note::
@@ -1984,8 +1985,17 @@ def import_bulk(
19841985
:type on_duplicate: str
19851986
:param sync: Block until operation is synchronized to disk.
19861987
:type sync: bool | None
1988+
:param batch_size: Split up **documents** into batches of max length
1989+
**batch_size** and import them in a loop on the client side. If
1990+
**batch_size** is specified, the return type of this method
1991+
changes from a result object to a list of result objects.
1992+
IMPORTANT NOTE: this parameter may go through breaking changes
1993+
in the future where the return type may not be a list of result
1994+
objects anymore. Use it at your own risk, and avoid
1995+
depending on the return value if possible.
1996+
:type batch_size: int
19871997
:return: Result of the bulk import.
1988-
:rtype: dict
1998+
:rtype: dict | list[dict]
19891999
:raise arango.exceptions.DocumentInsertError: If import fails.
19902000
"""
19912001
documents = [self._ensure_key_from_id(doc) for doc in documents]
@@ -2006,21 +2016,35 @@ def import_bulk(
20062016
if sync is not None:
20072017
params["waitForSync"] = sync
20082018

2009-
request = Request(
2010-
method="post",
2011-
endpoint="/_api/import",
2012-
data=documents,
2013-
params=params,
2014-
write=self.name,
2015-
)
2016-
20172019
def response_handler(resp: Response) -> Json:
20182020
if resp.is_success:
20192021
result: Json = resp.body
20202022
return result
20212023
raise DocumentInsertError(resp, request)
20222024

2023-
return self._execute(request, response_handler)
2025+
if batch_size is None:
2026+
request = Request(
2027+
method="post",
2028+
endpoint="/_api/import",
2029+
data=documents,
2030+
params=params,
2031+
write=self.name,
2032+
)
2033+
2034+
return self._execute(request, response_handler)
2035+
else:
2036+
results = []
2037+
for batch in get_batches(documents, batch_size):
2038+
request = Request(
2039+
method="post",
2040+
endpoint="/_api/import",
2041+
data=batch,
2042+
params=params,
2043+
write=self.name,
2044+
)
2045+
results.append(self._execute(request, response_handler))
2046+
2047+
return results
20242048

20252049

20262050
class StandardCollection(Collection):

arango/database.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,11 +1225,13 @@ def create_graph(
12251225
12261226
.. code-block:: python
12271227
1228-
{
1229-
'edge_collection': 'teach',
1230-
'from_vertex_collections': ['teachers'],
1231-
'to_vertex_collections': ['lectures']
1232-
}
1228+
[
1229+
{
1230+
'edge_collection': 'teach',
1231+
'from_vertex_collections': ['teachers'],
1232+
'to_vertex_collections': ['lectures']
1233+
}
1234+
]
12331235
"""
12341236
data: Json = {"name": name, "options": dict()}
12351237
if edge_definitions is not None:

arango/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from contextlib import contextmanager
11-
from typing import Any, Iterator, Union
11+
from typing import Any, Iterator, Sequence, Union
1212

1313
from arango.exceptions import DocumentParseError
1414
from arango.typings import Json
@@ -82,3 +82,16 @@ def is_none_or_str(obj: Any) -> bool:
8282
:rtype: bool
8383
"""
8484
return obj is None or isinstance(obj, str)
85+
86+
87+
def get_batches(elements: Sequence[Json], batch_size: int) -> Iterator[Sequence[Json]]:
88+
"""Generator to split a list in batches
89+
of (maximum) **batch_size** elements each.
90+
91+
:param elements: The list of elements.
92+
:type elements: Sequence[Json]
93+
:param batch_size: Max number of elements per batch.
94+
:type batch_size: int
95+
"""
96+
for index in range(0, len(elements), batch_size):
97+
yield elements[index : index + batch_size]

tests/test_document.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,17 @@ def test_document_import_bulk(col, bad_col, docs):
18321832
assert col[doc_key]["loc"] == doc["loc"]
18331833
empty_collection(col)
18341834

1835+
# Test import bulk with batch_size
1836+
results = col.import_bulk(docs, batch_size=len(docs) // 2)
1837+
assert type(results) is list
1838+
assert len(results) == 2
1839+
empty_collection(col)
1840+
1841+
result = col.import_bulk(docs, batch_size=len(docs) * 2)
1842+
assert type(result) is list
1843+
assert len(result) == 1
1844+
empty_collection(col)
1845+
18351846
# Test import bulk on_duplicate actions
18361847
doc = docs[0]
18371848
doc_key = doc["_key"]

0 commit comments

Comments
 (0)