diff --git a/conftest.py b/conftest.py index b98074ea..d956c050 100644 --- a/conftest.py +++ b/conftest.py @@ -28,6 +28,13 @@ def dataset(CLIENT): CLIENT.delete_dataset(ds.id) +@pytest.fixture() +def model(CLIENT): + model = CLIENT.create_model(TEST_DATASET_NAME, "fake_reference_id") + yield model + CLIENT.delete_model(model.id) + + if __name__ == "__main__": client = nucleus.NucleusClient(API_KEY) # ds = client.create_dataset("Test Dataset With Autotags") diff --git a/nucleus/__init__.py b/nucleus/__init__.py index c66292f1..50016905 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -460,93 +460,6 @@ def populate_dataset( dataset_items, batch_size=batch_size, update=update ) - def annotate_dataset( - self, - dataset_id: str, - annotations: Sequence[ - Union[ - BoxAnnotation, - PolygonAnnotation, - CuboidAnnotation, - CategoryAnnotation, - MultiCategoryAnnotation, - SegmentationAnnotation, - ] - ], - update: bool, - batch_size: int = 5000, - ) -> Dict[str, object]: - # TODO: deprecate in favor of Dataset.annotate invocation - - # Split payload into segmentations and Box/Polygon - segmentations = [ - ann - for ann in annotations - if isinstance(ann, SegmentationAnnotation) - ] - other_annotations = [ - ann - for ann in annotations - if not isinstance(ann, SegmentationAnnotation) - ] - - batches = [ - other_annotations[i : i + batch_size] - for i in range(0, len(other_annotations), batch_size) - ] - - semseg_batches = [ - segmentations[i : i + batch_size] - for i in range(0, len(segmentations), batch_size) - ] - - agg_response = { - DATASET_ID_KEY: dataset_id, - ANNOTATIONS_PROCESSED_KEY: 0, - ANNOTATIONS_IGNORED_KEY: 0, - ERRORS_KEY: [], - } - - total_batches = len(batches) + len(semseg_batches) - - tqdm_batches = self.tqdm_bar(batches) - - with self.tqdm_bar(total=total_batches) as pbar: - for batch in tqdm_batches: - payload = construct_annotation_payload(batch, update) - response = self.make_request( - payload, f"dataset/{dataset_id}/annotate" - ) - pbar.update(1) - if STATUS_CODE_KEY in response: - agg_response[ERRORS_KEY] = response - else: - agg_response[ANNOTATIONS_PROCESSED_KEY] += response[ - ANNOTATIONS_PROCESSED_KEY - ] - agg_response[ANNOTATIONS_IGNORED_KEY] += response[ - ANNOTATIONS_IGNORED_KEY - ] - agg_response[ERRORS_KEY] += response[ERRORS_KEY] - - for s_batch in semseg_batches: - payload = construct_segmentation_payload(s_batch, update) - response = self.make_request( - payload, f"dataset/{dataset_id}/annotate_segmentation" - ) - pbar.update(1) - if STATUS_CODE_KEY in response: - agg_response[ERRORS_KEY] = response - else: - agg_response[ANNOTATIONS_PROCESSED_KEY] += response[ - ANNOTATIONS_PROCESSED_KEY - ] - agg_response[ANNOTATIONS_IGNORED_KEY] += response[ - ANNOTATIONS_IGNORED_KEY - ] - - return agg_response - @deprecated(msg="Use Dataset.ingest_tasks instead") def ingest_tasks(self, dataset_id: str, payload: dict): dataset = self.get_dataset(dataset_id) @@ -599,93 +512,6 @@ def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun: response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self ) - @deprecated("Use Dataset.upload_predictions instead.") - def predict( - self, - annotations: List[ - Union[ - BoxPrediction, - PolygonPrediction, - CuboidPrediction, - SegmentationPrediction, - CategoryPrediction, - ] - ], - model_run_id: Optional[str] = None, - model_id: Optional[str] = None, - dataset_id: Optional[str] = None, - update: bool = False, - batch_size: int = 5000, - ): - if model_run_id is not None: - assert model_id is None and dataset_id is None - endpoint = f"modelRun/{model_run_id}/predict" - else: - assert ( - model_id is not None and dataset_id is not None - ), "Model ID and dataset ID are required if not using model run id." - endpoint = ( - f"dataset/{dataset_id}/model/{model_id}/uploadPredictions" - ) - segmentations = [ - ann - for ann in annotations - if isinstance(ann, SegmentationPrediction) - ] - - other_predictions = [ - ann - for ann in annotations - if not isinstance(ann, SegmentationPrediction) - ] - - s_batches = [ - segmentations[i : i + batch_size] - for i in range(0, len(segmentations), batch_size) - ] - - batches = [ - other_predictions[i : i + batch_size] - for i in range(0, len(other_predictions), batch_size) - ] - - errors = [] - predictions_processed = 0 - predictions_ignored = 0 - - tqdm_batches = self.tqdm_bar(batches) - - for batch in tqdm_batches: - batch_payload = construct_box_predictions_payload( - batch, - update, - ) - response = self.make_request(batch_payload, endpoint) - if STATUS_CODE_KEY in response: - errors.append(response) - else: - predictions_processed += response[PREDICTIONS_PROCESSED_KEY] - predictions_ignored += response[PREDICTIONS_IGNORED_KEY] - if ERRORS_KEY in response: - errors += response[ERRORS_KEY] - - for s_batch in s_batches: - payload = construct_segmentation_payload(s_batch, update) - response = self.make_request(payload, endpoint) - # pbar.update(1) - if STATUS_CODE_KEY in response: - errors.append(response) - else: - predictions_processed += response[PREDICTIONS_PROCESSED_KEY] - predictions_ignored += response[PREDICTIONS_IGNORED_KEY] - - return { - MODEL_RUN_ID_KEY: model_run_id, - PREDICTIONS_PROCESSED_KEY: predictions_processed, - PREDICTIONS_IGNORED_KEY: predictions_ignored, - ERRORS_KEY: errors, - } - @deprecated( "Model runs have been deprecated and will be removed. Use a Model instead." ) diff --git a/nucleus/annotation.py b/nucleus/annotation.py index 347e156d..46483ac1 100644 --- a/nucleus/annotation.py +++ b/nucleus/annotation.py @@ -1,4 +1,5 @@ import json +import os from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Sequence, Type, Union @@ -70,6 +71,15 @@ def to_json(self) -> str: """Serializes annotation object to schematized JSON string.""" return json.dumps(self.to_payload(), allow_nan=False) + def has_local_files_to_upload(self) -> bool: + """Returns True if annotation has local files that need to be uploaded. + + Nearly all subclasses have no local files, so we default this to just return + false. If the subclass has local files, it should override this method (but + that is not the only thing required to get local upload of files to work.) + """ + return False + @dataclass # pylint: disable=R0902 class BoxAnnotation(Annotation): # pylint: disable=R0902 @@ -578,6 +588,26 @@ def to_payload(self) -> dict: return payload + def has_local_files_to_upload(self) -> bool: + """Check if the mask url is local and needs to be uploaded.""" + if is_local_path(self.mask_url): + if not os.path.isfile(self.mask_url): + raise Exception(f"Mask file {self.mask_url} does not exist.") + return True + return False + + def __eq__(self, other): + if not isinstance(other, SegmentationAnnotation): + return False + self.annotations = sorted(self.annotations, key=lambda x: x.index) + other.annotations = sorted(other.annotations, key=lambda x: x.index) + return ( + (self.annotation_id == other.annotation_id) + and (self.annotations == other.annotations) + and (self.mask_url == other.mask_url) + and (self.reference_id == other.reference_id) + ) + class AnnotationTypes(Enum): BOX = BOX_TYPE @@ -737,12 +767,12 @@ def is_local_path(path: str) -> bool: def check_all_mask_paths_remote( - annotations: Sequence[Union[Annotation]], + annotations: Sequence[Annotation], ): for annotation in annotations: if hasattr(annotation, MASK_URL_KEY): if is_local_path(getattr(annotation, MASK_URL_KEY)): raise ValueError( "Found an annotation with a local path, which is not currently" - f"supported. Use a remote path instead. {annotation}" + f"supported for asynchronous upload. Use a remote path instead, or try synchronous upload. {annotation}" ) diff --git a/nucleus/annotation_uploader.py b/nucleus/annotation_uploader.py new file mode 100644 index 00000000..6926eb30 --- /dev/null +++ b/nucleus/annotation_uploader.py @@ -0,0 +1,231 @@ +import json +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence + +from nucleus.annotation import Annotation, SegmentationAnnotation +from nucleus.async_utils import ( + FileFormField, + FormDataContextHandler, + make_many_form_data_requests_concurrently, +) +from nucleus.constants import MASK_TYPE, SERIALIZED_REQUEST_KEY +from nucleus.payload_constructor import ( + construct_annotation_payload, + construct_segmentation_payload, +) + +if TYPE_CHECKING: + from . import NucleusClient + + +def accumulate_dict_values(dicts: Iterable[dict]): + """ + Accumulate a list of dicts into a single dict using summation. + """ + result = {} + for d in dicts: + for key, value in d.items(): + if ( + key not in result + or key == "dataset_id" + or key == "model_run_id" + ): + result[key] = value + else: + result[key] += value + return result + + +class AnnotationUploader: + """This is a helper class not intended for direct use. Please use dataset.annotate + or dataset.upload_predictions. + + This class is purely a helper class for implementing dataset.annotate/dataset.predict. + """ + + def __init__( + self, dataset_id: Optional[str], client: "NucleusClient" + ): # noqa: F821 + self._client = client + self._route = f"dataset/{dataset_id}/annotate" + + def upload( + self, + annotations: Iterable[Annotation], + batch_size: int = 5000, + update: bool = False, + remote_files_per_upload_request: int = 20, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, + ): + """For more details on parameters and functionality, see dataset.annotate.""" + if local_files_per_upload_request > 10: + raise ValueError("local_files_per_upload_request must be <= 10") + annotations_without_files: List[Annotation] = [] + segmentations_with_local_files: List[SegmentationAnnotation] = [] + segmentations_with_remote_files: List[SegmentationAnnotation] = [] + + for annotation in annotations: + if annotation.has_local_files_to_upload(): + # Only segmentations have local files currently, and probably for a long + # time to to come. + assert isinstance(annotation, SegmentationAnnotation) + segmentations_with_local_files.append(annotation) + elif isinstance(annotation, SegmentationAnnotation): + segmentations_with_remote_files.append(annotation) + else: + annotations_without_files.append(annotation) + + responses = [] + if segmentations_with_local_files: + responses.extend( + self.make_batched_file_form_data_requests( + segmentations=segmentations_with_local_files, + update=update, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, + ) + ) + if segmentations_with_remote_files: + # Segmentations require an upload and must be batched differently since a single + # segmentation will take a lot longer for the server to process than a single + # annotation of any other kind. + responses.extend( + self.make_batched_requests( + segmentations_with_remote_files, + update, + batch_size=remote_files_per_upload_request, + segmentation=True, + ) + ) + if annotations_without_files: + responses.extend( + self.make_batched_requests( + annotations_without_files, + update, + batch_size=batch_size, + segmentation=False, + ) + ) + + return accumulate_dict_values(responses) + + def make_batched_requests( + self, + annotations: Sequence[Annotation], + update: bool, + batch_size: int, + segmentation: bool, + ): + batches = [ + annotations[i : i + batch_size] + for i in range(0, len(annotations), batch_size) + ] + responses = [] + progress_bar_name = ( + "Segmentation batches" if segmentation else "Annotation batches" + ) + for batch in self._client.tqdm_bar(batches, desc=progress_bar_name): + payload = construct_annotation_payload(batch, update) + responses.append( + self._client.make_request(payload, route=self._route) + ) + return responses + + def make_batched_file_form_data_requests( + self, + segmentations: Sequence[SegmentationAnnotation], + update, + local_files_per_upload_request: int, + local_file_upload_concurrency: int, + ): + requests = [] + for i in range(0, len(segmentations), local_files_per_upload_request): + batch = segmentations[i : i + local_files_per_upload_request] + request = FormDataContextHandler( + self.get_form_data_and_file_pointers_fn(batch, update) + ) + requests.append(request) + + progressbar = self._client.tqdm_bar( + total=len(requests), + desc="Local segmentation mask file batches", + ) + + return make_many_form_data_requests_concurrently( + client=self._client, + requests=requests, + route=self._route, + progressbar=progressbar, + concurrency=local_file_upload_concurrency, + ) + + def get_form_data_and_file_pointers_fn( + self, + segmentations: Iterable[SegmentationAnnotation], + update: bool, + ): + """Defines a function to be called on each retry. + + File pointers are also returned so whoever calls this function can + appropriately close the files. This is intended for use with a + FormDataContextHandler in order to make form data requests. + """ + + def fn(): + request_json = construct_segmentation_payload( + segmentations, update + ) + form_data = [ + FileFormField( + name=SERIALIZED_REQUEST_KEY, + filename=None, + value=json.dumps(request_json), + content_type="application/json", + ) + ] + file_pointers = [] + for segmentation in segmentations: + # I don't know of a way to use with, since all files in the request + # need to be opened at the same time. + # pylint: disable=consider-using-with + mask_fp = open(segmentation.mask_url, "rb") + # pylint: enable=consider-using-with + file_type = segmentation.mask_url.split(".")[-1] + if file_type != "png": + raise ValueError( + f"Only png files are supported. Got {file_type} for {segmentation.mask_url}" + ) + form_data.append( + FileFormField( + name=MASK_TYPE, + filename=segmentation.mask_url, + value=mask_fp, + content_type="image/png", + ) + ) + file_pointers.append(mask_fp) + return form_data, file_pointers + + return fn + + +class PredictionUploader(AnnotationUploader): + def __init__( + self, + client: "NucleusClient", + dataset_id: Optional[str] = None, + model_id: Optional[str] = None, + model_run_id: Optional[str] = None, + ): + super().__init__(dataset_id, client) + self._client = client + if model_run_id is not None: + assert model_id is None and dataset_id is None + self._route = f"modelRun/{model_run_id}/predict" + else: + assert ( + model_id is not None and dataset_id is not None + ), "Model ID and dataset ID are required if not using model run id." + self._route = ( + f"dataset/{dataset_id}/model/{model_id}/uploadPredictions" + ) diff --git a/nucleus/async_utils.py b/nucleus/async_utils.py new file mode 100644 index 00000000..5e9b2ee9 --- /dev/null +++ b/nucleus/async_utils.py @@ -0,0 +1,215 @@ +import asyncio +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, BinaryIO, Callable, Sequence, Tuple + +import aiohttp +import nest_asyncio +from tqdm import tqdm + +from nucleus.constants import DEFAULT_NETWORK_TIMEOUT_SEC +from nucleus.errors import NucleusAPIError +from nucleus.retry_strategy import RetryStrategy + +from .logger import logger + +if TYPE_CHECKING: + from . import NucleusClient + + +@dataclass +class FileFormField: + name: str + filename: str + value: BinaryIO + content_type: str + + +FileFormData = Sequence[FileFormField] + + +async def gather_with_concurrency(n, *tasks): + """Helper method to limit the concurrency when gathering the results from multiple tasks.""" + semaphore = asyncio.Semaphore(n) + + async def sem_task(task): + async with semaphore: + return await task + + return await asyncio.gather(*(sem_task(task) for task in tasks)) + + +class FormDataContextHandler: + """A context handler for file form data that handles closing all files in a request. + + Why do I need to wrap my requests in such a funny way? + + 1. Form data must be regenerated on each request to avoid errors + see https://github.com/Rapptz/discord.py/issues/6531 + 2. Files must be properly open/closed for each request. + 3. We need to be able to do 1/2 above multiple times so that we can implement retries + properly. + + Write a function that returns a tuple of form data and file pointers, then pass it to the + constructor of this class, and this class will handle the rest for you. + """ + + def __init__( + self, + form_data_and_file_pointers_fn: Callable[ + ..., Tuple[FileFormData, Sequence[BinaryIO]] + ], + ): + self._form_data_and_file_pointer_fn = form_data_and_file_pointers_fn + self._file_pointers = None + + def __enter__(self): + ( + file_form_data, + self._file_pointers, + ) = self._form_data_and_file_pointer_fn() + form = aiohttp.FormData() + for field in file_form_data: + form.add_field( + name=field.name, + filename=field.filename, + value=field.value, + content_type=field.content_type, + ) + return form + + def __exit__(self, exc_type, exc_val, exc_tb): + for file_pointer in self._file_pointers: + file_pointer.close() + + +def get_event_loop(): + try: + loop = asyncio.get_event_loop() + except RuntimeError: # no event loop running: + loop = asyncio.new_event_loop() + else: + nest_asyncio.apply(loop) + return loop + + +def make_many_form_data_requests_concurrently( + client: "NucleusClient", + requests: Sequence[FormDataContextHandler], + route: str, + progressbar: tqdm, + concurrency: int = 30, +): + """ + Makes an async post request with form data to a Nucleus endpoint. + + Args: + client: The client to use for the request. + requests: Each requst should be a FormDataContextHandler object which will + handle generating form data, and opening/closing files for each request. + route: route for the request. + progressbar: A tqdm progress bar to use for showing progress to the user. + concurrency: How many concurrent requests to run at once. Should be exposed + to the user. + """ + loop = get_event_loop() + return loop.run_until_complete( + form_data_request_helper( + client, requests, route, progressbar, concurrency + ) + ) + + +async def form_data_request_helper( + client: "NucleusClient", + requests: Sequence[FormDataContextHandler], + route: str, + progressbar: tqdm, + concurrency: int = 30, +): + """ + Makes an async post request with files to a Nucleus endpoint. + + Args: + client: The client to use for the request. + requests: Each requst should be a FormDataContextHandler object which will + handle generating form data, and opening/closing files for each request. + route: route for the request. + """ + async with aiohttp.ClientSession() as session: + tasks = [ + asyncio.ensure_future( + _post_form_data( + client=client, + request=request, + route=route, + session=session, + progressbar=progressbar, + ) + ) + for request in requests + ] + return await gather_with_concurrency(concurrency, *tasks) + + +async def _post_form_data( + client: "NucleusClient", + request: FormDataContextHandler, + route: str, + session: aiohttp.ClientSession, + progressbar: tqdm, +): + """ + Makes an async post request with files to a Nucleus endpoint. + + Args: + client: The client to use for the request. + request: The request to make (See FormDataContextHandler for more details.) + route: route for the request. + session: The session to use for the request. + """ + endpoint = f"{client.endpoint}/{route}" + + logger.info("Posting to %s", endpoint) + + for sleep_time in RetryStrategy.sleep_times() + [-1]: + with request as form: + async with session.post( + endpoint, + data=form, + auth=aiohttp.BasicAuth(client.api_key, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + logger.info( + "API request has response code %s", response.status + ) + + try: + data = await response.json() + except aiohttp.client_exceptions.ContentTypeError: + # In case of 404, the server returns text + data = await response.text() + if ( + response.status in RetryStrategy.statuses + and sleep_time != -1 + ): + time.sleep(sleep_time) + continue + + if response.status == 503: + raise TimeoutError( + "The request to upload your max is timing out, please lower local_files_per_upload_request in your api call." + ) + + if not response.ok: + raise NucleusAPIError( + endpoint, + session.post, + aiohttp_response=( + response.status, + response.reason, + data, + ), + ) + progressbar.update(1) + return data diff --git a/nucleus/connection.py b/nucleus/connection.py index 4054748f..11d07ba4 100644 --- a/nucleus/connection.py +++ b/nucleus/connection.py @@ -55,7 +55,7 @@ def make_request( logger.info("Make request to %s", endpoint) - for retry_wait_time in RetryStrategy.sleep_times: + for retry_wait_time in RetryStrategy.sleep_times(): response = requests_command( endpoint, json=payload, diff --git a/nucleus/constants.py b/nucleus/constants.py index 69837640..426fc7a4 100644 --- a/nucleus/constants.py +++ b/nucleus/constants.py @@ -99,6 +99,7 @@ REFERENCE_ID_KEY = "reference_id" REQUEST_ID_KEY = "requestId" SCENES_KEY = "scenes" +SERIALIZED_REQUEST_KEY = "serialized_request" SEGMENTATIONS_KEY = "segmentations" SLICE_ID_KEY = "slice_id" STATUS_CODE_KEY = "status_code" diff --git a/nucleus/dataset.py b/nucleus/dataset.py index a7959507..811dd270 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -3,6 +3,7 @@ import requests +from nucleus.annotation_uploader import AnnotationUploader, PredictionUploader from nucleus.job import AsyncJob from nucleus.prediction import ( BoxPrediction, @@ -20,16 +21,7 @@ serialize_and_write_to_presigned_url, ) -from .annotation import ( - Annotation, - BoxAnnotation, - CategoryAnnotation, - CuboidAnnotation, - MultiCategoryAnnotation, - PolygonAnnotation, - SegmentationAnnotation, - check_all_mask_paths_remote, -) +from .annotation import Annotation, check_all_mask_paths_remote from .constants import ( ANNOTATIONS_KEY, AUTOTAG_SCORE_THRESHOLD, @@ -304,19 +296,13 @@ def create_model_run( def annotate( self, - annotations: Sequence[ - Union[ - BoxAnnotation, - PolygonAnnotation, - CuboidAnnotation, - CategoryAnnotation, - MultiCategoryAnnotation, - SegmentationAnnotation, - ] - ], + annotations: Sequence[Annotation], update: bool = DEFAULT_ANNOTATION_UPDATE_MODE, batch_size: int = 5000, asynchronous: bool = False, + remote_files_per_upload_request: int = 20, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, ) -> Union[Dict[str, Any], AsyncJob]: """Uploads ground truth annotations to the dataset. @@ -349,9 +335,20 @@ def annotate( objects to upload. update: Whether to ignore or overwrite metadata for conflicting annotations. batch_size: Number of annotations processed in each concurrent batch. - Default is 5000. + Default is 5000. If you get timeouts when uploading geometric annotations, + you can try lowering this batch size. asynchronous: Whether or not to process the upload asynchronously (and return an :class:`AsyncJob` object). Default is False. + remote_files_per_upload_request: Number of remote files to upload in each + request. Segmentations have either local or remote files, if you are + getting timeouts while uploading segmentations with remote urls, you + should lower this value from its default of 20. + local_files_per_upload_request: Number of local files to upload in each + request. Segmentations have either local or remote files, if you are + getting timeouts while uploading segmentations with local files, you + should lower this value from its default of 10. The maximum is 10. + local_file_upload_concurrency: Number of concurrent local file uploads. + Returns: If synchronous, payload describing the upload result:: @@ -363,9 +360,8 @@ def annotate( Otherwise, returns an :class:`AsyncJob` object. """ - check_all_mask_paths_remote(annotations) - if asynchronous: + check_all_mask_paths_remote(annotations) request_id = serialize_and_write_to_presigned_url( annotations, self.id, self._client ) @@ -374,8 +370,14 @@ def annotate( route=f"dataset/{self.id}/annotate?async=1", ) return AsyncJob.from_json(response, self._client) - return self._client.annotate_dataset( - self.id, annotations, update=update, batch_size=batch_size + uploader = AnnotationUploader(dataset_id=self.id, client=self._client) + return uploader.upload( + annotations=annotations, + update=update, + batch_size=batch_size, + remote_files_per_upload_request=remote_files_per_upload_request, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, ) def ingest_tasks(self, task_ids: List[str]) -> dict: @@ -412,6 +414,8 @@ def append( update: bool = False, batch_size: int = 20, asynchronous: bool = False, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, ) -> Union[Dict[Any, Any], AsyncJob, UploadResponse]: """Appends items or scenes to a dataset. @@ -482,12 +486,20 @@ def append( Sequence[:class:`LidarScene`] \ Sequence[:class:`VideoScene`] ]): List of items or scenes to upload. - batch_size: Size of the batch for larger uploads. Default is 20. + batch_size: Size of the batch for larger uploads. Default is 20. This is + for items that have a remote URL and do not require a local upload. + If you get timeouts for uploading remote urls, try decreasing this. update: Whether or not to overwrite metadata on reference ID collision. Default is False. asynchronous: Whether or not to process the upload asynchronously (and return an :class:`AsyncJob` object). This is required when uploading scenes. Default is False. + files_per_upload_request: How large to make each upload request when your + files are local. If you get timeouts, you may need to lower this from + its default of 10. The default is 10. + local_file_upload_concurrency: How many local file requests to send + concurrently. If you start to see gateway timeouts or cloudflare related + errors, you may need to lower this from its default of 30. Returns: For scenes @@ -557,6 +569,8 @@ def append( dataset_items, update=update, batch_size=batch_size, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, ) @deprecated("Prefer using Dataset.append instead.") @@ -1270,6 +1284,10 @@ def upload_predictions( ], update: bool = False, asynchronous: bool = False, + batch_size: int = 5000, + remote_files_per_upload_request: int = 20, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, ): """Uploads predictions and associates them with an existing :class:`Model`. @@ -1312,6 +1330,21 @@ def upload_predictions( collision. Default is False. asynchronous: Whether or not to process the upload asynchronously (and return an :class:`AsyncJob` object). Default is False. + batch_size: Number of predictions processed in each concurrent batch. + Default is 5000. If you get timeouts when uploading geometric predictions, + you can try lowering this batch size. This is only relevant for + asynchronous=False + remote_files_per_upload_request: Number of remote files to upload in each + request. Segmentations have either local or remote files, if you are + getting timeouts while uploading segmentations with remote urls, you + should lower this value from its default of 20. This is only relevant for + asynchronous=False. + local_files_per_upload_request: Number of local files to upload in each + request. Segmentations have either local or remote files, if you are + getting timeouts while uploading segmentations with local files, you + should lower this value from its default of 10. The maximum is 10. + This is only relevant for asynchronous=False + local_file_upload_concurrency: Number of concurrent local file uploads. Returns: Payload describing the synchronous upload:: @@ -1335,12 +1368,19 @@ def upload_predictions( ) return AsyncJob.from_json(response, self._client) else: - return self._client.predict( + uploader = PredictionUploader( model_run_id=None, dataset_id=self.id, model_id=model.id, + client=self._client, + ) + return uploader.upload( annotations=predictions, + batch_size=batch_size, update=update, + remote_files_per_upload_request=remote_files_per_upload_request, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, ) def predictions_iloc(self, model, index): @@ -1431,6 +1471,8 @@ def _upload_items( dataset_items: List[DatasetItem], batch_size: int = 20, update: bool = False, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, ) -> UploadResponse: """ Appends images to a dataset with given dataset_id. @@ -1438,8 +1480,16 @@ def _upload_items( Args: dataset_items: Items to Upload - batch_size: size of the batch for long payload + batch_size: how many items with remote urls to include in each request. + If you get timeouts for uploading remote urls, try decreasing this. update: Update records on conflict otherwise overwrite + local_files_per_upload_request: How large to make each upload request when your + files are local. If you get timeouts, you may need to lower this from + its default of 10. The maximum is 10. + local_file_upload_concurrency: How many local file requests to send + concurrently. If you start to see gateway timeouts or cloudflare related + errors, you may need to lower this from its default of 30. + Returns: UploadResponse """ @@ -1450,9 +1500,14 @@ def _upload_items( "client.create_dataset(, is_scene=False) or add the dataset items to " "an existing dataset supporting dataset items." ) - - populator = DatasetItemUploader(self.id, self._client) - return populator.upload(dataset_items, batch_size, update) + uploader = DatasetItemUploader(self.id, self._client) + return uploader.upload( + dataset_items=dataset_items, + batch_size=batch_size, + update=update, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, + ) def update_scene_metadata(self, mapping: Dict[str, dict]): """ diff --git a/nucleus/dataset_item_uploader.py b/nucleus/dataset_item_uploader.py index e2e33fca..a803ca92 100644 --- a/nucleus/dataset_item_uploader.py +++ b/nucleus/dataset_item_uploader.py @@ -1,25 +1,26 @@ -import asyncio import json import os -import time -from typing import TYPE_CHECKING, Any, List - -import aiohttp -import nest_asyncio +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + Callable, + List, + Sequence, + Tuple, +) -from .constants import ( - DATASET_ID_KEY, - DEFAULT_NETWORK_TIMEOUT_SEC, - IMAGE_KEY, - IMAGE_URL_KEY, - ITEMS_KEY, - UPDATE_KEY, +from nucleus.async_utils import ( + FileFormData, + FileFormField, + FormDataContextHandler, + make_many_form_data_requests_concurrently, ) + +from .constants import DATASET_ID_KEY, IMAGE_KEY, ITEMS_KEY, UPDATE_KEY from .dataset_item import DatasetItem from .errors import NotFoundError -from .logger import logger from .payload_constructor import construct_append_payload -from .retry_strategy import RetryStrategy from .upload_response import UploadResponse if TYPE_CHECKING: @@ -34,14 +35,20 @@ def __init__(self, dataset_id: str, client: "NucleusClient"): # noqa: F821 def upload( self, dataset_items: List[DatasetItem], - batch_size: int = 20, + batch_size: int = 5000, update: bool = False, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, ) -> UploadResponse: """ Args: dataset_items: Items to Upload - batch_size: How many items to pool together for a single request + batch_size: How many items to pool together for a single request for items + without files to upload + files_per_upload_request: How many items to pool together for a single + request for items with files to upload + update: Update records instead of overwriting Returns: @@ -49,6 +56,8 @@ def upload( """ local_items = [] remote_items = [] + if local_files_per_upload_request > 10: + raise ValueError("local_files_per_upload_request should be <= 10") # Check local files exist before sending requests for item in dataset_items: @@ -59,41 +68,35 @@ def upload( else: remote_items.append(item) - local_batches = [ - local_items[i : i + batch_size] - for i in range(0, len(local_items), batch_size) - ] - - remote_batches = [ - remote_items[i : i + batch_size] - for i in range(0, len(remote_items), batch_size) - ] - agg_response = UploadResponse(json={DATASET_ID_KEY: self.dataset_id}) async_responses: List[Any] = [] - if local_batches: - tqdm_local_batches = self._client.tqdm_bar( - local_batches, desc="Local file batches" + if local_items: + async_responses.extend( + self._process_append_requests_local( + self.dataset_id, + items=local_items, + update=update, + batch_size=batch_size, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, + ) ) - for batch in tqdm_local_batches: - payload = construct_append_payload(batch, update) - responses = self._process_append_requests_local( - self.dataset_id, payload, update - ) - async_responses.extend(responses) + remote_batches = [ + remote_items[i : i + batch_size] + for i in range(0, len(remote_items), batch_size) + ] if remote_batches: tqdm_remote_batches = self._client.tqdm_bar( remote_batches, desc="Remote file batches" ) for batch in tqdm_remote_batches: - payload = construct_append_payload(batch, update) responses = self._process_append_requests( dataset_id=self.dataset_id, - payload=payload, + payload=construct_append_payload(batch, update), update=update, batch_size=batch_size, ) @@ -107,173 +110,34 @@ def upload( def _process_append_requests_local( self, dataset_id: str, - payload: dict, - update: bool, # TODO: understand how to pass this in. - local_batch_size: int = 10, + items: Sequence[DatasetItem], + update: bool, + batch_size: int, + local_files_per_upload_request: int, + local_file_upload_concurrency: int, ): - def get_files(batch): - for item in batch: - item[UPDATE_KEY] = update - request_payload = [ - ( - ITEMS_KEY, - ( - None, - json.dumps(batch, allow_nan=False), - "application/json", - ), - ) - ] - for item in batch: - image = open( # pylint: disable=R1732 - item.get(IMAGE_URL_KEY), "rb" # pylint: disable=R1732 - ) # pylint: disable=R1732 - img_name = os.path.basename(image.name) - img_type = ( - f"image/{os.path.splitext(image.name)[1].strip('.')}" - ) - request_payload.append( - (IMAGE_KEY, (img_name, image, img_type)) - ) - return request_payload + # Batch into requests + requests = [] + batch_size = local_files_per_upload_request + for i in range(0, len(items), batch_size): + batch = items[i : i + batch_size] + request = FormDataContextHandler( + self.get_form_data_and_file_pointers_fn(batch, update) + ) + requests.append(request) - items = payload[ITEMS_KEY] - responses: List[Any] = [] - files_per_request = [] - payload_items = [] - for i in range(0, len(items), local_batch_size): - batch = items[i : i + local_batch_size] - files_per_request.append(get_files(batch)) - payload_items.append(batch) + progressbar = self._client.tqdm_bar( + total=len(requests), desc="Local file batches" + ) - future = self.make_many_files_requests_asynchronously( - files_per_request, + return make_many_form_data_requests_concurrently( + self._client, + requests, f"dataset/{dataset_id}/append", + progressbar=progressbar, + concurrency=local_file_upload_concurrency, ) - try: - loop = asyncio.get_event_loop() - except RuntimeError: # no event loop running: - loop = asyncio.new_event_loop() - responses = loop.run_until_complete(future) - else: - nest_asyncio.apply(loop) - return loop.run_until_complete(future) - - def close_files(request_items): - for item in request_items: - # file buffer in location [1][1] - if item[0] == IMAGE_KEY: - item[1][1].close() - - # don't forget to close all open files - for p in files_per_request: - close_files(p) - - return responses - - async def make_many_files_requests_asynchronously( - self, files_per_request, route - ): - """ - Makes an async post request with files to a Nucleus endpoint. - - :param files_per_request: A list of lists of tuples (name, (filename, file_pointer, content_type)) - name will become the name by which the multer can build an array. - :param route: route for the request - :return: awaitable list(response) - """ - async with aiohttp.ClientSession() as session: - tasks = [ - asyncio.ensure_future( - self._make_files_request( - files=files, route=route, session=session - ) - ) - for files in files_per_request - ] - return await asyncio.gather(*tasks) - - async def _make_files_request( - self, - files, - route: str, - session: aiohttp.ClientSession, - retry_attempt=0, - max_retries=3, - sleep_intervals=(1, 3, 9), - ): - """ - Makes an async post request with files to a Nucleus endpoint. - - :param files: A list of tuples (name, (filename, file_pointer, file_type)) - :param route: route for the request - :param session: Session to use for post. - :return: response - """ - endpoint = f"{self._client.endpoint}/{route}" - - logger.info("Posting to %s", endpoint) - - form = aiohttp.FormData() - - for file in files: - form.add_field( - name=file[0], - filename=file[1][0], - value=file[1][1], - content_type=file[1][2], - ) - - for sleep_time in RetryStrategy.sleep_times + [-1]: - - async with session.post( - endpoint, - data=form, - auth=aiohttp.BasicAuth(self._client.api_key, ""), - timeout=DEFAULT_NETWORK_TIMEOUT_SEC, - ) as response: - logger.info( - "API request has response code %s", response.status - ) - - try: - data = await response.json() - except aiohttp.client_exceptions.ContentTypeError: - # In case of 404, the server returns text - data = await response.text() - if ( - response.status in RetryStrategy.statuses - and sleep_time != -1 - ): - time.sleep(sleep_time) - continue - - if not response.ok: - if retry_attempt < max_retries: - time.sleep(sleep_intervals[retry_attempt]) - retry_attempt += 1 - return self._make_files_request( - files, - route, - session, - retry_attempt, - max_retries, - sleep_intervals, - ) - else: - self._client.handle_bad_response( - endpoint, - session.post, - aiohttp_response=( - response.status, - response.reason, - data, - ), - ) - - return data - def _process_append_requests( self, dataset_id: str, @@ -283,11 +147,9 @@ def _process_append_requests( ): items = payload[ITEMS_KEY] payloads = [ - # batch_size images per request {ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update} for i in range(0, len(items), batch_size) ] - return [ self._client.make_request( payload, @@ -295,3 +157,55 @@ def _process_append_requests( ) for payload in payloads ] + + def get_form_data_and_file_pointers_fn( + self, items: Sequence[DatasetItem], update: bool + ) -> Callable[..., Tuple[FileFormData, Sequence[BinaryIO]]]: + """Defines a function to be called on each retry. + + File pointers are also returned so whoever calls this function can + appropriately close the files. This is intended for use with a + FormDataContextHandler in order to make form data requests. + """ + + def fn(): + + # For some reason, our backend only accepts this reformatting of items when + # doing local upload. + # TODO: make it just accept the same exact format as a normal append request + # i.e. the output of construct_append_payload(items, update) + json_data = [] + for item in items: + item_payload = item.to_payload() + item_payload[UPDATE_KEY] = update + json_data.append(item_payload) + + form_data = [ + FileFormField( + name=ITEMS_KEY, + filename=None, + value=json.dumps(json_data, allow_nan=False), + content_type="application/json", + ) + ] + + file_pointers = [] + for item in items: + # I don't know of a way to use with, since all files in the request + # need to be opened at the same time. + # pylint: disable=consider-using-with + image_fp = open(item.image_location, "rb") + # pylint: enable=consider-using-with + img_type = f"image/{os.path.splitext(item.image_location)[1].strip('.')}" + form_data.append( + FileFormField( + name=IMAGE_KEY, + filename=item.image_location, + value=image_fp, + content_type=img_type, + ) + ) + file_pointers.append(image_fp) + return form_data, file_pointers + + return fn diff --git a/nucleus/model_run.py b/nucleus/model_run.py index 137d7cf8..993eda7e 100644 --- a/nucleus/model_run.py +++ b/nucleus/model_run.py @@ -18,6 +18,7 @@ import requests from nucleus.annotation import check_all_mask_paths_remote +from nucleus.annotation_uploader import PredictionUploader from nucleus.job import AsyncJob from nucleus.utils import ( format_prediction_response, @@ -114,12 +115,38 @@ def predict( SegmentationPrediction, ] ], - update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE, + update: bool = DEFAULT_ANNOTATION_UPDATE_MODE, asynchronous: bool = False, + batch_size: int = 5000, + remote_files_per_upload_request: int = 20, + local_files_per_upload_request: int = 10, + local_file_upload_concurrency: int = 30, ) -> Union[dict, AsyncJob]: """ Uploads model outputs as predictions for a model_run. Returns info about the upload. - :param annotations: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]], + + Args: + annotations: Predictions to upload for this model run, + update: If True, existing predictions for the same (reference_id, annotation_id) + will be overwritten. If False, existing predictions will be skipped. + asynchronous: Whether or not to process the upload asynchronously (and + return an :class:`AsyncJob` object). Default is False. + batch_size: Number of predictions processed in each concurrent batch. + Default is 5000. If you get timeouts when uploading geometric annotations, + you can try lowering this batch size. This is only relevant for + asynchronous=False. + remote_files_per_upload_request: Number of remote files to upload in each + request. Segmentations have either local or remote files, if you are + getting timeouts while uploading segmentations with remote urls, you + should lower this value from its default of 20. This is only relevant for + asynchronous=False + local_files_per_upload_request: Number of local files to upload in each + request. Segmentations have either local or remote files, if you are + getting timeouts while uploading segmentations with local files, you + should lower this value from its default of 10. The maximum is 10. + This is only relevant for asynchronous=False + local_file_upload_concurrency: Number of concurrent local file uploads. + This is only relevant for asynchronous=False :return: { "model_run_id": str, @@ -138,12 +165,17 @@ def predict( route=f"modelRun/{self.model_run_id}/predict?async=1", ) return AsyncJob.from_json(response, self._client) - else: - return self._client.predict( - model_run_id=self.model_run_id, - annotations=annotations, - update=update, - ) + uploader = PredictionUploader( + model_run_id=self.model_run_id, client=self._client + ) + return uploader.upload( + annotations=annotations, + update=update, + batch_size=batch_size, + remote_files_per_upload_request=remote_files_per_upload_request, + local_files_per_upload_request=local_files_per_upload_request, + local_file_upload_concurrency=local_file_upload_concurrency, + ) def iloc(self, i: int): """ diff --git a/nucleus/payload_constructor.py b/nucleus/payload_constructor.py index 38609368..f21f6990 100644 --- a/nucleus/payload_constructor.py +++ b/nucleus/payload_constructor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from .annotation import ( BoxAnnotation, @@ -72,11 +72,22 @@ def construct_annotation_payload( ], update: bool, ) -> dict: - annotations = [] - for annotation_item in annotation_items: - annotations.append(annotation_item.to_payload()) - - return {ANNOTATIONS_KEY: annotations, ANNOTATION_UPDATE_KEY: update} + annotations = [ + annotation.to_payload() + for annotation in annotation_items + if not isinstance(annotation, SegmentationAnnotation) + ] + segmentations = [ + annotation.to_payload() + for annotation in annotation_items + if isinstance(annotation, SegmentationAnnotation) + ] + payload: Dict[str, Any] = {ANNOTATION_UPDATE_KEY: update} + if annotations: + payload[ANNOTATIONS_KEY] = annotations + if segmentations: + payload[SEGMENTATIONS_KEY] = segmentations + return payload def construct_segmentation_payload( diff --git a/nucleus/retry_strategy.py b/nucleus/retry_strategy.py index eabc1309..fae0ff33 100644 --- a/nucleus/retry_strategy.py +++ b/nucleus/retry_strategy.py @@ -1,4 +1,12 @@ # TODO: use retry library instead of custom code. Tenacity is one option. +import random + + class RetryStrategy: statuses = {503, 524, 520, 504} - sleep_times = [1, 3, 9] + + @staticmethod + def sleep_times(): + sleep_times = [1, 3, 9, 27] # These are in seconds + + return [2 * random.random() * t for t in sleep_times] diff --git a/tests/helpers.py b/tests/helpers.py index d5f67d87..6a4a34aa 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -301,6 +301,11 @@ def reference_id_from_url(url): TEST_MASK_URL = "https://raw.githubusercontent.com/scaleapi/nucleus-python-client/master/tests/testdata/000000000285.png" +this_dir = os.path.dirname(os.path.realpath(__file__)) +TEST_LOCAL_MASK_URL = os.path.join(this_dir, "testdata/000000000285.png") + + +NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET = len(TEST_DATASET_ITEMS) TEST_SEGMENTATION_ANNOTATIONS = [ { "reference_id": reference_id_from_url(TEST_IMG_URLS[i]), diff --git a/tests/test_annotation.py b/tests/test_annotation.py index ced5ff11..6b218817 100644 --- a/tests/test_annotation.py +++ b/tests/test_annotation.py @@ -34,7 +34,6 @@ assert_multicategory_annotation_matches_dict, assert_partial_equality, assert_polygon_annotation_matches_dict, - assert_segmentation_annotation_matches_dict, reference_id_from_url, ) @@ -242,70 +241,6 @@ def test_default_multicategory_gt_upload(dataset): ) -def test_single_semseg_gt_upload(dataset): - annotation = SegmentationAnnotation.from_json( - TEST_SEGMENTATION_ANNOTATIONS[0] - ) - response = dataset.annotate(annotations=[annotation]) - assert response["dataset_id"] == dataset.id - assert response["annotations_processed"] == 1 - assert response["annotations_ignored"] == 0 - - response_annotation = dataset.refloc(annotation.reference_id)[ - "annotations" - ]["segmentation"][0] - assert_segmentation_annotation_matches_dict( - response_annotation, TEST_SEGMENTATION_ANNOTATIONS[0] - ) - - -def test_batch_semseg_gt_upload(dataset): - annotations = [ - SegmentationAnnotation.from_json(ann) - for ann in TEST_SEGMENTATION_ANNOTATIONS - ] - response = dataset.annotate(annotations=annotations) - assert response["dataset_id"] == dataset.id - assert response["annotations_processed"] == 5 - assert response["annotations_ignored"] == 0 - - -def test_batch_semseg_gt_upload_ignore(dataset): - # First upload annotations - annotations = [ - SegmentationAnnotation.from_json(ann) - for ann in TEST_SEGMENTATION_ANNOTATIONS - ] - response = dataset.annotate(annotations=annotations) - assert response["dataset_id"] == dataset.id - assert response["annotations_processed"] == 5 - assert response["annotations_ignored"] == 0 - - # When we re-upload, expect them to be ignored - response = dataset.annotate(annotations=annotations) - assert response["dataset_id"] == dataset.id - assert response["annotations_processed"] == 0 - assert response["annotations_ignored"] == 5 - - -def test_batch_semseg_gt_upload_update(dataset): - # First upload annotations - annotations = [ - SegmentationAnnotation.from_json(ann) - for ann in TEST_SEGMENTATION_ANNOTATIONS - ] - response = dataset.annotate(annotations=annotations) - assert response["dataset_id"] == dataset.id - assert response["annotations_processed"] == 5 - assert response["annotations_ignored"] == 0 - - # When we re-upload, expect uploads to be processed - response = dataset.annotate(annotations=annotations, update=True) - assert response["dataset_id"] == dataset.id - assert response["annotations_processed"] == 5 - assert response["annotations_ignored"] == 0 - - def test_mixed_annotation_upload(dataset): # First upload annotations semseg_annotations = [ diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 8594f4d0..adc3c7cc 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -262,6 +262,7 @@ def test_dataset_append_local(CLIENT, dataset): reference_id="bad", ) ] + num_local_items_to_test = 10 with pytest.raises(ValueError) as e: dataset.append(ds_items_local_error) assert "Out of range float values are not JSON compliant" in str( @@ -271,8 +272,9 @@ def test_dataset_append_local(CLIENT, dataset): DatasetItem( image_location=LOCAL_FILENAME, metadata={"test": 0}, - reference_id=LOCAL_FILENAME.split("/")[-1], + reference_id=LOCAL_FILENAME.split("/")[-1] + str(i), ) + for i in range(num_local_items_to_test) ] response = dataset.append(ds_items_local) @@ -280,7 +282,7 @@ def test_dataset_append_local(CLIENT, dataset): assert isinstance(response, UploadResponse) resp_json = response.json() assert resp_json[DATASET_ID_KEY] == dataset.id - assert resp_json[NEW_ITEMS] == 1 + assert resp_json[NEW_ITEMS] == num_local_items_to_test assert resp_json[UPDATED_ITEMS] == 0 assert resp_json[IGNORED_ITEMS] == 0 assert resp_json[ERROR_ITEMS] == 0 diff --git a/tests/test_prediction.py b/tests/test_prediction.py index cc8ba6f4..3b06fb0e 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -194,39 +194,6 @@ def test_non_existent_taxonomy_category_gt_upload(model_run): ) -def test_segmentation_pred_upload(model_run): - prediction = SegmentationPrediction.from_json( - TEST_SEGMENTATION_PREDICTIONS[0] - ) - response = model_run.predict(annotations=[prediction]) - - assert response["model_run_id"] == model_run.model_run_id - assert response["predictions_processed"] == 1 - assert response["predictions_ignored"] == 0 - - response = model_run.refloc(prediction.reference_id)["segmentation"] - assert isinstance(response[0], SegmentationPrediction) - - assert_segmentation_annotation_matches_dict( - response[0], TEST_SEGMENTATION_PREDICTIONS[0] - ) - - -def test_segmentation_pred_upload_ignore(model_run): - prediction = SegmentationPrediction.from_json( - TEST_SEGMENTATION_PREDICTIONS[0] - ) - response1 = model_run.predict(annotations=[prediction]) - - assert response1["predictions_processed"] == 1 - - # Upload Duplicate annotation - response = model_run.predict(annotations=[prediction]) - assert response["model_run_id"] == model_run.model_run_id - assert response["predictions_processed"] == 0 - assert response["predictions_ignored"] == 1 - - def test_box_pred_upload_update(model_run): prediction = BoxPrediction(**TEST_BOX_PREDICTIONS[0]) response = model_run.predict(annotations=[prediction]) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py new file mode 100644 index 00000000..7bf0fe23 --- /dev/null +++ b/tests/test_segmentation.py @@ -0,0 +1,188 @@ +from nucleus.annotation import SegmentationAnnotation +from nucleus.dataset import Dataset +from nucleus.model import Model +from nucleus.prediction import SegmentationPrediction +from tests.helpers import ( + NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET, + TEST_LOCAL_MASK_URL, + TEST_SEGMENTATION_ANNOTATIONS, + TEST_SEGMENTATION_PREDICTIONS, + assert_segmentation_annotation_matches_dict, +) + + +def test_segmentation_pred_upload_local(dataset: Dataset, model: Model): + prediction = SegmentationPrediction.from_json( + TEST_SEGMENTATION_PREDICTIONS[0] + ) + prediction.mask_url = TEST_LOCAL_MASK_URL + response = dataset.upload_predictions(model, [prediction]) + + assert response["predictions_processed"] == 1 + + response = dataset.predictions_refloc(model, prediction.reference_id)[ + "segmentation" + ][0] + assert isinstance(response, SegmentationPrediction) + assert response == prediction + + +def test_segmentation_pred_upload(dataset: Dataset, model: Model): + prediction = SegmentationPrediction.from_json( + TEST_SEGMENTATION_PREDICTIONS[0] + ) + response = dataset.upload_predictions(model, [prediction]) + + assert response["predictions_processed"] == 1 + assert response["predictions_ignored"] == 0 + + response = dataset.predictions_refloc(model, prediction.reference_id)[ + "segmentation" + ] + assert isinstance(response[0], SegmentationPrediction) + + assert_segmentation_annotation_matches_dict( + response[0], TEST_SEGMENTATION_PREDICTIONS[0] + ) + + +def test_segmentation_pred_upload_ignore(dataset: Dataset, model: Model): + prediction = SegmentationPrediction.from_json( + TEST_SEGMENTATION_PREDICTIONS[0] + ) + response1 = dataset.upload_predictions(model, [prediction]) + + assert response1["predictions_processed"] == 1 + + # Upload Duplicate annotation + response = dataset.upload_predictions(model, [prediction]) + assert response["predictions_processed"] == 0 + assert response["predictions_ignored"] == 1 + + +def test_single_local_semseg_gt_upload(dataset: Dataset): + request_annotation = SegmentationAnnotation.from_json( + TEST_SEGMENTATION_ANNOTATIONS[0] + ) + request_annotation.mask_url = TEST_LOCAL_MASK_URL + response = dataset.annotate(annotations=[request_annotation]) + + assert response["dataset_id"] == dataset.id + assert response["annotations_processed"] == 1 + assert response["annotations_ignored"] == 0 + + response_annotation = dataset.refloc(request_annotation.reference_id)[ + "annotations" + ]["segmentation"][0] + + assert response_annotation == request_annotation + + +def test_batch_local_semseg_gt_upload(dataset: Dataset): + + # This reference id is not in the dataset. + bad_reference_id = TEST_SEGMENTATION_ANNOTATIONS[-1]["reference_id"] + + request_annotations = [ + SegmentationAnnotation.from_json(json_data) + for json_data in TEST_SEGMENTATION_ANNOTATIONS + ] + for request_annotation in request_annotations: + request_annotation.mask_url = TEST_LOCAL_MASK_URL + response = dataset.annotate(annotations=request_annotations) + + assert response["dataset_id"] == dataset.id + assert ( + response["annotations_processed"] + == NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ) + assert response["annotations_ignored"] == 0 + assert bad_reference_id in response["errors"][0] + + for request_annotation in request_annotations[ + :NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ]: + response_annotation = dataset.refloc(request_annotation.reference_id)[ + "annotations" + ]["segmentation"][0] + + assert response_annotation == request_annotation + + +def test_single_semseg_gt_upload(dataset): + annotation = SegmentationAnnotation.from_json( + TEST_SEGMENTATION_ANNOTATIONS[0] + ) + response = dataset.annotate(annotations=[annotation]) + assert response["dataset_id"] == dataset.id + assert response["annotations_processed"] == 1 + assert response["annotations_ignored"] == 0 + + response_annotation = dataset.refloc(annotation.reference_id)[ + "annotations" + ]["segmentation"][0] + assert_segmentation_annotation_matches_dict( + response_annotation, TEST_SEGMENTATION_ANNOTATIONS[0] + ) + + +def test_batch_semseg_gt_upload(dataset): + annotations = [ + SegmentationAnnotation.from_json(ann) + for ann in TEST_SEGMENTATION_ANNOTATIONS + ] + response = dataset.annotate(annotations=annotations) + assert response["dataset_id"] == dataset.id + assert ( + response["annotations_processed"] + == NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ) + assert response["annotations_ignored"] == 0 + + +def test_batch_semseg_gt_upload_ignore(dataset): + # First upload annotations + annotations = [ + SegmentationAnnotation.from_json(ann) + for ann in TEST_SEGMENTATION_ANNOTATIONS + ] + response = dataset.annotate(annotations=annotations) + assert response["dataset_id"] == dataset.id + assert ( + response["annotations_processed"] + == NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ) + assert response["annotations_ignored"] == 0 + + # When we re-upload, expect them to be ignored + response = dataset.annotate(annotations=annotations) + assert response["dataset_id"] == dataset.id + assert response["annotations_processed"] == 0 + assert ( + response["annotations_ignored"] + == NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ) + + +def test_batch_semseg_gt_upload_update(dataset): + # First upload annotations + annotations = [ + SegmentationAnnotation.from_json(ann) + for ann in TEST_SEGMENTATION_ANNOTATIONS + ] + response = dataset.annotate(annotations=annotations) + assert response["dataset_id"] == dataset.id + assert ( + response["annotations_processed"] + == NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ) + assert response["annotations_ignored"] == 0 + + # When we re-upload, expect uploads to be processed + response = dataset.annotate(annotations=annotations, update=True) + assert response["dataset_id"] == dataset.id + assert ( + response["annotations_processed"] + == NUM_VALID_SEGMENTATIONS_IN_MAIN_DATASET + ) + assert response["annotations_ignored"] == 0