Skip to content

Refactor a lot of segmentation local upload and async logic #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 16, 2022
Merged
7 changes: 7 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def dataset(CLIENT):
CLIENT.delete_dataset(ds.id)


@pytest.fixture()
def model(CLIENT):
model = CLIENT.create_model(TEST_DATASET_NAME, "fake_reference_id")
yield model
CLIENT.delete_model(model.id)


if __name__ == "__main__":
client = nucleus.NucleusClient(API_KEY)
# ds = client.create_dataset("Test Dataset With Autotags")
Expand Down
174 changes: 0 additions & 174 deletions nucleus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,93 +460,6 @@ def populate_dataset(
dataset_items, batch_size=batch_size, update=update
)

def annotate_dataset(
self,
dataset_id: str,
annotations: Sequence[
Union[
BoxAnnotation,
PolygonAnnotation,
CuboidAnnotation,
CategoryAnnotation,
MultiCategoryAnnotation,
SegmentationAnnotation,
]
],
update: bool,
batch_size: int = 5000,
) -> Dict[str, object]:
# TODO: deprecate in favor of Dataset.annotate invocation

# Split payload into segmentations and Box/Polygon
segmentations = [
ann
for ann in annotations
if isinstance(ann, SegmentationAnnotation)
]
other_annotations = [
ann
for ann in annotations
if not isinstance(ann, SegmentationAnnotation)
]

batches = [
other_annotations[i : i + batch_size]
for i in range(0, len(other_annotations), batch_size)
]

semseg_batches = [
segmentations[i : i + batch_size]
for i in range(0, len(segmentations), batch_size)
]

agg_response = {
DATASET_ID_KEY: dataset_id,
ANNOTATIONS_PROCESSED_KEY: 0,
ANNOTATIONS_IGNORED_KEY: 0,
ERRORS_KEY: [],
}

total_batches = len(batches) + len(semseg_batches)

tqdm_batches = self.tqdm_bar(batches)

with self.tqdm_bar(total=total_batches) as pbar:
for batch in tqdm_batches:
payload = construct_annotation_payload(batch, update)
response = self.make_request(
payload, f"dataset/{dataset_id}/annotate"
)
pbar.update(1)
if STATUS_CODE_KEY in response:
agg_response[ERRORS_KEY] = response
else:
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
ANNOTATIONS_PROCESSED_KEY
]
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
ANNOTATIONS_IGNORED_KEY
]
agg_response[ERRORS_KEY] += response[ERRORS_KEY]

for s_batch in semseg_batches:
payload = construct_segmentation_payload(s_batch, update)
response = self.make_request(
payload, f"dataset/{dataset_id}/annotate_segmentation"
)
pbar.update(1)
if STATUS_CODE_KEY in response:
agg_response[ERRORS_KEY] = response
else:
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
ANNOTATIONS_PROCESSED_KEY
]
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
ANNOTATIONS_IGNORED_KEY
]

return agg_response

@deprecated(msg="Use Dataset.ingest_tasks instead")
def ingest_tasks(self, dataset_id: str, payload: dict):
dataset = self.get_dataset(dataset_id)
Expand Down Expand Up @@ -599,93 +512,6 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
)

@deprecated("Use Dataset.upload_predictions instead.")
def predict(
self,
annotations: List[
Union[
BoxPrediction,
PolygonPrediction,
CuboidPrediction,
SegmentationPrediction,
CategoryPrediction,
]
],
model_run_id: Optional[str] = None,
model_id: Optional[str] = None,
dataset_id: Optional[str] = None,
update: bool = False,
batch_size: int = 5000,
):
if model_run_id is not None:
assert model_id is None and dataset_id is None
endpoint = f"modelRun/{model_run_id}/predict"
else:
assert (
model_id is not None and dataset_id is not None
), "Model ID and dataset ID are required if not using model run id."
endpoint = (
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
)
segmentations = [
ann
for ann in annotations
if isinstance(ann, SegmentationPrediction)
]

other_predictions = [
ann
for ann in annotations
if not isinstance(ann, SegmentationPrediction)
]

s_batches = [
segmentations[i : i + batch_size]
for i in range(0, len(segmentations), batch_size)
]

batches = [
other_predictions[i : i + batch_size]
for i in range(0, len(other_predictions), batch_size)
]

errors = []
predictions_processed = 0
predictions_ignored = 0

tqdm_batches = self.tqdm_bar(batches)

for batch in tqdm_batches:
batch_payload = construct_box_predictions_payload(
batch,
update,
)
response = self.make_request(batch_payload, endpoint)
if STATUS_CODE_KEY in response:
errors.append(response)
else:
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
if ERRORS_KEY in response:
errors += response[ERRORS_KEY]

for s_batch in s_batches:
payload = construct_segmentation_payload(s_batch, update)
response = self.make_request(payload, endpoint)
# pbar.update(1)
if STATUS_CODE_KEY in response:
errors.append(response)
else:
predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
predictions_ignored += response[PREDICTIONS_IGNORED_KEY]

return {
MODEL_RUN_ID_KEY: model_run_id,
PREDICTIONS_PROCESSED_KEY: predictions_processed,
PREDICTIONS_IGNORED_KEY: predictions_ignored,
ERRORS_KEY: errors,
}

@deprecated(
"Model runs have been deprecated and will be removed. Use a Model instead."
)
Expand Down
34 changes: 32 additions & 2 deletions nucleus/annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Sequence, Type, Union
Expand Down Expand Up @@ -70,6 +71,15 @@ def to_json(self) -> str:
"""Serializes annotation object to schematized JSON string."""
return json.dumps(self.to_payload(), allow_nan=False)

def has_local_files_to_upload(self) -> bool:
"""Returns True if annotation has local files that need to be uploaded.

Nearly all subclasses have no local files, so we default this to just return
false. If the subclass has local files, it should override this method (but
that is not the only thing required to get local upload of files to work.)
"""
return False


@dataclass # pylint: disable=R0902
class BoxAnnotation(Annotation): # pylint: disable=R0902
Expand Down Expand Up @@ -578,6 +588,26 @@ def to_payload(self) -> dict:

return payload

def has_local_files_to_upload(self) -> bool:
"""Check if the mask url is local and needs to be uploaded."""
if is_local_path(self.mask_url):
if not os.path.isfile(self.mask_url):
raise Exception(f"Mask file {self.mask_url} does not exist.")
return True
return False

def __eq__(self, other):
if not isinstance(other, SegmentationAnnotation):
return False
self.annotations = sorted(self.annotations, key=lambda x: x.index)
other.annotations = sorted(other.annotations, key=lambda x: x.index)
return (
(self.annotation_id == other.annotation_id)
and (self.annotations == other.annotations)
and (self.mask_url == other.mask_url)
and (self.reference_id == other.reference_id)
)


class AnnotationTypes(Enum):
BOX = BOX_TYPE
Expand Down Expand Up @@ -737,12 +767,12 @@ def is_local_path(path: str) -> bool:


def check_all_mask_paths_remote(
annotations: Sequence[Union[Annotation]],
annotations: Sequence[Annotation],
):
for annotation in annotations:
if hasattr(annotation, MASK_URL_KEY):
if is_local_path(getattr(annotation, MASK_URL_KEY)):
raise ValueError(
"Found an annotation with a local path, which is not currently"
f"supported. Use a remote path instead. {annotation}"
f"supported for asynchronous upload. Use a remote path instead, or try synchronous upload. {annotation}"
)
Loading