Skip to content

Commit 59ce11e

Browse files
author
Ubuntu
committed
work in progress refactor of annotation upload
1 parent 4106f8e commit 59ce11e

File tree

7 files changed

+145
-271
lines changed

7 files changed

+145
-271
lines changed

nucleus/__init__.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -460,93 +460,6 @@ def populate_dataset(
460460
dataset_items, batch_size=batch_size, update=update
461461
)
462462

463-
def annotate_dataset(
464-
self,
465-
dataset_id: str,
466-
annotations: Sequence[
467-
Union[
468-
BoxAnnotation,
469-
PolygonAnnotation,
470-
CuboidAnnotation,
471-
CategoryAnnotation,
472-
MultiCategoryAnnotation,
473-
SegmentationAnnotation,
474-
]
475-
],
476-
update: bool,
477-
batch_size: int = 5000,
478-
) -> Dict[str, object]:
479-
# TODO: deprecate in favor of Dataset.annotate invocation
480-
481-
# Split payload into segmentations and Box/Polygon
482-
segmentations = [
483-
ann
484-
for ann in annotations
485-
if isinstance(ann, SegmentationAnnotation)
486-
]
487-
other_annotations = [
488-
ann
489-
for ann in annotations
490-
if not isinstance(ann, SegmentationAnnotation)
491-
]
492-
493-
batches = [
494-
other_annotations[i : i + batch_size]
495-
for i in range(0, len(other_annotations), batch_size)
496-
]
497-
498-
semseg_batches = [
499-
segmentations[i : i + batch_size]
500-
for i in range(0, len(segmentations), batch_size)
501-
]
502-
503-
agg_response = {
504-
DATASET_ID_KEY: dataset_id,
505-
ANNOTATIONS_PROCESSED_KEY: 0,
506-
ANNOTATIONS_IGNORED_KEY: 0,
507-
ERRORS_KEY: [],
508-
}
509-
510-
total_batches = len(batches) + len(semseg_batches)
511-
512-
tqdm_batches = self.tqdm_bar(batches)
513-
514-
with self.tqdm_bar(total=total_batches) as pbar:
515-
for batch in tqdm_batches:
516-
payload = construct_annotation_payload(batch, update)
517-
response = self.make_request(
518-
payload, f"dataset/{dataset_id}/annotate"
519-
)
520-
pbar.update(1)
521-
if STATUS_CODE_KEY in response:
522-
agg_response[ERRORS_KEY] = response
523-
else:
524-
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
525-
ANNOTATIONS_PROCESSED_KEY
526-
]
527-
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
528-
ANNOTATIONS_IGNORED_KEY
529-
]
530-
agg_response[ERRORS_KEY] += response[ERRORS_KEY]
531-
532-
for s_batch in semseg_batches:
533-
payload = construct_segmentation_payload(s_batch, update)
534-
response = self.make_request(
535-
payload, f"dataset/{dataset_id}/annotate_segmentation"
536-
)
537-
pbar.update(1)
538-
if STATUS_CODE_KEY in response:
539-
agg_response[ERRORS_KEY] = response
540-
else:
541-
agg_response[ANNOTATIONS_PROCESSED_KEY] += response[
542-
ANNOTATIONS_PROCESSED_KEY
543-
]
544-
agg_response[ANNOTATIONS_IGNORED_KEY] += response[
545-
ANNOTATIONS_IGNORED_KEY
546-
]
547-
548-
return agg_response
549-
550463
@deprecated(msg="Use Dataset.ingest_tasks instead")
551464
def ingest_tasks(self, dataset_id: str, payload: dict):
552465
dataset = self.get_dataset(dataset_id)

nucleus/annotation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import os
23
from dataclasses import dataclass, field
34
from enum import Enum
45
from typing import Dict, List, Optional, Sequence, Type, Union
@@ -70,6 +71,9 @@ def to_json(self) -> str:
7071
"""Serializes annotation object to schematized JSON string."""
7172
return json.dumps(self.to_payload(), allow_nan=False)
7273

74+
def has_local_files(self) -> bool:
75+
return False
76+
7377

