Skip to content

Commit cd942e4

Browse files
author
Ubuntu
committed
work in progress
1 parent 5fe4e6b commit cd942e4

File tree

2 files changed

+306
-0
lines changed

2 files changed

+306
-0
lines changed

nucleus/async_utils.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import asyncio
2+
import time
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, BinaryIO, Callable, Sequence, Tuple
5+
6+
import aiohttp
7+
import nest_asyncio
8+
9+
from nucleus.constants import DEFAULT_NETWORK_TIMEOUT_SEC
10+
from nucleus.errors import NucleusAPIError
11+
from nucleus.retry_strategy import RetryStrategy
12+
13+
from .logger import logger
14+
15+
if TYPE_CHECKING:
16+
from . import NucleusClient
17+
18+
19+
@dataclass
20+
class FileFormField:
21+
name: str
22+
filename: str
23+
value: BinaryIO
24+
content_type: str
25+
26+
27+
FileFormData = Sequence[FileFormField]
28+
29+
30+
class FormDataContextHandler:
31+
"""A context handler for file form data that handles closing all files in a request.
32+
33+
Why do I need to wrap my requests in such a funny way?
34+
35+
1. Form data must be regenerated on each request to avoid errors
36+
see https://github.com/Rapptz/discord.py/issues/6531
37+
2. Files must be properly open/closed for each request.
38+
3. We need to be able to do 1/2 above multiple times so that we can implement retries
39+
properly.
40+
41+
Write a function that returns a tuple of form data and file pointers, then pass it to the
42+
constructor of this class, and this class will handle the rest for you.
43+
"""
44+
45+
def __init__(
46+
self,
47+
form_data_and_file_pointers_fn: Callable[
48+
..., Tuple[FileFormData, Sequence[BinaryIO]]
49+
],
50+
):
51+
self._form_data_and_file_pointer_fn = form_data_and_file_pointers_fn
52+
self._file_pointers = None
53+
54+
def __enter__(self):
55+
(
56+
file_form_data,
57+
self._file_pointers,
58+
) = self._form_data_and_file_pointer_fn()
59+
form = aiohttp.FormData()
60+
for field in file_form_data:
61+
form.add_field(
62+
name=field.name,
63+
filename=field.filename,
64+
value=field.value,
65+
content_type=field.content_type,
66+
)
67+
return form
68+
69+
def __exit__(self, exc_type, exc_val, exc_tb):
70+
for file_pointer in self._file_pointers:
71+
file_pointer.close()
72+
73+
74+
def get_event_loop():
75+
try:
76+
loop = asyncio.get_event_loop()
77+
except RuntimeError: # no event loop running:
78+
loop = asyncio.new_event_loop()
79+
else:
80+
nest_asyncio.apply(loop)
81+
return loop
82+
83+
84+
def make_many_form_data_requests_concurrently(
85+
client: "NucleusClient",
86+
requests: Sequence[FormDataContextHandler],
87+
route: str,
88+
):
89+
"""
90+
Makes an async post request with form data to a Nucleus endpoint.
91+
92+
Args:
93+
client: The client to use for the request.
94+
requests: Each requst should be a FormDataContextHandler object which will
95+
handle generating form data, and opening/closing files for each request.
96+
route: route for the request.
97+
"""
98+
loop = get_event_loop()
99+
return loop.run_until_complete(
100+
form_data_request_helper(client, requests, route)
101+
)
102+
103+
104+
async def form_data_request_helper(
105+
client: "NucleusClient",
106+
requests: Sequence[FormDataContextHandler],
107+
route: str,
108+
):
109+
"""
110+
Makes an async post request with files to a Nucleus endpoint.
111+
112+
Args:
113+
client: The client to use for the request.
114+
requests: Each requst should be a FormDataContextHandler object which will
115+
handle generating form data, and opening/closing files for each request.
116+
route: route for the request.
117+
"""
118+
async with aiohttp.ClientSession() as session:
119+
tasks = [
120+
asyncio.ensure_future(
121+
_post_form_data(
122+
client=client,
123+
request=request,
124+
route=route,
125+
session=session,
126+
)
127+
)
128+
for request in requests
129+
]
130+
return await asyncio.gather(*tasks)
131+
132+
133+
async def _post_form_data(
134+
client: "NucleusClient",
135+
request: FormDataContextHandler,
136+
route: str,
137+
session: aiohttp.ClientSession,
138+
):
139+
"""
140+
Makes an async post request with files to a Nucleus endpoint.
141+
142+
Args:
143+
client: The client to use for the request.
144+
request: The request to make (See FormDataContextHandler for more details.)
145+
route: route for the request.
146+
session: The session to use for the request.
147+
"""
148+
endpoint = f"{client.endpoint}/{route}"
149+
150+
logger.info("Posting to %s", endpoint)
151+
152+
for sleep_time in RetryStrategy.sleep_times + [-1]:
153+
with request as form:
154+
async with session.post(
155+
endpoint,
156+
data=form,
157+
auth=aiohttp.BasicAuth(client.api_key, ""),
158+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
159+
) as response:
160+
logger.info(
161+
"API request has response code %s", response.status
162+
)
163+
164+
try:
165+
data = await response.json()
166+
except aiohttp.client_exceptions.ContentTypeError:
167+
# In case of 404, the server returns text
168+
data = await response.text()
169+
if (
170+
response.status in RetryStrategy.statuses
171+
and sleep_time != -1
172+
):
173+
time.sleep(sleep_time)
174+
continue
175+
176+
if response.status == 503:
177+
raise TimeoutError(
178+
"The request to upload your max is timing out, please lower the batch size."
179+
)
180+
181+
if not response.ok:
182+
raise NucleusAPIError(
183+
endpoint,
184+
session.post,
185+
aiohttp_response=(
186+
response.status,
187+
response.reason,
188+
data,
189+
),
190+
)
191+
192+
return data

