From 8b460152d91c6cb380fb0e45de59a1db3ae57215 Mon Sep 17 00:00:00 2001 From: Sasha Harrison Date: Tue, 15 Feb 2022 11:15:24 -0800 Subject: [PATCH] remove unused import --- nucleus/constants.py | 2 ++ nucleus/dataset.py | 16 ++++++++++------ nucleus/model_run.py | 8 ++++---- nucleus/utils.py | 44 +++++++++++++++++++++++++++++++------------- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/nucleus/constants.py b/nucleus/constants.py index de29b482..52d61c7d 100644 --- a/nucleus/constants.py +++ b/nucleus/constants.py @@ -24,6 +24,7 @@ AUTOTAG_SCORE_THRESHOLD = "score_threshold" EXPORTED_ROWS = "exportedRows" CAMERA_PARAMS_KEY = "camera_params" +CHUNK_SIZE = 500000 CLASS_PDF_KEY = "class_pdf" CONFIDENCE_KEY = "confidence" CX_KEY = "cx" @@ -88,6 +89,7 @@ REFERENCE_IDS_KEY = "reference_ids" REFERENCE_ID_KEY = "reference_id" REQUEST_ID_KEY = "requestId" +REQUEST_IDS_KEY = "requestIds" SCENES_KEY = "scenes" SEGMENTATIONS_KEY = "segmentations" SLICE_ID_KEY = "slice_id" diff --git a/nucleus/dataset.py b/nucleus/dataset.py index df44e270..7d866b4d 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -45,6 +45,7 @@ NAME_KEY, REFERENCE_IDS_KEY, REQUEST_ID_KEY, + REQUEST_IDS_KEY, SLICE_ID_KEY, UPDATE_KEY, ) @@ -365,11 +366,11 @@ def annotate( check_all_mask_paths_remote(annotations) if asynchronous: - request_id = serialize_and_write_to_presigned_url( - annotations, self.id, self._client + request_ids = serialize_and_write_to_presigned_url( + annotations, self.id, self._client, can_shard=True ) response = self._client.make_request( - payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update}, + payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update}, route=f"dataset/{self.id}/annotate?async=1", ) return AsyncJob.from_json(response, self._client) @@ -1266,11 +1267,14 @@ def upload_predictions( if asynchronous: check_all_mask_paths_remote(predictions) - request_id = serialize_and_write_to_presigned_url( - predictions, self.id, self._client + request_ids = serialize_and_write_to_presigned_url( + predictions, + self.id, + self._client, + can_shard=True, ) response = self._client.make_request( - payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update}, + payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update}, route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1", ) return AsyncJob.from_json(response, self._client) diff --git a/nucleus/model_run.py b/nucleus/model_run.py index 137d7cf8..87e5097b 100644 --- a/nucleus/model_run.py +++ b/nucleus/model_run.py @@ -27,7 +27,7 @@ from .constants import ( ANNOTATIONS_KEY, DEFAULT_ANNOTATION_UPDATE_MODE, - REQUEST_ID_KEY, + REQUEST_IDS_KEY, UPDATE_KEY, ) from .prediction import ( @@ -130,11 +130,11 @@ def predict( if asynchronous: check_all_mask_paths_remote(annotations) - request_id = serialize_and_write_to_presigned_url( - annotations, self.dataset_id, self._client + request_ids = serialize_and_write_to_presigned_url( + annotations, self.dataset_id, self._client, can_shard=True ) response = self._client.make_request( - payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update}, + payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update}, route=f"modelRun/{self.model_run_id}/predict?async=1", ) return AsyncJob.from_json(response, self._client) diff --git a/nucleus/utils.py b/nucleus/utils.py index 98071acf..d34e2975 100644 --- a/nucleus/utils.py +++ b/nucleus/utils.py @@ -24,6 +24,7 @@ ANNOTATIONS_KEY, BOX_TYPE, CATEGORY_TYPE, + CHUNK_SIZE, CUBOID_TYPE, ITEM_KEY, MULTICATEGORY_TYPE, @@ -48,6 +49,12 @@ } +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + class KeyErrorDict(dict): """Wrapper for response dicts with deprecated keys. @@ -257,20 +264,31 @@ def serialize_and_write_to_presigned_url( upload_units: Sequence[Union[DatasetItem, Annotation, LidarScene]], dataset_id: str, client, -): + can_shard: bool = False, +) -> Union[Sequence[str], str]: """This helper function can be used to serialize a list of API objects to NDJSON.""" - request_id = uuid.uuid4().hex - response = client.make_request( - payload={}, - route=f"dataset/{dataset_id}/signedUrl/{request_id}", - requests_command=requests.get, - ) - - strio = io.StringIO() - serialize_and_write(upload_units, strio) - strio.seek(0) - upload_to_presigned_url(response["signed_url"], strio) - return request_id + + def upload(items): + request_id = uuid.uuid4().hex + response = client.make_request( + payload={}, + route=f"dataset/{dataset_id}/signedUrl/{request_id}", + requests_command=requests.get, + ) + + strio = io.StringIO() + serialize_and_write(items, strio) + strio.seek(0) + upload_to_presigned_url(response["signed_url"], strio) + return request_id + + if can_shard: + request_ids = [] + for chunk in list(chunks(upload_units, CHUNK_SIZE)): + request_ids.append(upload(chunk)) + return request_ids + else: + return upload(upload_units) def replace_double_slashes(s: str) -> str: