Skip to content

Commit 3dce534

Browse files
author
Ubuntu
committed
Initial pass at client changes for prediction segmentation upload
1 parent aea3111 commit 3dce534

File tree

4 files changed

+102
-108
lines changed

4 files changed

+102
-108
lines changed

nucleus/__init__.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -512,93 +512,6 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
512512
response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
513513
)
514514

515-
@deprecated("Use Dataset.upload_predictions instead.")
516-
def predict(
517-
self,
518-
annotations: List[
519-
Union[
520-
BoxPrediction,
521-
PolygonPrediction,
522-
CuboidPrediction,
523-
SegmentationPrediction,
524-
CategoryPrediction,
525-
]
526-
],
527-
model_run_id: Optional[str] = None,
528-
model_id: Optional[str] = None,
529-
dataset_id: Optional[str] = None,
530-
update: bool = False,
531-
batch_size: int = 5000,
532-
):
533-
if model_run_id is not None:
534-
assert model_id is None and dataset_id is None
535-
endpoint = f"modelRun/{model_run_id}/predict"
536-
else:
537-
assert (
538-
model_id is not None and dataset_id is not None
539-
), "Model ID and dataset ID are required if not using model run id."
540-
endpoint = (
541-
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
542-
)
543-
segmentations = [
544-
ann
545-
for ann in annotations
546-
if isinstance(ann, SegmentationPrediction)
547-
]
548-
549-
other_predictions = [
550-
ann
551-
for ann in annotations
552-
if not isinstance(ann, SegmentationPrediction)
553-
]
554-
555-
s_batches = [
556-
segmentations[i : i + batch_size]
557-
for i in range(0, len(segmentations), batch_size)
558-
]
559-
560-
batches = [
561-
other_predictions[i : i + batch_size]
562-
for i in range(0, len(other_predictions), batch_size)
563-
]
564-
565-
errors = []
566-
predictions_processed = 0
567-
predictions_ignored = 0
568-
569-
tqdm_batches = self.tqdm_bar(batches)
570-
571-
for batch in tqdm_batches:
572-
batch_payload = construct_box_predictions_payload(
573-
batch,
574-
update,
575-
)
576-
response = self.make_request(batch_payload, endpoint)
577-
if STATUS_CODE_KEY in response:
578-
errors.append(response)
579-
else:
580-
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
581-
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
582-
if ERRORS_KEY in response:
583-
errors += response[ERRORS_KEY]
584-
585-
for s_batch in s_batches:
586-
payload = construct_segmentation_payload(s_batch, update)
587-
response = self.make_request(payload, endpoint)
588-
# pbar.update(1)
589-
if STATUS_CODE_KEY in response:
590-
errors.append(response)
591-
else:
592-
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
593-
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
594-
595-
return {
596-
MODEL_RUN_ID_KEY: model_run_id,
597-
PREDICTIONS_PROCESSED_KEY: predictions_processed,
598-
PREDICTIONS_IGNORED_KEY: predictions_ignored,
599-
ERRORS_KEY: errors,
600-
}
601-
602515
@deprecated(
603516
"Model runs have been deprecated and will be removed. Use a Model instead."
604517
)

nucleus/annotation_uploader.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import TYPE_CHECKING, Iterable, List, Sequence
2+
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
33