nucleus/segmentation_uploader.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# import asyncio
2+
# import json
3+
# import os
4+
# from typing import TYPE_CHECKING, Any, List
5+
# from nucleus.annotation import SegmentationAnnotation
6+
# from nucleus.async_utils import get_event_loop
7+
# from nucleus.constants import DATASET_ID_KEY, MASK_TYPE, SEGMENTATIONS_KEY
8+
# from nucleus.errors import NotFoundError
9+
# from nucleus.payload_constructor import construct_segmentation_payload
10+
# from annotation import is_local_path
11+
# from nucleus.upload_response import UploadResponse
12+
# import nest_asyncio
13+
14+
# if TYPE_CHECKING:
15+
# from . import NucleusClient
16+
17+
18+
# class SegmentationUploader:
19+
# def __init__(self, dataset_id: str, client: "NucleusClient"): # noqa: F821
20+
# self.dataset_id = dataset_id
21+
# self._client = client
22+
23+
# def annotate(
24+
# self,
25+
# segmentations: List[SegmentationAnnotation],
26+
# batch_size: int = 20,
27+
# update: bool = False,
28+
# ):
29+
# remote_segmentations = []
30+
# local_segmentations = []
31+
# for segmentation in segmentations:
32+
# if is_local_path(segmentation.mask_url):
33+
# if not segmentation.local_file_exists():
34+
# raise NotFoundError(
35+
# "Could not find f{segmentation.mask_url}"
36+
# )
37+
# local_segmentations.append(segmentation)
38+
# else:
39+
# remote_segmentations.append(segmentation)
40+
41+
# local_batches = [
42+
# local_segmentations[i : i + batch_size]
43+
# for i in range(0, len(local_segmentations), batch_size)
44+
# ]
45+
46+
# remote_batches = [
47+
# remote_segmentations[i : i + batch_size]
48+
# for i in range(0, len(remote_segmentations), batch_size)
49+
# ]
50+
51+
# agg_response = UploadResponse(json={DATASET_ID_KEY: self.dataset_id})
52+
53+
# async_responses: List[Any] = []
54+
55+
# if local_batches:
56+
# tqdm_local_batches = self._client.tqdm_bar(
57+
# local_batches, desc="Local file batches"
58+
# )
59+
# for batch in tqdm_local_batches:
60+
# responses = self._process_annotate_requests_local(
61+
# self.dataset_id, batch
62+
# )
63+
# async_responses.extend(responses)
64+
65+
# def process_annotate_requests_local(
66+
# dataset_id: str,
67+
# segmentations: List[SegmentationAnnotation],
68+
# local_batch_size: int = 10,
69+
# ):
70+
# requests = []
71+
# file_pointers = []
72+
# for i in range(0, len(segmentations), local_batch_size):
73+
# batch = segmentations[i : i + local_batch_size]
74+
# request, request_file_pointers = self.construct_files_request(
75+
# batch
76+
# )
77+
# requests.append(request)
78+
# file_pointers.extend(request_file_pointers)
79+
80+
# future = self.make_many_files_requests_asynchronously(
81+
# requests, f"dataset/{dataset_id}/files"
82+
# )
83+
84+
# loop = get_event_loop()
85+
86+
# responses = loop.run_until_complete(future)
87+
# [fp.close() for fp in file_pointers]
88+
# return responses
89+
90+
# def construct_files_request(
91+
# segmentations: List[SegmentationAnnotation],
92+
# ):
93+
# request_json = construct_segmentation_payload(
94+
# segmentations, update
95+
# )
96+
# request_payload = [
97+
# (
98+
# SEGMENTATIONS_KEY,
99+
# (None, json.dumps(request_json), "application/json"),
100+
# )
101+
# ]
102+
# file_pointers = []
103+
# for segmentation in segmentations:
104+
# mask_fp = open(segmentation.mask_url, "rb")
105+
# filename = os.path.basename(segmentation.mask_url)
106+
# file_type = segmentation.mask_url.split(".")[-1]
107+
# if file_type != "png":
108+
# raise ValueError(
109+
# f"Only png files are supported. Got {file_type} for {segmentation.mask_url}"
110+
# )
111+
# request_payload.append(
112+
# (MASK_TYPE, (filename, mask_fp, "image/png"))
113+
# )
114+
# return request_payload, file_pointers

0 commit comments

Comments
 (0)