7478
@dataclass # pylint: disable=R0902
7579
class BoxAnnotation(Annotation): # pylint: disable=R0902
@@ -578,6 +582,13 @@ def to_payload(self) -> dict:
578582

579583
return payload
580584

585+
def has_files(self) -> bool:
586+
if is_local_path(self.mask_url):
587+
if not os.path.isfile(self.mask_url):
588+
raise Exception(f"Mask file {self.mask_url} does not exist.")
589+
return True
590+
return False
591+
581592

582593
class AnnotationTypes(Enum):
583594
BOX = BOX_TYPE
@@ -737,12 +748,12 @@ def is_local_path(path: str) -> bool:
737748

738749

739750
def check_all_mask_paths_remote(
740-
annotations: Sequence[Union[Annotation]],
751+
annotations: Sequence[Annotation],
741752
):
742753
for annotation in annotations:
743754
if hasattr(annotation, MASK_URL_KEY):
744755
if is_local_path(getattr(annotation, MASK_URL_KEY)):
745756
raise ValueError(
746757
"Found an annotation with a local path, which is not currently"
747-
f"supported. Use a remote path instead. {annotation}"
758+
f"supported for asynchronous upload. Use a remote path instead, or try synchronous upload. {annotation}"
748759
)

nucleus/async_utils.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import aiohttp
77
import nest_asyncio
8+
from tqdm import tqdm
89

910
from nucleus.constants import DEFAULT_NETWORK_TIMEOUT_SEC
1011
from nucleus.errors import NucleusAPIError
@@ -27,6 +28,16 @@ class FileFormField:
2728
FileFormData = Sequence[FileFormField]
2829

2930

31+
async def gather_with_concurrency(n, *tasks):
32+
semaphore = asyncio.Semaphore(n)
33+
34+
async def sem_task(task):
35+
async with semaphore:
36+
return await task
37+
38+
return await asyncio.gather(*(sem_task(task) for task in tasks))
39+
40+
3041
class FormDataContextHandler:
3142
"""A context handler for file form data that handles closing all files in a request.
3243
@@ -85,6 +96,8 @@ def make_many_form_data_requests_concurrently(
8596
client: "NucleusClient",
8697
requests: Sequence[FormDataContextHandler],
8798
route: str,
99+
progressbar: tqdm,
100+
concurrency: int = 30,
88101
):
89102
"""
90103
Makes an async post request with form data to a Nucleus endpoint.
@@ -97,14 +110,18 @@ def make_many_form_data_requests_concurrently(
97110
"""
98111
loop = get_event_loop()
99112
return loop.run_until_complete(
100-
form_data_request_helper(client, requests, route)
113+
form_data_request_helper(
114+
client, requests, route, progressbar, concurrency
115+
)
101116
)
102117

103118

104119
async def form_data_request_helper(
105120
client: "NucleusClient",
106121
requests: Sequence[FormDataContextHandler],
107122
route: str,
123+
progressbar: tqdm,
124+
concurrency: int = 30,
108125
):
109126
"""
110127
Makes an async post request with files to a Nucleus endpoint.
@@ -123,18 +140,20 @@ async def form_data_request_helper(
123140
request=request,
124141
route=route,
125142
session=session,
143+
progressbar=progressbar,
126144
)
127145
)
128146
for request in requests
129147
]
130-
return await asyncio.gather(*tasks)
148+
return await gather_with_concurrency(concurrency, *tasks)
131149

132150

133151
async def _post_form_data(
134152
client: "NucleusClient",
135153
request: FormDataContextHandler,
136154
route: str,
137155
session: aiohttp.ClientSession,
156+
progressbar: tqdm,
138157
):
139158
"""
140159
Makes an async post request with files to a Nucleus endpoint.
@@ -175,7 +194,7 @@ async def _post_form_data(
175194

176195
if response.status == 503:
177196
raise TimeoutError(
178-
"The request to upload your max is timing out, please lower the batch size."
197+
"The request to upload your max is timing out, please lower local_files_per_upload_request in your api call."
179198
)
180199

181200
if not response.ok:
@@ -188,5 +207,5 @@ async def _post_form_data(
188207
data,
189208
),
190209
)
191-
210+
progressbar.update(1)
192211
return data

0 commit comments

Comments
 (0)