44
from nucleus.annotation import Annotation, SegmentationAnnotation
55
from nucleus.async_utils import (
@@ -34,12 +34,14 @@ def accumulate_dict_values(dicts: Iterable[dict]):
3434
class AnnotationUploader:
3535
"""This is a helper class not intended for direct use. Please use dataset.annotate.
3636
37-
This class is purely a helper class for implementing dataset.annotate.
37+
This class is purely a helper class for implementing dataset.annotate/dataset.predict.
3838
"""
3939

40-
def __init__(self, dataset_id: str, client: "NucleusClient"): # noqa: F821
41-
self.dataset_id = dataset_id
40+
def __init__(
41+
self, dataset_id: Optional[str], client: "NucleusClient"
42+
): # noqa: F821
4243
self._client = client
44+
self._route = f"dataset/{dataset_id}/annotate"
4345

4446
def upload(
4547
self,
@@ -83,7 +85,7 @@ def upload(
8385
# segmentation will take a lot longer for the server to process than a single
8486
# annotation of any other kind.
8587
responses.extend(
86-
self.make_batched_annotate_requests(
88+
self.make_batched_requests(
8789
segmentations_with_remote_files,
8890
update,
8991
batch_size=remote_files_per_upload_request,
@@ -92,7 +94,7 @@ def upload(
9294
)
9395
if annotations_without_files:
9496
responses.extend(
95-
self.make_batched_annotate_requests(
97+
self.make_batched_requests(
9698
annotations_without_files,
9799
update,
98100
batch_size=batch_size,
@@ -102,7 +104,7 @@ def upload(
102104

103105
return accumulate_dict_values(responses)
104106

105-
def make_batched_annotate_requests(
107+
def make_batched_requests(
106108
self,
107109
annotations: Sequence[Annotation],
108110
update: bool,
@@ -120,9 +122,7 @@ def make_batched_annotate_requests(
120122
for batch in self._client.tqdm_bar(batches, desc=progress_bar_name):
121123
payload = construct_annotation_payload(batch, update)
122124
responses.append(
123-
self._client.make_request(
124-
payload, route=f"dataset/{self.dataset_id}/annotate"
125-
)
125+
self._client.make_request(payload, route=self._route)
126126
)
127127
return responses
128128

@@ -149,7 +149,7 @@ def make_batched_file_form_data_requests(
149149
return make_many_form_data_requests_concurrently(
150150
client=self._client,
151151
requests=requests,
152-
route=f"dataset/{self.dataset_id}/annotate",
152+
route=self._route,
153153
progressbar=progressbar,
154154
concurrency=local_file_upload_concurrency,
155155
)
@@ -202,3 +202,25 @@ def fn():
202202
return form_data, file_pointers
203203

204204
return fn
205+
206+
207+
class PredictionUploader(AnnotationUploader):
208+
def __init__(
209+
self,
210+
client: "NucleusClient",
211+
dataset_id: Optional[str] = None,
212+
model_id: Optional[str] = None,
213+
model_run_id: Optional[str] = None,
214+
):
215+
super().__init__(dataset_id, client)
216+
self._client = client
217+
if model_run_id is not None:
218+
assert model_id is None and dataset_id is None
219+
self._route = f"modelRun/{model_run_id}/predict"
220+
else:
221+
assert (
222+
model_id is not None and dataset_id is not None
223+
), "Model ID and dataset ID are required if not using model run id."
224+
self._route = (
225+
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
226+
)

nucleus/dataset.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import requests
55

6-
from nucleus.annotation_uploader import AnnotationUploader
6+
from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader
77
from nucleus.job import AsyncJob
88
from nucleus.prediction import (
99
BoxPrediction,
@@ -347,6 +347,7 @@ def annotate(
347347
request. Segmentations have either local or remote files, if you are
348348
getting timeouts while uploading segmentations with local files, you
349349
should lower this value from its default of 10. The maximum is 10.
350+
local_file_upload_concurrency: Number of concurrent local file uploads.
350351
351352
352353
Returns:
@@ -1283,6 +1284,10 @@ def upload_predictions(
12831284
],
12841285
update: bool = False,
12851286
asynchronous: bool = False,
1287+
batch_size: int = 5000,
1288+
remote_files_per_upload_request: int = 20,
1289+
local_files_per_upload_request: int = 10,
1290+
local_file_upload_concurrency: int = 30,
12861291
):
12871292
"""Uploads predictions and associates them with an existing :class:`Model`.
12881293
@@ -1325,6 +1330,21 @@ def upload_predictions(
13251330
collision. Default is False.
13261331
asynchronous: Whether or not to process the upload asynchronously (and
13271332
return an :class:`AsyncJob` object). Default is False.
1333+
batch_size: Number of predictions processed in each concurrent batch.
1334+
Default is 5000. If you get timeouts when uploading geometric predictions,
1335+
you can try lowering this batch size. This is only relevant for
1336+
asynchronous=False
1337+
remote_files_per_upload_request: Number of remote files to upload in each
1338+
request. Segmentations have either local or remote files, if you are
1339+
getting timeouts while uploading segmentations with remote urls, you
1340+
should lower this value from its default of 20. This is only relevant for
1341+
asynchronous=False.
1342+
local_files_per_upload_request: Number of local files to upload in each
1343+
request. Segmentations have either local or remote files, if you are
1344+
getting timeouts while uploading segmentations with local files, you
1345+
should lower this value from its default of 10. The maximum is 10.
1346+
This is only relevant for asynchronous=False
1347+
local_file_upload_concurrency: Number of concurrent local file uploads.
13281348
13291349
Returns:
13301350
Payload describing the synchronous upload::
@@ -1348,12 +1368,19 @@ def upload_predictions(
13481368
)
13491369
return AsyncJob.from_json(response, self._client)
13501370
else:
1351-
return self._client.predict(
1371+
uploader = PredictionUploader(
13521372
model_run_id=None,
13531373
dataset_id=self.id,
13541374
model_id=model.id,
1375+
client=self._client,
1376+
)
1377+
return uploader.upload(
13551378
annotations=predictions,
1379+
batch_size=batch_size,
13561380
update=update,
1381+
remote_files_per_upload_request=remote_files_per_upload_request,
1382+
local_files_per_upload_request=local_files_per_upload_request,
1383+
local_file_upload_concurrency=local_file_upload_concurrency,
13571384
)
13581385

13591386
def predictions_iloc(self, model, index):

nucleus/model_run.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import requests
1919

2020
from nucleus.annotation import check_all_mask_paths_remote
21+
from nucleus.annotation_uploader import PredictionUploader
2122
from nucleus.job import AsyncJob
2223
from nucleus.utils import (
2324
format_prediction_response,
@@ -114,12 +115,38 @@ def predict(
114115
SegmentationPrediction,
115116
]
116117
],
117-
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
118+
update: bool = DEFAULT_ANNOTATION_UPDATE_MODE,
118119
asynchronous: bool = False,
120+
batch_size: int = 5000,
121+
remote_files_per_upload_request: int = 20,
122+
local_files_per_upload_request: int = 10,
123+
local_file_upload_concurrency: int = 30,
119124
) -> Union[dict, AsyncJob]:
120125
"""
121126
Uploads model outputs as predictions for a model_run. Returns info about the upload.
122-
:param annotations: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
127+
128+
Args:
129+
annotations: Predictions to upload for this model run,
130+
update: If True, existing predictions for the same (reference_id, annotation_id)
131+
will be overwritten. If False, existing predictions will be skipped.
132+
asynchronous: Whether or not to process the upload asynchronously (and
133+
return an :class:`AsyncJob` object). Default is False.
134+
batch_size: Number of predictions processed in each concurrent batch.
135+
Default is 5000. If you get timeouts when uploading geometric annotations,
136+
you can try lowering this batch size. This is only relevant for
137+
asynchronous=False.
138+
remote_files_per_upload_request: Number of remote files to upload in each
139+
request. Segmentations have either local or remote files, if you are
140+
getting timeouts while uploading segmentations with remote urls, you
141+
should lower this value from its default of 20. This is only relevant for
142+
asynchronous=False
143+
local_files_per_upload_request: Number of local files to upload in each
144+
request. Segmentations have either local or remote files, if you are
145+
getting timeouts while uploading segmentations with local files, you
146+
should lower this value from its default of 10. The maximum is 10.
147+
This is only relevant for asynchronous=False
148+
local_file_upload_concurrency: Number of concurrent local file uploads.
149+
This is only relevant for asynchronous=False
123150
:return:
124151
{
125152
"model_run_id": str,
@@ -138,12 +165,17 @@ def predict(
138165
route=f"modelRun/{self.model_run_id}/predict?async=1",
139166
)
140167
return AsyncJob.from_json(response, self._client)
141-
else:
142-
return self._client.predict(
143-
model_run_id=self.model_run_id,
144-
annotations=annotations,
145-
update=update,
146-
)
168+
uploader = PredictionUploader(
169+
model_run_id=self.model_run_id, client=self._client
170+
)
171+
return uploader.upload(
172+
annotations=annotations,
173+
update=update,
174+
batch_size=batch_size,
175+
remote_files_per_upload_request=remote_files_per_upload_request,
176+
local_files_per_upload_request=local_files_per_upload_request,
177+
local_file_upload_concurrency=local_file_upload_concurrency,
178+
)
147179

148180
def iloc(self, i: int):
149181
"""

0 commit comments

Comments
 (0)