Skip to content

Commit 9040e04

Browse files
ardilaUbuntu
and
Ubuntu
authored
Refactor a lot of segmentation local upload and async logic (#256)
* work in progress * work in progress * Big refactor to make things cleaner + enable retries properly on infra flakes for local upload * work in progress refactor of annotation upload * Fixed segmentation bugs * Fix one more bug and remove use of annotate_segmentation endpoint * refactor tests and add segmentation local upload test * Tests passing * Review feedback * Initial pass at client changes for prediction segmentation upload * relevant tests pass Co-authored-by: Ubuntu <diego.ardila@scale.com>
1 parent d4e3d1b commit 9040e04

17 files changed

+951
-524
lines changed

conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ def dataset(CLIENT):
2828
CLIENT.delete_dataset(ds.id)
2929

3030

31+
@pytest.fixture()
32+
def model(CLIENT):
33+
model = CLIENT.create_model(TEST_DATASET_NAME, "fake_reference_id")
34+
yield model
35+
CLIENT.delete_model(model.id)
36+
37+
3138
if __name__ == "__main__":
3239
client = nucleus.NucleusClient(API_KEY)
3340
# ds = client.create_dataset("Test Dataset With Autotags")

nucleus/__init__.py

Lines changed: 0 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -460,93 +460,6 @@ def populate_dataset(
460460
dataset_items, batch_size=batch_size, update=update
461461
)
462462

463-
def annotate_dataset(
464-
self,
465-
dataset_id: str,
466-
annotations: Sequence[
467-
Union[
468-
BoxAnnotation,
469-
PolygonAnnotation,
470-
CuboidAnnotation,
471-
CategoryAnnotation,
472-
MultiCategoryAnnotation,
473-
SegmentationAnnotation,
474-
]
475-
],
476-
update: bool,
477-
batch_size: int = 5000,
478-
) -> Dict[str, object]:
479-
# TODO: deprecate in favor of Dataset.annotate invocation
480-
481-
# Split payload into segmentations and Box/Polygon
482-
segmentations = [
483-
ann
484-
for ann in annotations
485-
if isinstance(ann, SegmentationAnnotation)
486-
]
487-
other_annotations = [
488-
ann
489-
for ann in annotations
490-
if not isinstance(ann, SegmentationAnnotation)
491-
]
492-
493-
batches = [
494-
other_annotations[i : i + batch_size]
495-
for i in range(0, len(other_annotations), batch_size)
496-
]
497-
498-
semseg_batches = [
499-
segmentations[i : i + batch_size]
500-
for i in range(0, len(segmentations), batch_size)
501-
]
502-
503-
agg_response = {
504-
DATASET_ID_KEY: dataset_id,
505-
ANNOTATIONS_PROCESSED_KEY: 0,
506-
ANNOTATIONS_IGNORED_KEY: 0,
507-
ERRORS_KEY: [],
508-
}
509-
510-
total_batches = len(batches) + len(semseg_batches)
511-
512-
tqdm_batches = self.tqdm_bar(batches)
513-
514-
with self.tqdm_bar(total=total_batches) as pbar:
515-
for batch in tqdm_batches:
516-
payload = construct_annotation_payload(batch, update)
517-
response = self.make_request(
518-
payload, f"dataset/{dataset_id}/annotate"
519-
)
520-
pbar.update(1)
521-
if STATUS_CODE_KEY in response:
522-
agg_response[ERRORS_KEY] = response
523-
else:
524-
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
525-
ANNOTATIONS_PROCESSED_KEY
526-
]
527-
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
528-
ANNOTATIONS_IGNORED_KEY
529-
]
530-
agg_response[ERRORS_KEY] += response[ERRORS_KEY]
531-
532-
for s_batch in semseg_batches:
533-
payload = construct_segmentation_payload(s_batch, update)
534-
response = self.make_request(
535-
payload, f"dataset/{dataset_id}/annotate_segmentation"
536-
)
537-
pbar.update(1)
538-
if STATUS_CODE_KEY in response:
539-
agg_response[ERRORS_KEY] = response
540-
else:
541-
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
542-
ANNOTATIONS_PROCESSED_KEY
543-
]
544-
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
545-
ANNOTATIONS_IGNORED_KEY
546-
]
547-
548-
return agg_response
549-
550463
@deprecated(msg="Use Dataset.ingest_tasks instead")
551464
def ingest_tasks(self, dataset_id: str, payload: dict):
552465
dataset = self.get_dataset(dataset_id)
@@ -599,93 +512,6 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
599512
response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
600513
)
601514

