From 2d52a4fc46274df87cab9f890a4a232fd8065bde Mon Sep 17 00:00:00 2001 From: Jihan Yin Date: Tue, 5 Apr 2022 15:48:50 -0700 Subject: [PATCH] make annotation requests in batches --- nucleus/constants.py | 1 + nucleus/dataset.py | 10 ++++++---- nucleus/model_run.py | 8 ++++---- nucleus/utils.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/nucleus/constants.py b/nucleus/constants.py index 8358742d..a900e7d5 100644 --- a/nucleus/constants.py +++ b/nucleus/constants.py @@ -102,6 +102,7 @@ REFERENCE_ID_KEY = "reference_id" BACKEND_REFERENCE_ID_KEY = "ref_id" # TODO(355762): Our backend returns this instead of the "proper" key sometimes. REQUEST_ID_KEY = "requestId" +REQUEST_IDS_KEY = "requestIds" SCENES_KEY = "scenes" SERIALIZED_REQUEST_KEY = "serialized_request" SEGMENTATIONS_KEY = "segmentations" diff --git a/nucleus/dataset.py b/nucleus/dataset.py index 82b4f1e0..4ecc0620 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -20,6 +20,7 @@ format_prediction_response, paginate_generator, serialize_and_write_to_presigned_url, + serialize_and_write_to_presigned_urls_in_batches, ) from .annotation import Annotation, check_all_mask_paths_remote @@ -39,6 +40,7 @@ NAME_KEY, REFERENCE_IDS_KEY, REQUEST_ID_KEY, + REQUEST_IDS_KEY, SLICE_ID_KEY, UPDATE_KEY, VIDEO_UPLOAD_TYPE_KEY, @@ -390,11 +392,11 @@ def annotate( """ if asynchronous: check_all_mask_paths_remote(annotations) - request_id = serialize_and_write_to_presigned_url( + request_ids = serialize_and_write_to_presigned_urls_in_batches( annotations, self.id, self._client ) response = self._client.make_request( - payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update}, + payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update}, route=f"dataset/{self.id}/annotate?async=1", ) return AsyncJob.from_json(response, self._client) @@ -1384,11 +1386,11 @@ def upload_predictions( if asynchronous: check_all_mask_paths_remote(predictions) - request_id = serialize_and_write_to_presigned_url( + request_ids = serialize_and_write_to_presigned_urls_in_batches( predictions, self.id, self._client ) response = self._client.make_request( - payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update}, + payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update}, route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1", ) return AsyncJob.from_json(response, self._client) diff --git a/nucleus/model_run.py b/nucleus/model_run.py index 993eda7e..81e96724 100644 --- a/nucleus/model_run.py +++ b/nucleus/model_run.py @@ -22,13 +22,13 @@ from nucleus.job import AsyncJob from nucleus.utils import ( format_prediction_response, - serialize_and_write_to_presigned_url, + serialize_and_write_to_presigned_urls_in_batches, ) from .constants import ( ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE, - REQUEST_ID_KEY, + REQUEST_IDS_KEY, UPDATE_KEY, ) from .prediction import ( @@ -157,11 +157,11 @@ def predict( if asynchronous: check_all_mask_paths_remote(annotations) - request_id = serialize_and_write_to_presigned_url( + request_ids = serialize_and_write_to_presigned_urls_in_batches( annotations, self.dataset_id, self._client ) response = self._client.make_request( - payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update}, + payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update}, route=f"modelRun/{self.model_run_id}/predict?async=1", ) return AsyncJob.from_json(response, self._client) diff --git a/nucleus/utils.py b/nucleus/utils.py index 0c37f42a..7108144b 100644 --- a/nucleus/utils.py +++ b/nucleus/utils.py @@ -273,6 +273,34 @@ def upload_to_presigned_url(presigned_url: str, file_pointer: IO): ) +def serialize_and_write_to_presigned_urls_in_batches( + upload_units: Sequence[ + Union[DatasetItem, Annotation, LidarScene, VideoScene] + ], + dataset_id: str, + client, + batch_size: int = 10000, +): + """This helper function can be used to serialize a list of API objects to batches of NDJSON files.""" + request_ids = [] + for i in range(0, len(upload_units), batch_size): + upload_units_chunk = upload_units[i : i + batch_size] + request_id = uuid.uuid4().hex + response = client.make_request( + payload={}, + route=f"dataset/{dataset_id}/signedUrl/{request_id}", + requests_command=requests.get, + ) + + strio = io.StringIO() + serialize_and_write(upload_units_chunk, strio) + strio.seek(0) + upload_to_presigned_url(response["signed_url"], strio) + + request_ids.append(request_id) + return request_ids + + def serialize_and_write_to_presigned_url( upload_units: Sequence[ Union[DatasetItem, Annotation, LidarScene, VideoScene]