Skip to content

Add batch_size parameter in import_bulk method #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from arango.response import Response
from arango.result import Result
from arango.typings import Fields, Headers, Json, Params
from arango.utils import get_doc_id, is_none_or_int, is_none_or_str
from arango.utils import get_batches, get_doc_id, is_none_or_int, is_none_or_str


class Collection(ApiGroup):
Expand Down Expand Up @@ -1934,7 +1934,8 @@ def import_bulk(
overwrite: Optional[bool] = None,
on_duplicate: Optional[str] = None,
sync: Optional[bool] = None,
) -> Result[Json]:
batch_size: Optional[int] = None,
) -> Union[Result[Json], List[Result[Json]]]:
"""Insert multiple documents into the collection.

.. note::
Expand Down Expand Up @@ -1984,8 +1985,17 @@ def import_bulk(
:type on_duplicate: str
:param sync: Block until operation is synchronized to disk.
:type sync: bool | None
:param batch_size: Split up **documents** into batches of max length
**batch_size** and import them in a loop on the client side. If
**batch_size** is specified, the return type of this method
changes from a result object to a list of result objects.
IMPORTANT NOTE: this parameter may go through breaking changes
in the future where the return type may not be a list of result
objects anymore. Use it at your own risk, and avoid
depending on the return value if possible.
:type batch_size: int
:return: Result of the bulk import.
:rtype: dict
:rtype: dict | list[dict]
:raise arango.exceptions.DocumentInsertError: If import fails.
"""
documents = [self._ensure_key_from_id(doc) for doc in documents]
Expand All @@ -2006,21 +2016,35 @@ def import_bulk(
if sync is not None:
params["waitForSync"] = sync

request = Request(
method="post",
endpoint="/_api/import",
data=documents,
params=params,
write=self.name,
)

def response_handler(resp: Response) -> Json:
if resp.is_success:
result: Json = resp.body
return result
raise DocumentInsertError(resp, request)

return self._execute(request, response_handler)
if batch_size is None:
request = Request(
method="post",
endpoint="/_api/import",
data=documents,
params=params,
write=self.name,
)

return self._execute(request, response_handler)
else:
results = []
for batch in get_batches(documents, batch_size):
request = Request(
method="post",
endpoint="/_api/import",
data=batch,
params=params,
write=self.name,
)
results.append(self._execute(request, response_handler))

return results


class StandardCollection(Collection):
Expand Down
12 changes: 7 additions & 5 deletions arango/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,11 +1225,13 @@ def create_graph(

.. code-block:: python

{
'edge_collection': 'teach',
'from_vertex_collections': ['teachers'],
'to_vertex_collections': ['lectures']
}
[
{
'edge_collection': 'teach',
'from_vertex_collections': ['teachers'],
'to_vertex_collections': ['lectures']
}
]
"""
data: Json = {"name": name, "options": dict()}
if edge_definitions is not None:
Expand Down
15 changes: 14 additions & 1 deletion arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from contextlib import contextmanager
from typing import Any, Iterator, Union
from typing import Any, Iterator, Sequence, Union

from arango.exceptions import DocumentParseError
from arango.typings import Json
Expand Down Expand Up @@ -82,3 +82,16 @@ def is_none_or_str(obj: Any) -> bool:
:rtype: bool
"""
return obj is None or isinstance(obj, str)


def get_batches(elements: Sequence[Json], batch_size: int) -> Iterator[Sequence[Json]]:
"""Generator to split a list in batches
of (maximum) **batch_size** elements each.

:param elements: The list of elements.
:type elements: Sequence[Json]
:param batch_size: Max number of elements per batch.
:type batch_size: int
"""
for index in range(0, len(elements), batch_size):
yield elements[index : index + batch_size]
11 changes: 11 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,17 @@ def test_document_import_bulk(col, bad_col, docs):
assert col[doc_key]["loc"] == doc["loc"]
empty_collection(col)

# Test import bulk with batch_size
results = col.import_bulk(docs, batch_size=len(docs) // 2)
assert type(results) is list
assert len(results) == 2
empty_collection(col)

result = col.import_bulk(docs, batch_size=len(docs) * 2)
assert type(result) is list
assert len(result) == 1
empty_collection(col)

# Test import bulk on_duplicate actions
doc = docs[0]
doc_key = doc["_key"]
Expand Down