602-
@deprecated("Use Dataset.upload_predictions instead.")
603-
def predict(
604-
self,
605-
annotations: List[
606-
Union[
607-
BoxPrediction,
608-
PolygonPrediction,
609-
CuboidPrediction,
610-
SegmentationPrediction,
611-
CategoryPrediction,
612-
]
613-
],
614-
model_run_id: Optional[str] = None,
615-
model_id: Optional[str] = None,
616-
dataset_id: Optional[str] = None,
617-
update: bool = False,
618-
batch_size: int = 5000,
619-
):
620-
if model_run_id is not None:
621-
assert model_id is None and dataset_id is None
622-
endpoint = f"modelRun/{model_run_id}/predict"
623-
else:
624-
assert (
625-
model_id is not None and dataset_id is not None
626-
), "Model ID and dataset ID are required if not using model run id."
627-
endpoint = (
628-
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
629-
)
630-
segmentations = [
631-
ann
632-
for ann in annotations
633-
if isinstance(ann, SegmentationPrediction)
634-
]
635-
636-
other_predictions = [
637-
ann
638-
for ann in annotations
639-
if not isinstance(ann, SegmentationPrediction)
640-
]
641-
642-
s_batches = [
643-
segmentations[i : i + batch_size]
644-
for i in range(0, len(segmentations), batch_size)
645-
]
646-
647-
batches = [
648-
other_predictions[i : i + batch_size]
649-
for i in range(0, len(other_predictions), batch_size)
650-
]
651-
652-
errors = []
653-
predictions_processed = 0
654-
predictions_ignored = 0
655-
656-
tqdm_batches = self.tqdm_bar(batches)
657-
658-
for batch in tqdm_batches:
659-
batch_payload = construct_box_predictions_payload(
660-
batch,
661-
update,
662-
)
663-
response = self.make_request(batch_payload, endpoint)
664-
if STATUS_CODE_KEY in response:
665-
errors.append(response)
666-
else:
667-
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
668-
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
669-
if ERRORS_KEY in response:
670-
errors += response[ERRORS_KEY]
671-
672-
for s_batch in s_batches:
673-
payload = construct_segmentation_payload(s_batch, update)
674-
response = self.make_request(payload, endpoint)
675-
# pbar.update(1)
676-
if STATUS_CODE_KEY in response:
677-
errors.append(response)
678-
else:
679-
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
680-
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
681-
682-
return {
683-
MODEL_RUN_ID_KEY: model_run_id,
684-
PREDICTIONS_PROCESSED_KEY: predictions_processed,
685-
PREDICTIONS_IGNORED_KEY: predictions_ignored,
686-
ERRORS_KEY: errors,
687-
}
688-
689515
@deprecated(
690516
"Model runs have been deprecated and will be removed. Use a Model instead."
691517
)

nucleus/annotation.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import os
23
from dataclasses import dataclass, field
34
from enum import Enum
45
from typing import Dict, List, Optional, Sequence, Type, Union
@@ -70,6 +71,15 @@ def to_json(self) -> str:
7071
"""Serializes annotation object to schematized JSON string."""
7172
return json.dumps(self.to_payload(), allow_nan=False)
7273

74+
def has_local_files_to_upload(self) -> bool:
75+
"""Returns True if annotation has local files that need to be uploaded.
76+
77+
Nearly all subclasses have no local files, so we default this to just return
78+
false. If the subclass has local files, it should override this method (but
79+
that is not the only thing required to get local upload of files to work.)
80+
"""
81+
return False
82+
7383

7484
@dataclass # pylint: disable=R0902
7585
class BoxAnnotation(Annotation): # pylint: disable=R0902
@@ -578,6 +588,26 @@ def to_payload(self) -> dict:
578588

579589
return payload
580590

591+
def has_local_files_to_upload(self) -> bool:
592+
"""Check if the mask url is local and needs to be uploaded."""
593+
if is_local_path(self.mask_url):
594+
if not os.path.isfile(self.mask_url):
595+
raise Exception(f"Mask file {self.mask_url} does not exist.")
596+
return True
597+
return False
598+
599+
def __eq__(self, other):
600+
if not isinstance(other, SegmentationAnnotation):
601+
return False
602+
self.annotations = sorted(self.annotations, key=lambda x: x.index)
603+
other.annotations = sorted(other.annotations, key=lambda x: x.index)
604+
return (
605+
(self.annotation_id == other.annotation_id)
606+
and (self.annotations == other.annotations)
607+
and (self.mask_url == other.mask_url)
608+
and (self.reference_id == other.reference_id)
609+
)
610+
581611

582612
class AnnotationTypes(Enum):
583613
BOX = BOX_TYPE
@@ -737,12 +767,12 @@ def is_local_path(path: str) -> bool:
737767

738768

739769
def check_all_mask_paths_remote(
740-
annotations: Sequence[Union[Annotation]],
770+
annotations: Sequence[Annotation],
741771
):
742772
for annotation in annotations:
743773
if hasattr(annotation, MASK_URL_KEY):
744774
if is_local_path(getattr(annotation, MASK_URL_KEY)):
745775
raise ValueError(
746776
"Found an annotation with a local path, which is not currently"
747-
f"supported. Use a remote path instead. {annotation}"
777+
f"supported for asynchronous upload. Use a remote path instead, or try synchronous upload. {annotation}"
748778
)

0 commit comments

Comments
 (0)