Skip to content

Commit 01464a1

Browse files
author
Ubuntu
committed
Fixed segmentation bugs
1 parent 59ce11e commit 01464a1

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

nucleus/annotation_uploader.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import json
2+
from typing import TYPE_CHECKING, Iterable, List, Sequence
3+
4+
from nucleus.annotation import Annotation, SegmentationAnnotation
5+
from nucleus.async_utils import (
6+
FileFormField,
7+
FormDataContextHandler,
8+
make_many_form_data_requests_concurrently,
9+
)
10+
from nucleus.constants import ITEMS_KEY, SEGMENTATIONS_KEY
11+
from nucleus.payload_constructor import (
12+
construct_annotation_payload,
13+
construct_segmentation_payload,
14+
)
15+
16+
if TYPE_CHECKING:
17+
from . import NucleusClient
18+
19+
20+
def accumulate_dict_values(dicts: Iterable[dict]):
21+
"""
22+
Accumulate a list of dicts into a single dict using summation.
23+
"""
24+
result = {}
25+
for d in dicts:
26+
for key, value in d.items():
27+
if key not in result:
28+
result[key] = value
29+
else:
30+
result[key] += value
31+
return result
32+
33+
34+
class AnnotationUploader:
35+
def __init__(self, dataset_id: str, client: "NucleusClient"): # noqa: F821
36+
self.dataset_id = dataset_id
37+
self._client = client
38+
39+
def upload(
40+
self,
41+
annotations: Iterable[Annotation],
42+
batch_size: int = 5000,
43+
update: bool = False,
44+
remote_files_per_upload_request: int = 20,
45+
local_files_per_upload_request: int = 10,
46+
local_file_upload_concurrency: int = 30,
47+
):
48+
if local_files_per_upload_request > 10:
49+
raise ValueError("local_files_per_upload_request must be <= 10")
50+
annotations_without_files: List[Annotation] = []
51+
segmentations_with_local_files: List[SegmentationAnnotation] = []
52+
segmentations_with_remote_files: List[SegmentationAnnotation] = []
53+
54+
for annotation in annotations:
55+
if annotation.has_local_files():
56+
# Only segmentations have local files currently, and probably for a long
57+
# time to to come.
58+
assert isinstance(annotation, SegmentationAnnotation)
59+
segmentations_with_local_files.append(annotation)
60+
elif isinstance(annotation, SegmentationAnnotation):
61+
segmentations_with_remote_files.append(annotation)
62+
else:
63+
annotations_without_files.append(annotation)
64+
65+
responses = []
66+
if segmentations_with_local_files:
67+
responses.extend(
68+
self.make_batched_file_form_data_requests(
69+
segmentations=segmentations_with_local_files,
70+
update=update,
71+
local_files_per_upload_request=local_files_per_upload_request,
72+
local_file_upload_concurrency=local_file_upload_concurrency,
73+
)
74+
)
75+
if segmentations_with_remote_files:
76+
# Segmentations require an upload and must be batched differently since a single
77+
# segmentation will take a lot longer for the server to process than a single
78+
# annotation of any other kind.
79+
responses.extend(
80+
self.make_batched_annotate_requests(
81+
segmentations_with_remote_files,
82+
update,
83+
batch_size=remote_files_per_upload_request,
84+
segmentation=True,
85+
)
86+
)
87+
if annotations_without_files:
88+
responses.extend(
89+
self.make_batched_annotate_requests(
90+
annotations_without_files,
91+
update,
92+
batch_size=batch_size,
93+
segmentation=False,
94+
)
95+
)
96+
97+
return accumulate_dict_values(responses)
98+
99+
def make_batched_annotate_requests(
100+
self,
101+
annotations: Sequence[Annotation],
102+
update: bool,
103+
batch_size: int,
104+
segmentation: bool,
105+
):
106+
batches = [
107+
annotations[i : i + batch_size]
108+
for i in range(0, len(annotations), batch_size)
109+
]
110+
responses = []
111+
progress_bar_name = (
112+
"Segmentation batches" if segmentation else "Annotation batches"
113+
)
114+
for batch in self._client.tqdm_bar(batches, desc=progress_bar_name):
115+
if segmentation:
116+
payload = construct_segmentation_payload(batch, update)
117+
# TODO: remove validation checks in backend for /annotate
118+
# since it should work.
119+
route = f"dataset/{self.dataset_id}/annotate_segmentation"
120+
else:
121+
payload = construct_annotation_payload(batch, update)
122+
route = f"dataset/{self.dataset_id}/annotate"
123+
responses.append(self._client.make_request(payload, route))
124+
return responses
125+
126+
def make_batched_file_form_data_requests(
127+
self,
128+
segmentations: Sequence[SegmentationAnnotation],
129+
update,
130+
local_files_per_upload_request: int,
131+
local_file_upload_concurrency: int,
132+
):
133+
requests = []
134+
for i in range(0, len(segmentations), local_files_per_upload_request):
135+
batch = segmentations[i : i + local_files_per_upload_request]
136+
request = FormDataContextHandler(
137+
self.get_form_data_and_file_pointers_fn(batch, update)
138+
)
139+
requests.append(request)
140+
141+
progressbar = self._client.tqdm_bar(
142+
total=len(requests),
143+
desc="Local segmentation mask file batches",
144+
)
145+
146+
return make_many_form_data_requests_concurrently(
147+
client=self._client,
148+
requests=requests,
149+
route=f"dataset/{self.dataset_id}/annotate_segmentation_files",
150+
progressbar=progressbar,
151+
concurrency=local_file_upload_concurrency,
152+
)
153+
154+
def get_form_data_and_file_pointers_fn(
155+
self,
156+
segmentations: Iterable[SegmentationAnnotation],
157+
update: bool,
158+
):
159+
"""Defines a function to be called on each retry.
160+
161+
File pointers are also returned so whoever calls this function can
162+
appropriately close the files. This is intended for use with a
163+
FormDataContextHandler in order to make form data requests.
164+
"""
165+
166+
def fn():
167+
request_json = construct_segmentation_payload(
168+
segmentations, update
169+
)
170+
form_data = [
171+
FileFormField(
172+
name=ITEMS_KEY,
173+
filename=None,
174+
value=json.dumps(request_json),
175+
content_type="application/json",
176+
)
177+
]
178+
file_pointers = []
179+
for segmentation in segmentations:
180+
# I don't know of a way to use with, since all files in the request
181+
# need to be opened at the same time.
182+
# pylint: disable=consider-using-with
183+
mask_fp = open(segmentation.mask_url, "rb")
184+
# pylint: enable=consider-using-with
185+
file_type = segmentation.mask_url.split(".")[-1]
186+
if file_type != "png":
187+
raise ValueError(
188+
f"Only png files are supported. Got {file_type} for {segmentation.mask_url}"
189+
)
190+
form_data.append(
191+
FileFormField(
192+
name=SEGMENTATIONS_KEY,
193+
filename=segmentation.mask_url,
194+
value=mask_fp,
195+
content_type="image/png",
196+
)
197+
)
198+
file_pointers.append(mask_fp)
199+
return form_data, file_pointers
200+
201+
return fn

0 commit comments

Comments
 (0)