From e65a451bee4ea99f7a08e8b3761822a21be16e13 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Fri, 25 Feb 2022 09:29:10 +0100 Subject: [PATCH 01/33] Pass eval_func_arguments to backend with EvaluationCriteria --- .../data_transfer_objects/eval_function.py | 4 +++- .../eval_functions/available_eval_functions.py | 14 ++++++++++++-- .../validate/eval_functions/base_eval_function.py | 6 +++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/nucleus/validate/data_transfer_objects/eval_function.py b/nucleus/validate/data_transfer_objects/eval_function.py index 1e165de5..2592be3d 100644 --- a/nucleus/validate/data_transfer_objects/eval_function.py +++ b/nucleus/validate/data_transfer_objects/eval_function.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional from pydantic import validator @@ -50,12 +50,14 @@ class EvaluationCriterion(ImmutableModel): eval_function_id (str): ID of evaluation function threshold_comparison (:class:`ThresholdComparison`): comparator for evaluation. i.e. threshold=0.5 and threshold_comparator > implies that a test only passes if score > 0.5. threshold (float): numerical threshold that together with threshold comparison, defines success criteria for test evaluation. + eval_func_arguments: Arguments to pass to the eval function constructor """ # TODO: Having only eval_function_id hurts readability -> Add function name eval_function_id: str threshold_comparison: ThresholdComparison threshold: float + eval_func_arguments: Dict[str, Any] @validator("eval_function_id") def valid_eval_function_id(cls, v): # pylint: disable=no-self-argument diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index 090e63f4..d2fd4e74 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -1,5 +1,5 @@ import itertools -from typing import Callable, Dict, List, Type, Union +from typing import Callable, Dict, List, Optional, Type, Union from nucleus.logger import logger from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction @@ -35,6 +35,16 @@ def expected_name(cls) -> str: class CategorizationF1(BaseEvalFunction): + def __call__( + self, + confidence_threshold: Optional[float] = None, + f1_method: Optional[str] = None, + **kwargs, + ): + return super().__call__( + confidence_threshold=confidence_threshold, f1_method=f1_method + ) + @classmethod def expected_name(cls) -> str: return "cat_f1" @@ -140,7 +150,7 @@ def __init__(self, available_functions: List[EvalFunctionEntry]): self.bbox_map = self._assign_eval_function_if_defined( BoundingBoxMeanAveragePrecision # type: ignore ) - self.cat_f1 = self._assign_eval_function_if_defined( + self.cat_f1: CategorizationF1 = self._assign_eval_function_if_defined( CategorizationF1 # type: ignore ) diff --git a/nucleus/validate/eval_functions/base_eval_function.py b/nucleus/validate/eval_functions/base_eval_function.py index 1ea4c931..67ee8659 100644 --- a/nucleus/validate/eval_functions/base_eval_function.py +++ b/nucleus/validate/eval_functions/base_eval_function.py @@ -1,4 +1,5 @@ import abc +from typing import Any, Dict from ..constants import ThresholdComparison from ..data_transfer_objects.eval_function import ( @@ -17,6 +18,7 @@ def __init__(self, eval_func_entry: EvalFunctionEntry): self.eval_func_entry = eval_func_entry self.id = eval_func_entry.id self.name = eval_func_entry.name + self.eval_func_arguments: Dict[str, Any] = {} def __repr__(self): return f"" @@ -26,7 +28,7 @@ def __repr__(self): def expected_name(cls) -> str: """Name to look for in the EvalFunctionDefinitions""" - def __call__(self) -> "BaseEvalFunction": + def __call__(self, **kwargs) -> "BaseEvalFunction": """Adding call to prepare for being able to pass parameters to function Notes: @@ -34,6 +36,7 @@ def __call__(self) -> "BaseEvalFunction": to look like eval_function() > 0.5 to support eval_function(parameters) > 0.5 in the future """ + self.eval_func_arguments.update(**kwargs) return self def __gt__(self, other) -> EvaluationCriterion: @@ -57,6 +60,7 @@ def _op_to_test_metric(self, comparison: ThresholdComparison, value): eval_function_id=self.eval_func_entry.id, threshold_comparison=comparison, threshold=value, + eval_func_arguments=self.eval_func_arguments, ) def to_entry(self): From d3231e330ed0551a992f69a6e8ce2ae1f674889c Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Fri, 25 Feb 2022 10:41:55 +0100 Subject: [PATCH 02/33] Add better error message for scenario_test misconfiguration and arguments to all public functions --- nucleus/metrics/categorization_metrics.py | 3 +- nucleus/validate/client.py | 34 ++++- .../available_eval_functions.py | 132 ++++++++++++++++-- tests/validate/test_scenario_test.py | 26 ++++ 4 files changed, 178 insertions(+), 17 deletions(-) diff --git a/nucleus/metrics/categorization_metrics.py b/nucleus/metrics/categorization_metrics.py index 416f831a..80979c8e 100644 --- a/nucleus/metrics/categorization_metrics.py +++ b/nucleus/metrics/categorization_metrics.py @@ -143,7 +143,8 @@ def __init__( ): """ Args: - confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. + Must be in [0, 1]. Default 0.0 f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \ default='macro' This parameter is required for multiclass/multilabel targets. diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index d99afc07..61538fba 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -7,16 +7,18 @@ from .data_transfer_objects.eval_function import GetEvalFunctions from .data_transfer_objects.scenario_test import CreateScenarioTestRequest from .errors import CreateScenarioTestError -from .eval_functions.available_eval_functions import ( - AvailableEvalFunctions, - EvalFunction, -) +from .eval_functions.available_eval_functions import AvailableEvalFunctions +from .eval_functions.base_eval_function import BaseEvalFunction from .scenario_test import ScenarioTest SUCCESS_KEY = "success" EVAL_FUNCTIONS_KEY = "eval_functions" +class InvalidEvaluationCriteria(Exception): + pass + + class Validate: """Model CI Python Client extension.""" @@ -78,6 +80,30 @@ def create_scenario_test( "Must pass an evaluation_function to the scenario test! I.e. " "evaluation_functions=[client.validate.eval_functions.bbox_iou()]" ) + incorrect_type = [ + crit + for crit in evaluation_criteria + if not isinstance(crit, EvaluationCriterion) + ] + if incorrect_type: + # NOTE: We expect people to forget adding comparison to these calls so make an explicit error msg. + eval_funcs = [ + incorrect + for incorrect in incorrect_type + if isinstance(incorrect, BaseEvalFunction) + ] + if eval_funcs: + example: BaseEvalFunction = eval_funcs[0] + example_call = f"{example.name}()" + msg = ( + f"Expected a comparison (<, <=, >, >=) for every `evaluation_criteria`. " + f"You should add a comparison to {eval_funcs}. " + f"I.e. `{example_call} > 0.5` instead of just `{example_call}`" + ) + else: + msg = f"Received an incorrect `evaluation_criteria`: {repr(incorrect_type)}" + raise InvalidEvaluationCriteria(msg) + response = self.connection.post( CreateScenarioTestRequest( name=name, diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index d2fd4e74..bef31582 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -1,5 +1,5 @@ import itertools -from typing import Callable, Dict, List, Optional, Type, Union +from typing import Callable, Dict, List, Optional, Union from nucleus.logger import logger from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction @@ -11,24 +11,102 @@ class BoundingBoxIOU(BaseEvalFunction): + def __call__( + self, + enforce_label_match: bool = False, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonIOU` object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_iou" class BoundingBoxMeanAveragePrecision(BaseEvalFunction): + def __call__( + self, + iou_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonMAP` object. + + Args: + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + iou_threshold=iou_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_map" class BoundingBoxRecall(BaseEvalFunction): + def __call__( + self, + enforce_label_match: bool = False, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonRecall` object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_recall" class BoundingBoxPrecision(BaseEvalFunction): + def __call__( + self, + enforce_label_match: bool = False, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonPrecision` object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_precision" @@ -41,6 +119,35 @@ def __call__( f1_method: Optional[str] = None, **kwargs, ): + """ + Args: + confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. + Must be in [0, 1]. Default 0.0 + f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \ + default='macro' + This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + """ return super().__call__( confidence_threshold=confidence_threshold, f1_method=f1_method ) @@ -99,13 +206,14 @@ def expected_name(cls) -> str: EvalFunction = Union[ - Type[BoundingBoxIOU], - Type[BoundingBoxMeanAveragePrecision], - Type[BoundingBoxPrecision], - Type[BoundingBoxRecall], - Type[CustomEvalFunction], - Type[EvalFunctionNotAvailable], - Type[StandardEvalFunction], + BoundingBoxIOU, + BoundingBoxMeanAveragePrecision, + BoundingBoxPrecision, + BoundingBoxRecall, + CategorizationF1, + CustomEvalFunction, + EvalFunctionNotAvailable, + StandardEvalFunction, ] @@ -140,14 +248,14 @@ def __init__(self, available_functions: List[EvalFunctionEntry]): for f in available_functions if not f.is_public } - self.bbox_iou = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore - self.bbox_precision = self._assign_eval_function_if_defined( + self.bbox_iou: BoundingBoxIOU = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore + self.bbox_precision: BoundingBoxPrecision = self._assign_eval_function_if_defined( BoundingBoxPrecision # type: ignore ) - self.bbox_recall = self._assign_eval_function_if_defined( + self.bbox_recall: BoundingBoxRecall = self._assign_eval_function_if_defined( BoundingBoxRecall # type: ignore ) - self.bbox_map = self._assign_eval_function_if_defined( + self.bbox_map: BoundingBoxMeanAveragePrecision = self._assign_eval_function_if_defined( BoundingBoxMeanAveragePrecision # type: ignore ) self.cat_f1: CategorizationF1 = self._assign_eval_function_if_defined( diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index f6b3e61c..ec0a5ccf 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -1,6 +1,7 @@ import pytest from nucleus.validate import CreateScenarioTestError +from nucleus.validate.client import InvalidEvaluationCriteria from nucleus.validate.scenario_test import ScenarioTest from tests.helpers import ( EVAL_FUNCTION_COMPARISON, @@ -88,3 +89,28 @@ def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): # create some dataset_items for the scenario test to reference with pytest.raises(Exception): scenario_test.set_baseline_model("nonexistent_model_id") + + +def test_missing_comparison_raises_invalid_criteria( + CLIENT, test_slice, annotations +): + test_name = "scenario_test_" + get_uuid() # use uuid to make unique + with pytest.raises(InvalidEvaluationCriteria): + CLIENT.validate.create_scenario_test( + name=test_name, + slice_id=test_slice.id, + evaluation_criteria=[ + CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) + ], + ) + + +def test_passing_eval_arguments(CLIENT, test_slice, annotations): + test_name = "scenario_test_" + get_uuid() # use uuid to make unique + CLIENT.validate.create_scenario_test( + name=test_name, + slice_id=test_slice.id, + evaluation_criteria=[ + CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) > 0 + ], + ) From 9b9f68d5908ac005e98cb33df5c787c184b96969 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Tue, 1 Mar 2022 11:34:33 +0100 Subject: [PATCH 03/33] Update defaults to match metrics --- nucleus/validate/eval_functions/available_eval_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index bef31582..19b51380 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -40,7 +40,7 @@ def expected_name(cls) -> str: class BoundingBoxMeanAveragePrecision(BaseEvalFunction): def __call__( self, - iou_threshold: float = 0.0, + iou_threshold: float = 0.5, **kwargs, ): """Configures a call to :class:`PolygonMAP` object. @@ -62,7 +62,7 @@ class BoundingBoxRecall(BaseEvalFunction): def __call__( self, enforce_label_match: bool = False, - iou_threshold: float = 0.0, + iou_threshold: float = 0.5, confidence_threshold: float = 0.0, **kwargs, ): @@ -89,7 +89,7 @@ class BoundingBoxPrecision(BaseEvalFunction): def __call__( self, enforce_label_match: bool = False, - iou_threshold: float = 0.0, + iou_threshold: float = 0.5, confidence_threshold: float = 0.0, **kwargs, ): From d86cef28256bfddc4b71f0a6b64a121f10facff7 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Tue, 1 Mar 2022 11:36:27 +0100 Subject: [PATCH 04/33] Address @phil-scale comments! --- nucleus/validate/client.py | 17 +++++++---------- nucleus/validate/errors.py | 4 ++++ tests/validate/test_scenario_test.py | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index 61538fba..37ddc866 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -6,7 +6,7 @@ from .constants import SCENARIO_TEST_ID_KEY from .data_transfer_objects.eval_function import GetEvalFunctions from .data_transfer_objects.scenario_test import CreateScenarioTestRequest -from .errors import CreateScenarioTestError +from .errors import CreateScenarioTestError, InvalidEvaluationCriteria from .eval_functions.available_eval_functions import AvailableEvalFunctions from .eval_functions.base_eval_function import BaseEvalFunction from .scenario_test import ScenarioTest @@ -15,10 +15,6 @@ EVAL_FUNCTIONS_KEY = "eval_functions" -class InvalidEvaluationCriteria(Exception): - pass - - class Validate: """Model CI Python Client extension.""" @@ -38,7 +34,8 @@ def eval_functions(self) -> AvailableEvalFunctions: import nucleus client = nucleus.NucleusClient("YOUR_SCALE_API_KEY") - scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 # Creates an EvaluationCriterion by comparison + # Creates an EvaluationCriterion by using a comparison op + scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 Returns: :class:`AvailableEvalFunctions`: A container for all the available eval functions @@ -80,16 +77,16 @@ def create_scenario_test( "Must pass an evaluation_function to the scenario test! I.e. " "evaluation_functions=[client.validate.eval_functions.bbox_iou()]" ) - incorrect_type = [ + incorrect_types = [ crit for crit in evaluation_criteria if not isinstance(crit, EvaluationCriterion) ] - if incorrect_type: + if len(incorrect_types) > 0: # NOTE: We expect people to forget adding comparison to these calls so make an explicit error msg. eval_funcs = [ incorrect - for incorrect in incorrect_type + for incorrect in incorrect_types if isinstance(incorrect, BaseEvalFunction) ] if eval_funcs: @@ -101,7 +98,7 @@ def create_scenario_test( f"I.e. `{example_call} > 0.5` instead of just `{example_call}`" ) else: - msg = f"Received an incorrect `evaluation_criteria`: {repr(incorrect_type)}" + msg = f"Received an incorrect `evaluation_criteria`: {repr(incorrect_types)}" raise InvalidEvaluationCriteria(msg) response = self.connection.post( diff --git a/nucleus/validate/errors.py b/nucleus/validate/errors.py index 19cc249b..87253c71 100644 --- a/nucleus/validate/errors.py +++ b/nucleus/validate/errors.py @@ -4,3 +4,7 @@ class CreateScenarioTestError(Exception): class EvalFunctionNotAvailableError(Exception): pass + + +class InvalidEvaluationCriteria(Exception): + pass diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index ec0a5ccf..2521d414 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -1,7 +1,7 @@ import pytest from nucleus.validate import CreateScenarioTestError -from nucleus.validate.client import InvalidEvaluationCriteria +from nucleus.validate.errors import InvalidEvaluationCriteria from nucleus.validate.scenario_test import ScenarioTest from tests.helpers import ( EVAL_FUNCTION_COMPARISON, From 2d1e73858573fd8b05d5d067d4dcc00b7ba0a2d3 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Tue, 1 Mar 2022 11:59:26 +0100 Subject: [PATCH 05/33] Add examples to configuration functions and clear up class naming --- nucleus/validate/client.py | 6 +- .../available_eval_functions.py | 115 ++++++++++++++---- .../eval_functions/base_eval_function.py | 6 +- 3 files changed, 95 insertions(+), 32 deletions(-) diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index 37ddc866..df12dda2 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -8,7 +8,7 @@ from .data_transfer_objects.scenario_test import CreateScenarioTestRequest from .errors import CreateScenarioTestError, InvalidEvaluationCriteria from .eval_functions.available_eval_functions import AvailableEvalFunctions -from .eval_functions.base_eval_function import BaseEvalFunction +from .eval_functions.base_eval_function import EvalFunctionConfig from .scenario_test import ScenarioTest SUCCESS_KEY = "success" @@ -87,10 +87,10 @@ def create_scenario_test( eval_funcs = [ incorrect for incorrect in incorrect_types - if isinstance(incorrect, BaseEvalFunction) + if isinstance(incorrect, EvalFunctionConfig) ] if eval_funcs: - example: BaseEvalFunction = eval_funcs[0] + example: EvalFunctionConfig = eval_funcs[0] example_call = f"{example.name}()" msg = ( f"Expected a comparison (<, <=, >, >=) for every `evaluation_criteria`. " diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index 19b51380..a2ba14cd 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -2,7 +2,9 @@ from typing import Callable, Dict, List, Optional, Union from nucleus.logger import logger -from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction +from nucleus.validate.eval_functions.base_eval_function import ( + EvalFunctionConfig, +) from ..data_transfer_objects.eval_function import EvalFunctionEntry from ..errors import EvalFunctionNotAvailableError @@ -10,7 +12,7 @@ MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes" -class BoundingBoxIOU(BaseEvalFunction): +class PolygonIOUConfig(EvalFunctionConfig): def __call__( self, enforce_label_match: bool = False, @@ -19,6 +21,18 @@ def __call__( **kwargs, ): """Configures a call to :class:`PolygonIOU` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_iou: BoundingBoxIOU = client.validate.eval_functions.bbox_iou + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_iou(confidence_threshold=0.8) > 0.5] + ) Args: enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False @@ -37,13 +51,25 @@ def expected_name(cls) -> str: return "bbox_iou" -class BoundingBoxMeanAveragePrecision(BaseEvalFunction): +class PolygonMAPConfig(EvalFunctionConfig): def __call__( self, iou_threshold: float = 0.5, **kwargs, ): """Configures a call to :class:`PolygonMAP` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_map: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_map + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_map(iou_threshold=0.6) > 0.8] + ) Args: iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 @@ -58,7 +84,7 @@ def expected_name(cls) -> str: return "bbox_map" -class BoundingBoxRecall(BaseEvalFunction): +class PolygonRecallConfig(EvalFunctionConfig): def __call__( self, enforce_label_match: bool = False, @@ -67,6 +93,18 @@ def __call__( **kwargs, ): """Configures a call to :class:`PolygonRecall` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_recall: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_recall + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_recall(iou_threshold=0.6, confidence_threshold=0.4) > 0.9] + ) Args: enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False @@ -85,7 +123,7 @@ def expected_name(cls) -> str: return "bbox_recall" -class BoundingBoxPrecision(BaseEvalFunction): +class PolygonPrecisionConfig(EvalFunctionConfig): def __call__( self, enforce_label_match: bool = False, @@ -94,6 +132,18 @@ def __call__( **kwargs, ): """Configures a call to :class:`PolygonPrecision` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_precision: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_precision + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_precision(iou_threshold=0.6, confidence_threshold=0.4) > 0.9] + ) Args: enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False @@ -112,14 +162,27 @@ def expected_name(cls) -> str: return "bbox_precision" -class CategorizationF1(BaseEvalFunction): +class CategorizationF1Config(EvalFunctionConfig): def __call__( self, confidence_threshold: Optional[float] = None, f1_method: Optional[str] = None, **kwargs, ): - """ + """ Configure an evaluation of :class:`CategorizationF1`. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + cat_f1: CategorizationF1 = client.validate.eval_functions.cat_f1 + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[cat_f1(confidence_threshold=0.6, f1_method="weighted") > 0.7] + ) + Args: confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0 @@ -157,7 +220,7 @@ def expected_name(cls) -> str: return "cat_f1" -class CustomEvalFunction(BaseEvalFunction): +class CustomEvalFunction(EvalFunctionConfig): @classmethod def expected_name(cls) -> str: raise NotImplementedError( @@ -165,7 +228,7 @@ def expected_name(cls) -> str: ) # Placeholder: See super().eval_func_entry for actual name -class StandardEvalFunction(BaseEvalFunction): +class StandardEvalFunction(EvalFunctionConfig): """Class for standard Model CI eval functions that have not been added as attributes on AvailableEvalFunctions yet. """ @@ -182,7 +245,7 @@ def expected_name(cls) -> str: return "public_function" # Placeholder: See super().eval_func_entry for actual name -class EvalFunctionNotAvailable(BaseEvalFunction): +class EvalFunctionNotAvailable(EvalFunctionConfig): def __init__( self, not_available_name: str ): # pylint: disable=super-init-not-called @@ -206,11 +269,11 @@ def expected_name(cls) -> str: EvalFunction = Union[ - BoundingBoxIOU, - BoundingBoxMeanAveragePrecision, - BoundingBoxPrecision, - BoundingBoxRecall, - CategorizationF1, + PolygonIOUConfig, + PolygonMAPConfig, + PolygonPrecisionConfig, + PolygonRecallConfig, + CategorizationF1Config, CustomEvalFunction, EvalFunctionNotAvailable, StandardEvalFunction, @@ -242,24 +305,24 @@ def __init__(self, available_functions: List[EvalFunctionEntry]): f.name: f for f in available_functions if f.is_public } # NOTE: Public are assigned - self._public_to_function: Dict[str, BaseEvalFunction] = {} + self._public_to_function: Dict[str, EvalFunctionConfig] = {} self._custom_to_function: Dict[str, CustomEvalFunction] = { f.name: CustomEvalFunction(f) for f in available_functions if not f.is_public } - self.bbox_iou: BoundingBoxIOU = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore - self.bbox_precision: BoundingBoxPrecision = self._assign_eval_function_if_defined( - BoundingBoxPrecision # type: ignore + self.bbox_iou: PolygonIOUConfig = self._assign_eval_function_if_defined(PolygonIOUConfig) # type: ignore + self.bbox_precision: PolygonPrecisionConfig = self._assign_eval_function_if_defined( + PolygonPrecisionConfig # type: ignore ) - self.bbox_recall: BoundingBoxRecall = self._assign_eval_function_if_defined( - BoundingBoxRecall # type: ignore + self.bbox_recall: PolygonRecallConfig = self._assign_eval_function_if_defined( + PolygonRecallConfig # type: ignore ) - self.bbox_map: BoundingBoxMeanAveragePrecision = self._assign_eval_function_if_defined( - BoundingBoxMeanAveragePrecision # type: ignore + self.bbox_map: PolygonMAPConfig = self._assign_eval_function_if_defined( + PolygonMAPConfig # type: ignore ) - self.cat_f1: CategorizationF1 = self._assign_eval_function_if_defined( - CategorizationF1 # type: ignore + self.cat_f1: CategorizationF1Config = self._assign_eval_function_if_defined( + CategorizationF1Config # type: ignore ) # Add public entries that have not been implemented as an attribute on this class @@ -281,7 +344,7 @@ def __repr__(self): ) @property - def public_functions(self) -> Dict[str, BaseEvalFunction]: + def public_functions(self) -> Dict[str, EvalFunctionConfig]: """Standard functions provided by Model CI. Notes: diff --git a/nucleus/validate/eval_functions/base_eval_function.py b/nucleus/validate/eval_functions/base_eval_function.py index 67ee8659..af823f70 100644 --- a/nucleus/validate/eval_functions/base_eval_function.py +++ b/nucleus/validate/eval_functions/base_eval_function.py @@ -8,8 +8,8 @@ ) -class BaseEvalFunction(abc.ABC): - """Abstract base class for concrete implementations of EvalFunctions +class EvalFunctionConfig(abc.ABC): + """Abstract base class for concrete implementations of EvalFunctionsConfigs Operating on this class with comparison operators produces an EvaluationCriterion """ @@ -28,7 +28,7 @@ def __repr__(self): def expected_name(cls) -> str: """Name to look for in the EvalFunctionDefinitions""" - def __call__(self, **kwargs) -> "BaseEvalFunction": + def __call__(self, **kwargs) -> "EvalFunctionConfig": """Adding call to prepare for being able to pass parameters to function Notes: From 4b4ffee9ee4d17070e729831eb3d95d15757fd67 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Wed, 30 Mar 2022 13:09:17 +0200 Subject: [PATCH 06/33] Fix rebase errors --- nucleus/validate/client.py | 27 ++------------------------- nucleus/validate/errors.py | 4 ---- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index df12dda2..cfc5fbb5 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -6,7 +6,7 @@ from .constants import SCENARIO_TEST_ID_KEY from .data_transfer_objects.eval_function import GetEvalFunctions from .data_transfer_objects.scenario_test import CreateScenarioTestRequest -from .errors import CreateScenarioTestError, InvalidEvaluationCriteria +from .errors import CreateScenarioTestError from .eval_functions.available_eval_functions import AvailableEvalFunctions from .eval_functions.base_eval_function import EvalFunctionConfig from .scenario_test import ScenarioTest @@ -50,7 +50,7 @@ def create_scenario_test( self, name: str, slice_id: str, - evaluation_functions: List[EvalFunction], + evaluation_functions: List[EvalFunctionConfig], ) -> ScenarioTest: """Creates a new Scenario Test from an existing Nucleus :class:`Slice`:. :: @@ -77,29 +77,6 @@ def create_scenario_test( "Must pass an evaluation_function to the scenario test! I.e. " "evaluation_functions=[client.validate.eval_functions.bbox_iou()]" ) - incorrect_types = [ - crit - for crit in evaluation_criteria - if not isinstance(crit, EvaluationCriterion) - ] - if len(incorrect_types) > 0: - # NOTE: We expect people to forget adding comparison to these calls so make an explicit error msg. - eval_funcs = [ - incorrect - for incorrect in incorrect_types - if isinstance(incorrect, EvalFunctionConfig) - ] - if eval_funcs: - example: EvalFunctionConfig = eval_funcs[0] - example_call = f"{example.name}()" - msg = ( - f"Expected a comparison (<, <=, >, >=) for every `evaluation_criteria`. " - f"You should add a comparison to {eval_funcs}. " - f"I.e. `{example_call} > 0.5` instead of just `{example_call}`" - ) - else: - msg = f"Received an incorrect `evaluation_criteria`: {repr(incorrect_types)}" - raise InvalidEvaluationCriteria(msg) response = self.connection.post( CreateScenarioTestRequest( diff --git a/nucleus/validate/errors.py b/nucleus/validate/errors.py index 87253c71..19cc249b 100644 --- a/nucleus/validate/errors.py +++ b/nucleus/validate/errors.py @@ -4,7 +4,3 @@ class CreateScenarioTestError(Exception): class EvalFunctionNotAvailableError(Exception): pass - - -class InvalidEvaluationCriteria(Exception): - pass From 442eaf810940ded50a558ac0372a7c58869d869c Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Wed, 30 Mar 2022 13:23:58 +0200 Subject: [PATCH 07/33] Another rebasing error bites the dust --- tests/validate/test_scenario_test.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index 2521d414..9060b026 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -1,7 +1,6 @@ import pytest from nucleus.validate import CreateScenarioTestError -from nucleus.validate.errors import InvalidEvaluationCriteria from nucleus.validate.scenario_test import ScenarioTest from tests.helpers import ( EVAL_FUNCTION_COMPARISON, @@ -91,20 +90,6 @@ def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): scenario_test.set_baseline_model("nonexistent_model_id") -def test_missing_comparison_raises_invalid_criteria( - CLIENT, test_slice, annotations -): - test_name = "scenario_test_" + get_uuid() # use uuid to make unique - with pytest.raises(InvalidEvaluationCriteria): - CLIENT.validate.create_scenario_test( - name=test_name, - slice_id=test_slice.id, - evaluation_criteria=[ - CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) - ], - ) - - def test_passing_eval_arguments(CLIENT, test_slice, annotations): test_name = "scenario_test_" + get_uuid() # use uuid to make unique CLIENT.validate.create_scenario_test( From 203d0f298e59780d8f6cd56fdc2c4d2d81201c17 Mon Sep 17 00:00:00 2001 From: Diego Ardila Date: Wed, 16 Mar 2022 11:20:33 -0700 Subject: [PATCH 08/33] Refactor a lot of segmentation local upload and async logic (#256) * work in progress * work in progress * Big refactor to make things cleaner + enable retries properly on infra flakes for local upload * work in progress refactor of annotation upload * Fixed segmentation bugs * Fix one more bug and remove use of annotate_segmentation endpoint * refactor tests and add segmentation local upload test * Tests passing * Review feedback * Initial pass at client changes for prediction segmentation upload * relevant tests pass Co-authored-by: Ubuntu --- conftest.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/conftest.py b/conftest.py index 90946313..b2cdfef1 100644 --- a/conftest.py +++ b/conftest.py @@ -39,6 +39,13 @@ def model(CLIENT): CLIENT.delete_model(model.id) +@pytest.fixture() +def model(CLIENT): + model = CLIENT.create_model(TEST_DATASET_NAME, "fake_reference_id") + yield model + CLIENT.delete_model(model.id) + + if __name__ == "__main__": client = nucleus.NucleusClient(API_KEY) # ds = client.create_dataset("Test Dataset With Autotags") From 577d4fc4f88adedd11892b1e71c46ebdb9d2949d Mon Sep 17 00:00:00 2001 From: Diego Ardila Date: Fri, 18 Mar 2022 15:11:11 -0700 Subject: [PATCH 09/33] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a3a47502..c97bc4bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.8.3" +version = "0.8.4" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] From 940741cb3064fc2497b4c752935bacef72126fcc Mon Sep 17 00:00:00 2001 From: Sasha Harrison <70984140+sasha-scale@users.noreply.github.com> Date: Thu, 24 Mar 2022 09:59:40 -0700 Subject: [PATCH 10/33] fix camera_model initialization (#264) * fix camera_model initialization * bump version number --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c97bc4bc..712d9280 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.8.4" +version = "0.8.2" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] From e6a9058e483272f56061380443ab6f27a432f0d0 Mon Sep 17 00:00:00 2001 From: Sasha Harrison <70984140+sasha-scale@users.noreply.github.com> Date: Tue, 29 Mar 2022 16:57:08 -0700 Subject: [PATCH 11/33] Validate feature: setting baseline models (#266) add new set model as baseline functions to client, remove add_criteria in favor of add_eval_function, bump version number and changelog --- pyproject.toml | 2 +- tests/validate/test_scenario_test.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 712d9280..a3a47502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.8.2" +version = "0.8.3" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index 9060b026..95731fdc 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -99,3 +99,21 @@ def test_passing_eval_arguments(CLIENT, test_slice, annotations): CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) > 0 ], ) + + +def test_scenario_test_set_metric_threshold( + CLIENT, annotations, scenario_test +): + # create some dataset_items for the scenario test to reference + threshold = 0.5 + scenario_test_metrics = scenario_test.get_eval_functions() + metric = scenario_test_metrics[0] + assert metric + metric.set_threshold(threshold) + assert metric.threshold == threshold + + +def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): + # create some dataset_items for the scenario test to reference + with pytest.raises(Exception): + scenario_test.set_baseline_model("nonexistent_model_id") From e18b124eac18128ed9d0a5c07c19bbc775b713aa Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Fri, 25 Feb 2022 10:41:55 +0100 Subject: [PATCH 12/33] Add better error message for scenario_test misconfiguration and arguments to all public functions --- tests/validate/test_scenario_test.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index 95731fdc..a7893d0d 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -95,8 +95,8 @@ def test_passing_eval_arguments(CLIENT, test_slice, annotations): CLIENT.validate.create_scenario_test( name=test_name, slice_id=test_slice.id, - evaluation_criteria=[ - CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) > 0 + evaluation_functions=[ + CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) ], ) @@ -117,3 +117,15 @@ def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): # create some dataset_items for the scenario test to reference with pytest.raises(Exception): scenario_test.set_baseline_model("nonexistent_model_id") + + + +def test_passing_eval_arguments(CLIENT, test_slice, annotations): + test_name = "scenario_test_" + get_uuid() # use uuid to make unique + CLIENT.validate.create_scenario_test( + name=test_name, + slice_id=test_slice.id, + evaluation_functions=[ + CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) + ], + ) From 58c66eb0ac7247828478810e946bba768a09cd75 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Tue, 1 Mar 2022 11:36:27 +0100 Subject: [PATCH 13/33] Address @phil-scale comments! --- nucleus/validate/client.py | 2 +- nucleus/validate/errors.py | 4 ++++ tests/validate/test_scenario_test.py | 30 ---------------------------- 3 files changed, 5 insertions(+), 31 deletions(-) diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index cfc5fbb5..f8189fe0 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -6,7 +6,7 @@ from .constants import SCENARIO_TEST_ID_KEY from .data_transfer_objects.eval_function import GetEvalFunctions from .data_transfer_objects.scenario_test import CreateScenarioTestRequest -from .errors import CreateScenarioTestError +from .errors import CreateScenarioTestError, InvalidEvaluationCriteria from .eval_functions.available_eval_functions import AvailableEvalFunctions from .eval_functions.base_eval_function import EvalFunctionConfig from .scenario_test import ScenarioTest diff --git a/nucleus/validate/errors.py b/nucleus/validate/errors.py index 19cc249b..87253c71 100644 --- a/nucleus/validate/errors.py +++ b/nucleus/validate/errors.py @@ -4,3 +4,7 @@ class CreateScenarioTestError(Exception): class EvalFunctionNotAvailableError(Exception): pass + + +class InvalidEvaluationCriteria(Exception): + pass diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index a7893d0d..cdf63f0e 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -90,36 +90,6 @@ def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): scenario_test.set_baseline_model("nonexistent_model_id") -def test_passing_eval_arguments(CLIENT, test_slice, annotations): - test_name = "scenario_test_" + get_uuid() # use uuid to make unique - CLIENT.validate.create_scenario_test( - name=test_name, - slice_id=test_slice.id, - evaluation_functions=[ - CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) - ], - ) - - -def test_scenario_test_set_metric_threshold( - CLIENT, annotations, scenario_test -): - # create some dataset_items for the scenario test to reference - threshold = 0.5 - scenario_test_metrics = scenario_test.get_eval_functions() - metric = scenario_test_metrics[0] - assert metric - metric.set_threshold(threshold) - assert metric.threshold == threshold - - -def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): - # create some dataset_items for the scenario test to reference - with pytest.raises(Exception): - scenario_test.set_baseline_model("nonexistent_model_id") - - - def test_passing_eval_arguments(CLIENT, test_slice, annotations): test_name = "scenario_test_" + get_uuid() # use uuid to make unique CLIENT.validate.create_scenario_test( From c95598411f303028e31579588cca0756b05ad4e8 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Wed, 16 Mar 2022 21:52:07 +0000 Subject: [PATCH 14/33] flake fix --- nucleus/metrics/cuboid_utils.py | 331 ++++++++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 nucleus/metrics/cuboid_utils.py diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py new file mode 100644 index 00000000..413ffc41 --- /dev/null +++ b/nucleus/metrics/cuboid_utils.py @@ -0,0 +1,331 @@ +import numpy as np +from functools import wraps +from typing import Dict, List, Tuple +from shapely.geometry import Polygon + +from .base import ScalarResult +from nucleus.annotation import CuboidAnnotation +from nucleus.prediction import CuboidPrediction + + +def group_cuboids_by_label( + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], +) -> Dict[str, Tuple[List[CuboidAnnotation], List[CuboidPrediction]]]: + """Groups input annotations and predictions by label. + + Args: + annotations: list of input cuboid annotations + predictions: list of input cuboid predictions + + Returns: + Mapping from each label to (annotations, predictions) tuple + """ + labels = set(annotation.label for annotation in annotations) + labels |= set(prediction.label for prediction in predictions) + grouped: Dict[ + str, Tuple[List[CuboidAnnotation], List[CuboidPrediction]] + ] = {label: ([], []) for label in labels} + for annotation in annotations: + grouped[annotation.label][0].append(annotation) + for prediction in predictions: + grouped[prediction.label][1].append(prediction) + return grouped + + +def label_match_wrapper(metric_fn): + """Decorator to add the ability to only apply metric to annotations and + predictions with matching labels. + + Args: + metric_fn: Metric function that takes a list of annotations, a list + of predictions, and optional args and kwargs. + + Returns: + Metric function which can optionally enforce matching labels. + """ + + @wraps(metric_fn) + def wrapper( + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + *args, + enforce_label_match: bool = False, + **kwargs, + ) -> ScalarResult: + # Simply return the metric if we are not enforcing label matches. + if not enforce_label_match: + return metric_fn(annotations, predictions, *args, **kwargs) + + # For each bin of annotations/predictions, compute the metric applied + # only to that bin. Then aggregate results across all bins. + grouped_inputs = group_cuboids_by_label(annotations, predictions) + metric_results = [] + for binned_annotations, binned_predictions in grouped_inputs.values(): + metric_result = metric_fn( + binned_annotations, binned_predictions, *args, **kwargs + ) + metric_results.append(metric_result) + assert all( + isinstance(r, ScalarResult) for r in metric_results + ), "Expected every result to be a ScalarResult" + return ScalarResult.aggregate(metric_results) + + return wrapper + + +def process_dataitem(dataitem): + processed_item = {} + processed_item["xyz"] = np.array( + [[ann.position.x, ann.position.y, ann.position.z] for ann in dataitem] + ) + processed_item["wlh"] = np.array( + [ + [ann.dimensions.x, ann.dimensions.y, ann.dimensions.z] + for ann in dataitem + ] + ) + processed_item["yaw"] = np.array([ann.yaw for ann in dataitem]) + return processed_item + + +def compute_outer_iou( + xyz_0: np.ndarray, + wlh_0: np.ndarray, + yaw_0: np.ndarray, + xyz_1: np.ndarray, + wlh_1: np.ndarray, + yaw_1: np.ndarray, + scale_convention: bool = True, + distance_threshold=25, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Computes outer 3D and 2D IoU + :param xyz_0: (n, 3) + :param wlh_0: (n, 3) + :param yaw_0: (n,) + :param xyz_1: (m, 3) + :param wlh_1: (m, 3) + :param yaw_1: (m,) + :param scale_convention: flag whether the internal Scale convention is used (have to be adjusted by pi/2) + :param distance_threshold: computes iou only within this distance (~3x speedup) + :return: (n, m) 3D IoU, (n, m) 2D IoU + """ + bottom_z = np.maximum.outer( + xyz_0[:, 2] - (wlh_0[:, 2] / 2), xyz_1[:, 2] - (wlh_1[:, 2] / 2) + ) + top_z = np.minimum.outer( + xyz_0[:, 2] + (wlh_0[:, 2] / 2), xyz_1[:, 2] + (wlh_1[:, 2] / 2) + ) + height_intersection = np.maximum(0, top_z - bottom_z) + + cuboid_corners_0 = get_batch_cuboid_corners( + xyz_0, wlh_0, yaw_0, scale_convention=scale_convention + ) + cuboid_corners_1 = get_batch_cuboid_corners( + xyz_1, wlh_1, yaw_1, scale_convention=scale_convention + ) + polygons_1 = [ + Polygon(corners_1[[1, 0, 4, 5, 1], :2]) + for corners_1 in cuboid_corners_1 + ] + area_intersection = np.zeros( + (cuboid_corners_0.shape[0], cuboid_corners_1.shape[0]), + dtype=np.float32, + ) + + if cuboid_corners_0.shape[0] != 0 and cuboid_corners_1.shape[0] != 0: + distance_mask = ( + np.linalg.norm( + xyz_0[:, np.newaxis, :] - xyz_1[np.newaxis, :, :], axis=2 + ) + < distance_threshold + ) + + for i, corners_0 in enumerate(cuboid_corners_0): + for j, polygon_1 in enumerate(polygons_1): + if distance_mask[i, j]: + area_intersection[i, j] = ( + Polygon(corners_0[[1, 0, 4, 5, 1], :2]) + .intersection(polygon_1) + .area + ) + + intersection = height_intersection * area_intersection + area_0 = wlh_0[:, 0] * wlh_0[:, 1] + area_1 = wlh_1[:, 0] * wlh_1[:, 1] + union_2d = np.add.outer(area_0, area_1) - area_intersection + + volume_0 = area_0 * wlh_0[:, 2] + volume_1 = area_1 * wlh_1[:, 2] + union = np.add.outer(volume_0, volume_1) - intersection + return intersection / union, area_intersection / union_2d + + +def get_batch_cuboid_corners( + xyz: np.ndarray, + wlh: np.ndarray, + yaw: np.ndarray, + pitch: np.ndarray = None, + roll: np.ndarray = None, + scale_convention: bool = True, +) -> np.ndarray: + """ + Vectorized batch version of get_cuboid_corners + :param xyz: (n, 3) + :param wlh: (n, 3) + :param yaw: (n,) + :param pitch: (n,) + :param roll: (n,) + :param scale_convention: flag whether the internal Scale convention is used (have to be adjusted by pi/2) + :return: (n, 8, 3) + """ + if scale_convention: + yaw = yaw.copy() + np.pi / 2 + + w, l, h = wlh[:, 0, None], wlh[:, 1, None], wlh[:, 2, None] + + x_corners = l / 2 * np.array([1, 1, 1, 1, -1, -1, -1, -1]) + y_corners = w / 2 * np.array([1, -1, -1, 1, 1, -1, -1, 1]) + z_corners = h / 2 * np.array([1, 1, -1, -1, 1, 1, -1, -1]) + corners = np.stack((x_corners, y_corners, z_corners), axis=1) + + rot_mats = get_batch_rotation_matrices(yaw, pitch, roll) + corners = np.matmul(rot_mats, corners) + + x, y, z = xyz[:, 0, None], xyz[:, 1, None], xyz[:, 2, None] + corners[:, 0, :] = corners[:, 0, :] + x + corners[:, 1, :] = corners[:, 1, :] + y + corners[:, 2, :] = corners[:, 2, :] + z + return corners.swapaxes(1, 2) + + +def get_batch_rotation_matrices( + yaw: np.ndarray, pitch: np.ndarray = None, roll: np.ndarray = None +) -> np.ndarray: + if pitch is None: + pitch = np.zeros_like(yaw) + if roll is None: + roll = np.zeros_like(yaw) + cy = np.cos(yaw) + sy = np.sin(yaw) + cp = np.cos(pitch) + sp = np.sin(pitch) + cr = np.cos(roll) + sr = np.sin(roll) + return np.stack( + ( + np.stack( + (cy * cp, cy * sp * sr - sy * cr, cy * sp * cr + sy * sr), 1 + ), + np.stack( + (sy * cp, sy * sp * sr + cy * cr, sy * sp * cr - cy * sr), 1 + ), + np.stack((-sp, cp * sr, cp * cr), 1), + ), + 1, + ) + + +def associate_cuboids_on_iou( + xyz_0: np.ndarray, + wlh_0: np.ndarray, + yaw_0: np.ndarray, + xyz_1: np.ndarray, + wlh_1: np.ndarray, + yaw_1: np.ndarray, + threshold_in_overlap_ratio: float = 0.1, +) -> List[Tuple[int, int]]: + if xyz_0.shape[0] < 1 or xyz_1.shape[0] < 1: + return [] + iou_matrix, _ = compute_outer_iou(xyz_0, wlh_0, yaw_0, xyz_1, wlh_1, yaw_1) + mapping = [] + for i, m in enumerate(iou_matrix.max(axis=1)): + if m >= threshold_in_overlap_ratio: + mapping.append((i, iou_matrix[i].argmax())) + return mapping + + +def recall_precision( + prediction, + groundtruth, + threshold_in_overlap_ratio: float, +) -> Dict[str, float]: + """ + :param predictions: + :param ground_truth: + :param threshold: threshold in overlap ratio if IoU + """ + tp_sum = 0 + fp_sum = 0 + fn_sum = 0 + num_predicted = 0 + num_instances = 0 + + gt_items = process_dataitem(groundtruth) + pred_items = process_dataitem(prediction) + + num_predicted += pred_items["xyz"].shape[0] + num_instances += gt_items["xyz"].shape[0] + + tp = np.zeros(pred_items["xyz"].shape[0]) + fp = np.ones(pred_items["xyz"].shape[0]) + fn = np.ones(gt_items["xyz"].shape[0]) + + mapping = associate_cuboids_on_iou( + pred_items["xyz"], + pred_items["wlh"], + pred_items["yaw"] + np.pi / 2, + gt_items["xyz"], + gt_items["wlh"], + gt_items["yaw"] + np.pi / 2, + threshold_in_overlap_ratio=threshold_in_overlap_ratio, + ) + + for pred_id, gt_id in mapping: + if fn[gt_id] == 0: + continue + tp[pred_id] = 1 + fp[pred_id] = 0 + fn[gt_id] = 0 + + tp_sum += tp.sum() + fp_sum += fp.sum() + fn_sum += fn.sum() + + return { + "tp_sum": tp_sum, + "fp_sum": fp_sum, + "fn_sum": fn_sum, + "precision": tp_sum / (tp_sum + fp_sum), + "recall": tp_sum / (tp_sum + fn_sum), + "num_predicted": num_predicted, + "num_instances": num_instances, + } + + +def detection_iou(prediction, groundtruth, threshold_in_overlap_ratio): + + gt_items = process_dataitem(groundtruth) + pred_items = process_dataitem(prediction) + + meter_2d = [] + meter_3d = [] + iou_3d, iou_2d = compute_outer_iou( + gt_items["xyz"], + gt_items["wlh"], + gt_items["yaw"], + pred_items["xyz"], + pred_items["wlh"], + pred_items["yaw"], + ) + + for i, m in enumerate(iou_3d.max(axis=1)): + if m >= threshold_in_overlap_ratio: + j = iou_3d[i].argmax() + meter_3d.append(iou_3d[i, j]) + meter_2d.append(iou_2d[i, j]) + + meter_3d = np.array(meter_3d) + meter_2d = np.array(meter_2d) + return meter_3d, meter_2d From 94949ecda9ace188bcd4dd4b073bb0616c4e4eb6 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Wed, 16 Mar 2022 23:59:51 +0000 Subject: [PATCH 15/33] lint --- nucleus/metrics/cuboid_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index 413ffc41..0326315a 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -1,11 +1,11 @@ -import numpy as np -from functools import wraps from typing import Dict, List, Tuple +from functools import wraps +import numpy as np from shapely.geometry import Polygon -from .base import ScalarResult from nucleus.annotation import CuboidAnnotation from nucleus.prediction import CuboidPrediction +from .base import ScalarResult def group_cuboids_by_label( From 9065ebe3ab5f33ed1a8c78b5d70cfccbfa20ff77 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 00:03:44 +0000 Subject: [PATCH 16/33] linting for circle ci --- nucleus/metrics/__init__.py | 1 + nucleus/metrics/cuboid_metrics.py | 193 ++++++++++++++++++++++++++++++ nucleus/metrics/cuboid_utils.py | 4 +- 3 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 nucleus/metrics/cuboid_metrics.py diff --git a/nucleus/metrics/__init__.py b/nucleus/metrics/__init__.py index 1fd038a2..cb19d8a8 100644 --- a/nucleus/metrics/__init__.py +++ b/nucleus/metrics/__init__.py @@ -1,5 +1,6 @@ from .base import Metric, ScalarResult from .categorization_metrics import CategorizationF1 +from .cuboid_metrics import CuboidIOU, CuboidPrecision, CuboidRecall from .polygon_metrics import ( PolygonAveragePrecision, PolygonIOU, diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py new file mode 100644 index 00000000..ed9fc509 --- /dev/null +++ b/nucleus/metrics/cuboid_metrics.py @@ -0,0 +1,193 @@ +import sys +from abc import abstractmethod +from typing import List + +from nucleus.annotation import AnnotationList, CuboidAnnotation +from nucleus.prediction import CuboidPrediction, PredictionList + +from .base import Metric, ScalarResult +from .cuboid_utils import detection_iou, label_match_wrapper, recall_precision +from .filters import confidence_filter + + +class CuboidMetric(Metric): + """Abstract class for metrics of cuboids. + + The CuboidMetric class automatically filters incoming annotations and + predictions for only cuboid annotations. It also filters + predictions whose confidence is less than the provided confidence_threshold. + Finally, it provides support for enforcing matching labels. If + `enforce_label_match` is set to True, then annotations and predictions will + only be matched if they have the same label. + + To create a new concrete CuboidMetric, override the `eval` function + with logic to define a metric between cuboid annotations and predictions. + """ + + def __init__( + self, + enforce_label_match: bool = False, + confidence_threshold: float = 0.0, + ): + """Initializes CuboidMetric abstract object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Default False + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + self.enforce_label_match = enforce_label_match + assert 0 <= confidence_threshold <= 1 + self.confidence_threshold = confidence_threshold + + @abstractmethod + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + # Main evaluation function that subclasses must override. + pass + + def aggregate_score(self, results: List[ScalarResult]) -> ScalarResult: # type: ignore[override] + return ScalarResult.aggregate(results) + + def __call__( + self, annotations: AnnotationList, predictions: PredictionList + ) -> ScalarResult: + if self.confidence_threshold > 0: + predictions = confidence_filter( + predictions, self.confidence_threshold + ) + cuboid_annotations: List[CuboidAnnotation] = [] + cuboid_annotations.extend(annotations.cuboid_annotations) + cuboid_predictions: List[CuboidPrediction] = [] + cuboid_predictions.extend(predictions.cuboid_predictions) + + eval_fn = label_match_wrapper(self.eval) + result = eval_fn( + cuboid_annotations, + cuboid_predictions, + enforce_label_match=self.enforce_label_match, + ) + return result + + +class CuboidIOU(CuboidMetric): + """Calculates the average IOU between cuboid annotations and predictions.""" + + # TODO: Remove defaults once these are surfaced more cleanly to users. + def __init__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + birds_eye_view: bool = False, + ): + """Initializes CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + birds_eye_view: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + assert ( + 0 <= iou_threshold <= 1 + ), "IoU threshold must be between 0 and 1." + self.iou_threshold = iou_threshold + self.birds_eye_view = birds_eye_view + super().__init__(enforce_label_match, confidence_threshold) + + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + iou_3d, iou_2d = detection_iou( + predictions, + annotations, + threshold_in_overlap_ratio=self.iou_threshold, + ) + weight = max(len(annotations), len(predictions)) + if self.birds_eye_view: + avg_iou = iou_2d.sum() / max(weight, sys.float_info.epsilon) + else: + avg_iou = iou_3d.sum() / max(weight, sys.float_info.epsilon) + + return ScalarResult(avg_iou, weight) + + +class CuboidPrecision(CuboidMetric): + """Calculates the average precision between cuboid annotations and predictions.""" + + # TODO: Remove defaults once these are surfaced more cleanly to users. + def __init__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + ): + """Initializes CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + assert ( + 0 <= iou_threshold <= 1 + ), "IoU threshold must be between 0 and 1." + self.iou_threshold = iou_threshold + super().__init__(enforce_label_match, confidence_threshold) + + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + stats = recall_precision( + predictions, + annotations, + threshold_in_overlap_ratio=self.iou_threshold, + ) + weight = stats["tp_sum"] + stats["fp_sum"] + precision = stats["tp_sum"] / max(weight, sys.float_info.epsilon) + return ScalarResult(precision, weight) + + +class CuboidRecall(CuboidMetric): + """Calculates the average recall between cuboid annotations and predictions.""" + + # TODO: Remove defaults once these are surfaced more cleanly to users. + def __init__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + ): + """Initializes CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + assert ( + 0 <= iou_threshold <= 1 + ), "IoU threshold must be between 0 and 1." + self.iou_threshold = iou_threshold + super().__init__(enforce_label_match, confidence_threshold) + + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + stats = recall_precision( + predictions, + annotations, + threshold_in_overlap_ratio=self.iou_threshold, + ) + weight = stats["tp_sum"] + stats["fn_sum"] + recall = stats["tp_sum"] / max(weight, sys.float_info.epsilon) + return ScalarResult(recall, weight) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index 0326315a..34458a9c 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -1,10 +1,12 @@ -from typing import Dict, List, Tuple from functools import wraps +from typing import Dict, List, Tuple + import numpy as np from shapely.geometry import Polygon from nucleus.annotation import CuboidAnnotation from nucleus.prediction import CuboidPrediction + from .base import ScalarResult From 3e9f7da57f6043bcc2d1bee46ee1170eb1974766 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 00:22:01 +0000 Subject: [PATCH 17/33] version --- nucleus/metrics/cuboid_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index 34458a9c..ad959147 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -258,6 +258,7 @@ def recall_precision( :param ground_truth: :param threshold: threshold in overlap ratio if IoU """ + tp_sum = 0 fp_sum = 0 fn_sum = 0 From a494ccb8e7c65ca71fcb9b3dd20555ff4af0067e Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 00:53:23 +0000 Subject: [PATCH 18/33] used native polygon --- nucleus/metrics/cuboid_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index ad959147..91b26e66 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -2,13 +2,15 @@ from typing import Dict, List, Tuple import numpy as np -from shapely.geometry import Polygon from nucleus.annotation import CuboidAnnotation from nucleus.prediction import CuboidPrediction from .base import ScalarResult +# from shapely.geometry import Polygon +from .geometry import GeometryPolygon, polygon_intersection_area + def group_cuboids_by_label( annotations: List[CuboidAnnotation], @@ -128,7 +130,7 @@ def compute_outer_iou( xyz_1, wlh_1, yaw_1, scale_convention=scale_convention ) polygons_1 = [ - Polygon(corners_1[[1, 0, 4, 5, 1], :2]) + GeometryPolygon(points=corners_1[[1, 0, 4, 5, 1], :2]) for corners_1 in cuboid_corners_1 ] area_intersection = np.zeros( @@ -147,10 +149,9 @@ def compute_outer_iou( for i, corners_0 in enumerate(cuboid_corners_0): for j, polygon_1 in enumerate(polygons_1): if distance_mask[i, j]: - area_intersection[i, j] = ( - Polygon(corners_0[[1, 0, 4, 5, 1], :2]) - .intersection(polygon_1) - .area + area_intersection[i, j] = polygon_intersection_area( + GeometryPolygon(points=corners_0[[1, 0, 4, 5, 1], :2]), + polygon_1, ) intersection = height_intersection * area_intersection From 29d048355bf8e1cd5ec4cd9735eb68dca698a8e4 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 01:15:03 +0000 Subject: [PATCH 19/33] adding shapely --- nucleus/metrics/cuboid_utils.py | 13 ++++++------- pyproject.toml | 1 + 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index 91b26e66..ad959147 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -2,15 +2,13 @@ from typing import Dict, List, Tuple import numpy as np +from shapely.geometry import Polygon from nucleus.annotation import CuboidAnnotation from nucleus.prediction import CuboidPrediction from .base import ScalarResult -# from shapely.geometry import Polygon -from .geometry import GeometryPolygon, polygon_intersection_area - def group_cuboids_by_label( annotations: List[CuboidAnnotation], @@ -130,7 +128,7 @@ def compute_outer_iou( xyz_1, wlh_1, yaw_1, scale_convention=scale_convention ) polygons_1 = [ - GeometryPolygon(points=corners_1[[1, 0, 4, 5, 1], :2]) + Polygon(corners_1[[1, 0, 4, 5, 1], :2]) for corners_1 in cuboid_corners_1 ] area_intersection = np.zeros( @@ -149,9 +147,10 @@ def compute_outer_iou( for i, corners_0 in enumerate(cuboid_corners_0): for j, polygon_1 in enumerate(polygons_1): if distance_mask[i, j]: - area_intersection[i, j] = polygon_intersection_area( - GeometryPolygon(points=corners_0[[1, 0, 4, 5, 1], :2]), - polygon_1, + area_intersection[i, j] = ( + Polygon(corners_0[[1, 0, 4, 5, 1], :2]) + .intersection(polygon_1) + .area ) intersection = height_intersection * area_intersection diff --git a/pyproject.toml b/pyproject.toml index a3a47502..8bf56b6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ sphinx-autobuild = "^2021.3.14" furo = "^2021.10.9" sphinx-autoapi = "^1.8.4" pytest-xdist = "^2.5.0" +shapely = "^1.7.1" [tool.poetry.scripts] nu = "cli.nu:nu" From 677c777c696b07add5802d3e3b03b7c8e6db3dd8 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 01:20:36 +0000 Subject: [PATCH 20/33] adding shapely --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8bf56b6f..6e343d29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ sphinx-autobuild = "^2021.3.14" furo = "^2021.10.9" sphinx-autoapi = "^1.8.4" pytest-xdist = "^2.5.0" -shapely = "^1.7.1" +Shapely = "^1.7.1" [tool.poetry.scripts] nu = "cli.nu:nu" From d8b7c34904e68672c9819d53c9fc04f245ecc726 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 01:23:00 +0000 Subject: [PATCH 21/33] changing shapely --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6e343d29..cd1a154f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ sphinx-autobuild = "^2021.3.14" furo = "^2021.10.9" sphinx-autoapi = "^1.8.4" pytest-xdist = "^2.5.0" -Shapely = "^1.7.1" +Shapely = "^1.8.0" [tool.poetry.scripts] nu = "cli.nu:nu" From ac6f54207daacf3ce30568927b7e02f1004ec211 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 01:27:35 +0000 Subject: [PATCH 22/33] changing shapely --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cd1a154f..bb735655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ click = ">=7.1.2,<9.0" # NOTE: COLAB has 7.1.2 and has problems updating rich = "^10.15.2" shellingham = "^1.4.0" scikit-learn = ">=0.24.0" +Shapely = "^1.7.1" [tool.poetry.dev-dependencies] poetry = "^1.1.5" @@ -63,7 +64,6 @@ sphinx-autobuild = "^2021.3.14" furo = "^2021.10.9" sphinx-autoapi = "^1.8.4" pytest-xdist = "^2.5.0" -Shapely = "^1.8.0" [tool.poetry.scripts] nu = "cli.nu:nu" From 9f5b6bc17aa2c5c86c7ee35b5964a3296b991211 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 01:30:00 +0000 Subject: [PATCH 23/33] updating shapely --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bb735655..cad3f245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ click = ">=7.1.2,<9.0" # NOTE: COLAB has 7.1.2 and has problems updating rich = "^10.15.2" shellingham = "^1.4.0" scikit-learn = ">=0.24.0" -Shapely = "^1.7.1" +Shapely = ">=1.7.1" [tool.poetry.dev-dependencies] poetry = "^1.1.5" From e9450897122c2ae79068ca59742d673e96e8f88c Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Thu, 17 Mar 2022 16:49:15 +0000 Subject: [PATCH 24/33] poetry added shapely --- poetry.lock | 53 +++++++++++++++++++++++++++++++++++++++----------- pyproject.toml | 2 +- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/poetry.lock b/poetry.lock index 55dc71a2..59bda3e5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1669,6 +1669,19 @@ nativelib = ["pyobjc-framework-cocoa", "pywin32"] objc = ["pyobjc-framework-cocoa"] win32 = ["pywin32"] +[[package]] +name = "shapely" +version = "1.8.1.post1" +description = "Geometric objects, predicates, and operations" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +all = ["pytest", "pytest-cov", "numpy"] +test = ["pytest", "pytest-cov"] +vectorized = ["numpy"] + [[package]] name = "shellingham" version = "1.4.0" @@ -2064,7 +2077,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.6.2" -content-hash = "8d6269a2d5d30685e0d1e9b38187b6105826042e868acfb5bdb5a3e7188560f8" +content-hash = "4830dd3f9f593703fbbae85efdbd6feca5b0513285ad87132d4f90934fca1098" [metadata.files] absl-py = [ @@ -3206,32 +3219,24 @@ pyzmq = [ {file = "pyzmq-22.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f89468059ebc519a7acde1ee50b779019535db8dcf9b8c162ef669257fef7a93"}, {file = "pyzmq-22.3.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea12133df25e3a6918718fbb9a510c6ee5d3fdd5a346320421aac3882f4feeea"}, {file = "pyzmq-22.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c532fd68b93998aab92356be280deec5de8f8fe59cd28763d2cc8a58747b7f"}, - {file = "pyzmq-22.3.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f907c7359ce8bf7f7e63c82f75ad0223384105f5126f313400b7e8004d9b33c3"}, - {file = "pyzmq-22.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:902319cfe23366595d3fa769b5b751e6ee6750a0a64c5d9f757d624b2ac3519e"}, {file = "pyzmq-22.3.0-cp310-cp310-win32.whl", hash = "sha256:67db33bea0a29d03e6eeec55a8190e033318cee3cbc732ba8fd939617cbf762d"}, {file = "pyzmq-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:7661fc1d5cb73481cf710a1418a4e1e301ed7d5d924f91c67ba84b2a1b89defd"}, {file = "pyzmq-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:79244b9e97948eaf38695f4b8e6fc63b14b78cc37f403c6642ba555517ac1268"}, {file = "pyzmq-22.3.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab888624ed68930442a3f3b0b921ad7439c51ba122dbc8c386e6487a658e4a4e"}, {file = "pyzmq-22.3.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18cd854b423fce44951c3a4d3e686bac8f1243d954f579e120a1714096637cc0"}, {file = "pyzmq-22.3.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:de8df0684398bd74ad160afdc2a118ca28384ac6f5e234eb0508858d8d2d9364"}, - {file = "pyzmq-22.3.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:62bcade20813796c426409a3e7423862d50ff0639f5a2a95be4b85b09a618666"}, - {file = "pyzmq-22.3.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ea5a79e808baef98c48c884effce05c31a0698c1057de8fc1c688891043c1ce1"}, {file = "pyzmq-22.3.0-cp36-cp36m-win32.whl", hash = "sha256:3c1895c95be92600233e476fe283f042e71cf8f0b938aabf21b7aafa62a8dac9"}, {file = "pyzmq-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:851977788b9caa8ed011f5f643d3ee8653af02c5fc723fa350db5125abf2be7b"}, {file = "pyzmq-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b4ebed0977f92320f6686c96e9e8dd29eed199eb8d066936bac991afc37cbb70"}, {file = "pyzmq-22.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42abddebe2c6a35180ca549fadc7228d23c1e1f76167c5ebc8a936b5804ea2df"}, {file = "pyzmq-22.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1e41b32d6f7f9c26bc731a8b529ff592f31fc8b6ef2be9fa74abd05c8a342d7"}, {file = "pyzmq-22.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:be4e0f229cf3a71f9ecd633566bd6f80d9fa6afaaff5489492be63fe459ef98c"}, - {file = "pyzmq-22.3.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:08c4e315a76ef26eb833511ebf3fa87d182152adf43dedee8d79f998a2162a0b"}, - {file = "pyzmq-22.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:badb868fff14cfd0e200eaa845887b1011146a7d26d579aaa7f966c203736b92"}, {file = "pyzmq-22.3.0-cp37-cp37m-win32.whl", hash = "sha256:7c58f598d9fcc52772b89a92d72bf8829c12d09746a6d2c724c5b30076c1f11d"}, {file = "pyzmq-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2b97502c16a5ec611cd52410bdfaab264997c627a46b0f98d3f666227fd1ea2d"}, {file = "pyzmq-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d728b08448e5ac3e4d886b165385a262883c34b84a7fe1166277fe675e1c197a"}, {file = "pyzmq-22.3.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:480b9931bfb08bf8b094edd4836271d4d6b44150da051547d8c7113bf947a8b0"}, {file = "pyzmq-22.3.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7dc09198e4073e6015d9a8ea093fc348d4e59de49382476940c3dd9ae156fba8"}, {file = "pyzmq-22.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ca6cd58f62a2751728016d40082008d3b3412a7f28ddfb4a2f0d3c130f69e74"}, - {file = "pyzmq-22.3.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:468bd59a588e276961a918a3060948ae68f6ff5a7fa10bb2f9160c18fe341067"}, - {file = "pyzmq-22.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c88fa7410e9fc471e0858638f403739ee869924dd8e4ae26748496466e27ac59"}, {file = "pyzmq-22.3.0-cp38-cp38-win32.whl", hash = "sha256:c0f84360dcca3481e8674393bdf931f9f10470988f87311b19d23cda869bb6b7"}, {file = "pyzmq-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:f762442bab706fd874064ca218b33a1d8e40d4938e96c24dafd9b12e28017f45"}, {file = "pyzmq-22.3.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:954e73c9cd4d6ae319f1c936ad159072b6d356a92dcbbabfd6e6204b9a79d356"}, @@ -3239,8 +3244,6 @@ pyzmq = [ {file = "pyzmq-22.3.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:acebba1a23fb9d72b42471c3771b6f2f18dcd46df77482612054bd45c07dfa36"}, {file = "pyzmq-22.3.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cf98fd7a6c8aaa08dbc699ffae33fd71175696d78028281bc7b832b26f00ca57"}, {file = "pyzmq-22.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d072f7dfbdb184f0786d63bda26e8a0882041b1e393fbe98940395f7fab4c5e2"}, - {file = "pyzmq-22.3.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:53f4fd13976789ffafedd4d46f954c7bb01146121812b72b4ddca286034df966"}, - {file = "pyzmq-22.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1b5d457acbadcf8b27561deeaa386b0217f47626b29672fa7bd31deb6e91e1b"}, {file = "pyzmq-22.3.0-cp39-cp39-win32.whl", hash = "sha256:e6a02cf7271ee94674a44f4e62aa061d2d049001c844657740e156596298b70b"}, {file = "pyzmq-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:d3dcb5548ead4f1123851a5ced467791f6986d68c656bc63bfff1bf9e36671e2"}, {file = "pyzmq-22.3.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3a4c9886d61d386b2b493377d980f502186cd71d501fffdba52bd2a0880cef4f"}, @@ -3407,6 +3410,34 @@ send2trash = [ {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"}, {file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"}, ] +shapely = [ + {file = "Shapely-1.8.1.post1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0ca96a3314b7a38a3bb385531469de1fcf2b2c2979ec2aa4f37b4c70632cf1ad"}, + {file = "Shapely-1.8.1.post1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:493902923fdd135316161a4ece5294ba3ce81accaa54540d2af3b93f7231143a"}, + {file = "Shapely-1.8.1.post1-cp310-cp310-win_amd64.whl", hash = "sha256:b82fc74d5efb11a71283c4ed69b4e036997cc70db4b73c646207ddf0476ade44"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:89bc5f3abc1ccbc7682c2e1664153c4f8f125fa9c24bff4abca48685739d5636"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:44cb895b1710f7559c28d69dfa08cafe4f58cd4b7a87091a55bdf6711ad9ad66"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:437fff3b6274be26ffa3e450de711ee01e436324b5a405952add2146227e3eb5"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-win32.whl", hash = "sha256:dc0f46212f84c57d13189fc33cf61e13eee292704d7652e931e4b51c54b0c73c"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-win_amd64.whl", hash = "sha256:9248aad099ecf228fbdd877b0c668823dd83c48798cf04d49a1be75167e3a7ce"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bab5ff7c576588acccd665ecce2a0fe7b47d4ce0398f2d5c1e5b2e27d09398d2"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2381ce0aff67d569eb509bcc051264aa5fbdc1fdd54f4c09963d0e09f16a8f1b"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b4d35e72022b2dbf152d476b0362596011c674ff68be9fc8f2e68e71d86502ca"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-win32.whl", hash = "sha256:5a420e7112b55a1587412a5b03ebf59e302ddd354da68516d3721718f6b8a7c5"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-win_amd64.whl", hash = "sha256:c4c366e18edf91196a399f8f0f046f93516002e6d8af0b57c23e7c7d91944b16"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2020fda37c708d44a613c020cea09e81e476f96866f348afc2601e66c0e71db1"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:69d5352fb977655c85d2f40a05ae24fc5053cccee77d0a8b1f773e54804e723e"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:83f3c8191d30ae0e3dd557434c48ca591d75342d5a3f42fc5148ec42796be624"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3e792635e92c9aacd1452a589a4fa2970114b6a9b1165e09655481f6e58970f5"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-win32.whl", hash = "sha256:8cf7331f61780506976fe2175e069d898e1b04ace73be21aad55c3ee92e58e3a"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-win_amd64.whl", hash = "sha256:f109064bdb0753a6bac6238538cfeeb4a09739e2d556036b343b2eabeb9520b2"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:aea1e87450adffba3d04ccbaa790df719bb7aa23b05ac797ad16be236a5d0db8"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3a3602ba2e7715ddd5d4114173dec83d3181bfb2497e8589676c284aa739fd67"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:679789d774cfe09ca05118cab78c0a6a42985b3ed23bc93606272a4509b4df28"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:363df36370f28fdc7789857929f6ff27e659f64087b4c89f7a47ed43bd3bfe4d"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-win32.whl", hash = "sha256:bc6063875182515d3888180cc4cbdbaa6443e4a4386c4bb25499e9875b75dcac"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-win_amd64.whl", hash = "sha256:54aeb2a57978ce731fd52289d0e1deee7c232d41aed53091f38776378f644184"}, + {file = "Shapely-1.8.1.post1.tar.gz", hash = "sha256:93ff06ff05fbe2be843b93c7b1ad8292e56e665ba01b4708f75ae8a757972e9f"}, +] shellingham = [ {file = "shellingham-1.4.0-py2.py3-none-any.whl", hash = "sha256:536b67a0697f2e4af32ab176c00a50ac2899c5a05e0d8e2dadac8e58888283f9"}, {file = "shellingham-1.4.0.tar.gz", hash = "sha256:4855c2458d6904829bd34c299f11fdeed7cfefbf8a2c522e4caea6cd76b3171e"}, diff --git a/pyproject.toml b/pyproject.toml index cad3f245..dafcd587 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ click = ">=7.1.2,<9.0" # NOTE: COLAB has 7.1.2 and has problems updating rich = "^10.15.2" shellingham = "^1.4.0" scikit-learn = ">=0.24.0" -Shapely = ">=1.7.1" +Shapely = "^1.8.1" [tool.poetry.dev-dependencies] poetry = "^1.1.5" From 155f2709f4e9ecf1f926518dec28ae02a9809efe Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Wed, 23 Mar 2022 11:38:18 +0000 Subject: [PATCH 25/33] edge case --- nucleus/metrics/cuboid_metrics.py | 19 +++++++++-------- nucleus/metrics/cuboid_utils.py | 35 +++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py index ed9fc509..62e1f09f 100644 --- a/nucleus/metrics/cuboid_metrics.py +++ b/nucleus/metrics/cuboid_metrics.py @@ -81,12 +81,12 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - birds_eye_view: bool = False, + iou_2d: bool = False, ): """Initializes CuboidIOU object. Args: - enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 birds_eye_view: whether to return the BEV 2D IOU if true, or the 3D IOU if false. confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 @@ -95,7 +95,7 @@ def __init__( 0 <= iou_threshold <= 1 ), "IoU threshold must be between 0 and 1." self.iou_threshold = iou_threshold - self.birds_eye_view = birds_eye_view + self.iou_2d = iou_2d super().__init__(enforce_label_match, confidence_threshold) def eval( @@ -103,16 +103,17 @@ def eval( annotations: List[CuboidAnnotation], predictions: List[CuboidPrediction], ) -> ScalarResult: - iou_3d, iou_2d = detection_iou( + iou_3d_metric, iou_2d_metric = detection_iou( predictions, annotations, threshold_in_overlap_ratio=self.iou_threshold, ) + weight = max(len(annotations), len(predictions)) - if self.birds_eye_view: - avg_iou = iou_2d.sum() / max(weight, sys.float_info.epsilon) + if self.iou_2d: + avg_iou = iou_2d_metric.sum() / max(weight, sys.float_info.epsilon) else: - avg_iou = iou_3d.sum() / max(weight, sys.float_info.epsilon) + avg_iou = iou_3d_metric.sum() / max(weight, sys.float_info.epsilon) return ScalarResult(avg_iou, weight) @@ -130,7 +131,7 @@ def __init__( """Initializes CuboidIOU object. Args: - enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 """ @@ -168,7 +169,7 @@ def __init__( """Initializes CuboidIOU object. Args: - enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 """ diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index ad959147..edc6af46 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -113,6 +113,7 @@ def compute_outer_iou( :param distance_threshold: computes iou only within this distance (~3x speedup) :return: (n, m) 3D IoU, (n, m) 2D IoU """ + bottom_z = np.maximum.outer( xyz_0[:, 2] - (wlh_0[:, 2] / 2), xyz_1[:, 2] - (wlh_1[:, 2] / 2) ) @@ -249,14 +250,17 @@ def associate_cuboids_on_iou( def recall_precision( - prediction, - groundtruth, + prediction: List[CuboidPrediction], + groundtruth: List[CuboidAnnotation], threshold_in_overlap_ratio: float, ) -> Dict[str, float]: """ - :param predictions: - :param ground_truth: - :param threshold: threshold in overlap ratio if IoU + Calculates the precision and recall of each lidar frame. + + Args: + :param predictions: list of cuboid annotation predictions. + :param ground_truth: list of cuboid annotation groundtruths. + :param threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. """ tp_sum = 0 @@ -307,13 +311,32 @@ def recall_precision( } -def detection_iou(prediction, groundtruth, threshold_in_overlap_ratio): +def detection_iou( + prediction: List[CuboidPrediction], + groundtruth: List[CuboidAnnotation], + threshold_in_overlap_ratio: float, +) -> Tuple[ + np.ndarray[float], np.ndarray[float] +]: # pylint: disable=unsubscriptable-object + """ + Calculates the 2D IOU and 3D IOU overlap between predictions and groundtruth. + Uses linear sum assignment to associate cuboids. + + Args: + :param predictions: list of cuboid annotation predictions. + :param ground_truth: list of cuboid annotation groundtruths. + :param threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. + """ gt_items = process_dataitem(groundtruth) pred_items = process_dataitem(prediction) meter_2d = [] meter_3d = [] + + if gt_items["xyz"].shape[0] == 0 or pred_items["xyz"].shape[0] == 0: + return np.array([0.0]), np.array([0.0]) + iou_3d, iou_2d = compute_outer_iou( gt_items["xyz"], gt_items["wlh"], From 003127d8346b3b51aef1b649fd1b8ac8357fb480 Mon Sep 17 00:00:00 2001 From: Anirudh-Scale Date: Wed, 23 Mar 2022 11:42:28 +0000 Subject: [PATCH 26/33] np type --- nucleus/metrics/cuboid_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py index edc6af46..0ebf0716 100644 --- a/nucleus/metrics/cuboid_utils.py +++ b/nucleus/metrics/cuboid_utils.py @@ -315,9 +315,7 @@ def detection_iou( prediction: List[CuboidPrediction], groundtruth: List[CuboidAnnotation], threshold_in_overlap_ratio: float, -) -> Tuple[ - np.ndarray[float], np.ndarray[float] -]: # pylint: disable=unsubscriptable-object +) -> Tuple[np.ndarray, np.ndarray]: """ Calculates the 2D IOU and 3D IOU overlap between predictions and groundtruth. Uses linear sum assignment to associate cuboids. From 213bafa73f8bfb15ed27f4652568eaade9c12274 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Wed, 30 Mar 2022 12:30:55 +0200 Subject: [PATCH 27/33] CuboidMetrics can filter metadata --- nucleus/metrics/cuboid_metrics.py | 150 ++++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 6 deletions(-) diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py index 62e1f09f..14357160 100644 --- a/nucleus/metrics/cuboid_metrics.py +++ b/nucleus/metrics/cuboid_metrics.py @@ -1,15 +1,102 @@ import sys from abc import abstractmethod -from typing import List +from collections import namedtuple +from enum import Enum +from typing import Callable, List, Optional, Union -from nucleus.annotation import AnnotationList, CuboidAnnotation -from nucleus.prediction import CuboidPrediction, PredictionList +from nucleus.annotation import Annotation, AnnotationList, CuboidAnnotation +from nucleus.prediction import CuboidPrediction, Prediction, PredictionList from .base import Metric, ScalarResult from .cuboid_utils import detection_iou, label_match_wrapper, recall_precision from .filters import confidence_filter +class FilterOp(str, Enum): + GT = ">" + GTE = ">=" + LT = "<" + LTE = "<=" + EQ = "==" + NEQ = "!=" + + +MetadataFilter = namedtuple("MetadataFilter", ["key", "op", "value"]) +DNFMetadataFilters = List[List[MetadataFilter]] +DNFMetadataFilters.__doc__ = """\ +Disjunctive normal form (DNF) filters. +DNF allows arbitrary boolean logical combinations of single field predicates. +The innermost structures each describe a single column predicate. The list of inner predicates is +interpreted as a conjunction (AND), forming a more selective and multiple column predicate. +Finally, the most outer list combines these filters as a disjunction (OR). +""" + + +def filter_to_comparison_function( + metadata_filter: MetadataFilter, +) -> Callable[[Union[Annotation, Prediction]], bool]: + op = FilterOp(metadata_filter.op) + if op is FilterOp.GT: + return ( + lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] + > metadata_filter.value + ) + elif op is FilterOp.GTE: + return ( + lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] + >= metadata_filter.value + ) + elif op is FilterOp.LT: + return ( + lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] + < metadata_filter.value + ) + elif op is FilterOp.LTE: + return ( + lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] + <= metadata_filter.value + ) + elif op is FilterOp.EQ: + return ( + lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] + == metadata_filter.value + ) + else: + raise RuntimeError(f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {metadata_filter},") + + +def filter_metadata( + ann_or_pred: Union[List[Annotation], List[Prediction]], + metadata_filter: Union[DNFMetadataFilters, List[MetadataFilter]], +): + """ + Attributes: + ann_or_pred: Prediction or Annotation + metadata_filter: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + if len(metadata_filter) == 0: + return ann_or_pred + + if isinstance(metadata_filter[0], MetadataFilter): + # Normalize into DNF + metadata_filter: DNFMetadataFilters = [metadata_filter] + + filtered = [] + for item in ann_or_pred: + for or_branch in metadata_filter: + and_conditions = ( + filter_to_comparison_function(cond) for cond in or_branch + ) + if all(c(item) for c in and_conditions): + filtered.append(item) + break + return filtered + + class CuboidMetric(Metric): """Abstract class for metrics of cuboids. @@ -28,16 +115,30 @@ def __init__( self, enforce_label_match: bool = False, confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, ): """Initializes CuboidMetric abstract object. Args: enforce_label_match: whether to enforce that annotation and prediction labels must match. Default False confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). """ self.enforce_label_match = enforce_label_match assert 0 <= confidence_threshold <= 1 self.confidence_threshold = confidence_threshold + self.annotation_filters = annotation_filters + self.prediction_filters = prediction_filters @abstractmethod def eval( @@ -64,6 +165,12 @@ def __call__( cuboid_predictions.extend(predictions.cuboid_predictions) eval_fn = label_match_wrapper(self.eval) + cuboid_annotations = filter_metadata( + cuboid_annotations, self.annotation_filters + ) + cuboid_predictions = filter_metadata( + cuboid_annotations, self.prediction_filters + ) result = eval_fn( cuboid_annotations, cuboid_predictions, @@ -82,6 +189,8 @@ def __init__( iou_threshold: float = 0.0, confidence_threshold: float = 0.0, iou_2d: bool = False, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, ): """Initializes CuboidIOU object. @@ -90,13 +199,28 @@ def __init__( iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 birds_eye_view: whether to return the BEV 2D IOU if true, or the 3D IOU if false. confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). """ assert ( 0 <= iou_threshold <= 1 ), "IoU threshold must be between 0 and 1." self.iou_threshold = iou_threshold self.iou_2d = iou_2d - super().__init__(enforce_label_match, confidence_threshold) + super().__init__( + enforce_label_match=enforce_label_match, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + ) def eval( self, @@ -127,6 +251,8 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, ): """Initializes CuboidIOU object. @@ -139,7 +265,12 @@ def __init__( 0 <= iou_threshold <= 1 ), "IoU threshold must be between 0 and 1." self.iou_threshold = iou_threshold - super().__init__(enforce_label_match, confidence_threshold) + super().__init__( + enforce_label_match=enforce_label_match, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + ) def eval( self, @@ -165,6 +296,8 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, ): """Initializes CuboidIOU object. @@ -177,7 +310,12 @@ def __init__( 0 <= iou_threshold <= 1 ), "IoU threshold must be between 0 and 1." self.iou_threshold = iou_threshold - super().__init__(enforce_label_match, confidence_threshold) + super().__init__( + enforce_label_match=enforce_label_match, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + ) def eval( self, From 72f3c0d676cd460bde05f0423251603fd3a9cf58 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Wed, 30 Mar 2022 13:08:05 +0200 Subject: [PATCH 28/33] Add Cuboid configs --- nucleus/metrics/cuboid_metrics.py | 29 ++- .../available_eval_functions.py | 181 ++++++++++++++++++ 2 files changed, 204 insertions(+), 6 deletions(-) diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py index 14357160..1869864f 100644 --- a/nucleus/metrics/cuboid_metrics.py +++ b/nucleus/metrics/cuboid_metrics.py @@ -61,11 +61,18 @@ def filter_to_comparison_function( lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] == metadata_filter.value ) + elif op is FilterOp.NEQ: + return ( + lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] + != metadata_filter.value + ) else: - raise RuntimeError(f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {metadata_filter},") + raise RuntimeError( + f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {metadata_filter}," + ) -def filter_metadata( +def filter_by_metadata_fields( ann_or_pred: Union[List[Annotation], List[Prediction]], metadata_filter: Union[DNFMetadataFilters, List[MetadataFilter]], ): @@ -78,7 +85,7 @@ def filter_metadata( interpreted as a conjunction (AND), forming a more selective and multiple column predicate. Finally, the most outer list combines these filters as a disjunction (OR). """ - if len(metadata_filter) == 0: + if metadata_filter is None or len(metadata_filter) == 0: return ann_or_pred if isinstance(metadata_filter[0], MetadataFilter): @@ -165,10 +172,10 @@ def __call__( cuboid_predictions.extend(predictions.cuboid_predictions) eval_fn = label_match_wrapper(self.eval) - cuboid_annotations = filter_metadata( + cuboid_annotations = filter_by_metadata_fields( cuboid_annotations, self.annotation_filters ) - cuboid_predictions = filter_metadata( + cuboid_predictions = filter_by_metadata_fields( cuboid_annotations, self.prediction_filters ) result = eval_fn( @@ -197,7 +204,7 @@ def __init__( Args: enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 - birds_eye_view: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field @@ -260,6 +267,16 @@ def __init__( enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). """ assert ( 0 <= iou_threshold <= 1 diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index a2ba14cd..5382c3ba 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -8,6 +8,7 @@ from ..data_transfer_objects.eval_function import EvalFunctionEntry from ..errors import EvalFunctionNotAvailableError +from ...metrics.cuboid_metrics import DNFMetadataFilters MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes" @@ -162,6 +163,177 @@ def expected_name(cls) -> str: return "bbox_precision" +class CuboidIOU2DConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, + **kwargs, + ): + """Configure a call to CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + iou_2d=True, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_iou_2d" + + +class CuboidIOU3DConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, + **kwargs, + ): + """Configure a call to CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + iou_2d=False, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_iou_3d" + + +class CuboidPrecisionConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, + **kwargs, + ): + """Configure a call to CuboidPrecision object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + iou_2d=False, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_precision" + + +class CuboidRecallConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[DNFMetadataFilters] = None, + prediction_filters: Optional[DNFMetadataFilters] = None, + **kwargs, + ): + """Configure a call to a CuboidRecall object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + iou_2d=False, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_recall" + + class CategorizationF1Config(EvalFunctionConfig): def __call__( self, @@ -324,6 +496,15 @@ def __init__(self, available_functions: List[EvalFunctionEntry]): self.cat_f1: CategorizationF1Config = self._assign_eval_function_if_defined( CategorizationF1Config # type: ignore ) + self.cuboid_iou_2d: CuboidIOU2DConfig = ( + self._assign_eval_function_if_defined(CuboidIOU2DConfig) + ) + self.cuboid_iou_3d: CuboidIOU3DConfig = ( + self._assign_eval_function_if_defined(CuboidIOU3DConfig) + ) + self.cuboid_precision: CuboidPrecision = ( + self._assign_eval_function_if_defined(CuboidPrecisionConfig) + ) # Add public entries that have not been implemented as an attribute on this class for func_entry in self._public_func_entries.values(): From f0c7399d26745ba13b9730358fc9618e692c501c Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Wed, 30 Mar 2022 13:22:53 +0200 Subject: [PATCH 29/33] Fix mypy errors --- nucleus/metrics/cuboid_metrics.py | 45 +++++++++++++++++-- nucleus/validate/client.py | 2 +- nucleus/validate/errors.py | 4 -- .../available_eval_functions.py | 15 +++---- 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py index 1869864f..3b0544f8 100644 --- a/nucleus/metrics/cuboid_metrics.py +++ b/nucleus/metrics/cuboid_metrics.py @@ -4,8 +4,27 @@ from enum import Enum from typing import Callable, List, Optional, Union -from nucleus.annotation import Annotation, AnnotationList, CuboidAnnotation -from nucleus.prediction import CuboidPrediction, Prediction, PredictionList +from nucleus.annotation import ( + Annotation, + AnnotationList, + BoxAnnotation, + CategoryAnnotation, + CuboidAnnotation, + LineAnnotation, + MultiCategoryAnnotation, + PolygonAnnotation, + SegmentationAnnotation, +) +from nucleus.prediction import ( + BoxPrediction, + CategoryPrediction, + CuboidPrediction, + LinePrediction, + PolygonPrediction, + Prediction, + PredictionList, + SegmentationPrediction, +) from .base import Metric, ScalarResult from .cuboid_utils import detection_iou, label_match_wrapper, recall_precision @@ -31,10 +50,30 @@ class FilterOp(str, Enum): Finally, the most outer list combines these filters as a disjunction (OR). """ +AnnotationsWithMetadata = Union[ + BoxAnnotation, + CategoryAnnotation, + CuboidAnnotation, + LineAnnotation, + MultiCategoryAnnotation, + PolygonAnnotation, + SegmentationAnnotation, +] + + +PredictionsWithMetadata = Union[ + BoxPrediction, + CategoryPrediction, + CuboidPrediction, + LinePrediction, + PolygonPrediction, + SegmentationPrediction, +] + def filter_to_comparison_function( metadata_filter: MetadataFilter, -) -> Callable[[Union[Annotation, Prediction]], bool]: +) -> Callable[[Union[AnnotationsWithMetadata, PredictionsWithMetadata]], bool]: op = FilterOp(metadata_filter.op) if op is FilterOp.GT: return ( diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index f8189fe0..cfc5fbb5 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -6,7 +6,7 @@ from .constants import SCENARIO_TEST_ID_KEY from .data_transfer_objects.eval_function import GetEvalFunctions from .data_transfer_objects.scenario_test import CreateScenarioTestRequest -from .errors import CreateScenarioTestError, InvalidEvaluationCriteria +from .errors import CreateScenarioTestError from .eval_functions.available_eval_functions import AvailableEvalFunctions from .eval_functions.base_eval_function import EvalFunctionConfig from .scenario_test import ScenarioTest diff --git a/nucleus/validate/errors.py b/nucleus/validate/errors.py index 87253c71..19cc249b 100644 --- a/nucleus/validate/errors.py +++ b/nucleus/validate/errors.py @@ -4,7 +4,3 @@ class CreateScenarioTestError(Exception): class EvalFunctionNotAvailableError(Exception): pass - - -class InvalidEvaluationCriteria(Exception): - pass diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index 5382c3ba..0c3f135f 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -6,9 +6,9 @@ EvalFunctionConfig, ) +from ...metrics.cuboid_metrics import DNFMetadataFilters from ..data_transfer_objects.eval_function import EvalFunctionEntry from ..errors import EvalFunctionNotAvailableError -from ...metrics.cuboid_metrics import DNFMetadataFilters MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes" @@ -496,15 +496,10 @@ def __init__(self, available_functions: List[EvalFunctionEntry]): self.cat_f1: CategorizationF1Config = self._assign_eval_function_if_defined( CategorizationF1Config # type: ignore ) - self.cuboid_iou_2d: CuboidIOU2DConfig = ( - self._assign_eval_function_if_defined(CuboidIOU2DConfig) - ) - self.cuboid_iou_3d: CuboidIOU3DConfig = ( - self._assign_eval_function_if_defined(CuboidIOU3DConfig) - ) - self.cuboid_precision: CuboidPrecision = ( - self._assign_eval_function_if_defined(CuboidPrecisionConfig) - ) + self.cuboid_iou_2d: CuboidIOU2DConfig = self._assign_eval_function_if_defined(CuboidIOU2DConfig) # type: ignore + self.cuboid_iou_3d: CuboidIOU3DConfig = self._assign_eval_function_if_defined(CuboidIOU3DConfig) # type: ignore + self.cuboid_precision: CuboidPrecisionConfig = self._assign_eval_function_if_defined(CuboidPrecisionConfig) # type: ignore + self.cuboid_recall: CuboidRecallConfig = self._assign_eval_function_if_defined(CuboidRecallConfig) # type: ignore # Add public entries that have not been implemented as an attribute on this class for func_entry in self._public_func_entries.values(): From d72c5657cc45509861ad4e82923a8ab664f9285c Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Wed, 30 Mar 2022 17:19:03 +0200 Subject: [PATCH 30/33] Add field filters --- nucleus/metrics/cuboid_metrics.py | 116 +++++++++++------- nucleus/validate/client.py | 10 +- .../data_transfer_objects/eval_function.py | 2 +- .../data_transfer_objects/scenario_test.py | 7 +- .../available_eval_functions.py | 20 ++- .../eval_functions/base_eval_function.py | 3 - pyproject.toml | 2 +- 7 files changed, 100 insertions(+), 60 deletions(-) diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py index 3b0544f8..5431f283 100644 --- a/nucleus/metrics/cuboid_metrics.py +++ b/nucleus/metrics/cuboid_metrics.py @@ -1,3 +1,5 @@ +import enum +import functools import sys from abc import abstractmethod from collections import namedtuple @@ -40,9 +42,29 @@ class FilterOp(str, Enum): NEQ = "!=" -MetadataFilter = namedtuple("MetadataFilter", ["key", "op", "value"]) -DNFMetadataFilters = List[List[MetadataFilter]] -DNFMetadataFilters.__doc__ = """\ +class FilterType(str, enum.Enum): + FIELD = "field" + METADATA = "metadata" + + +AnnotationOrPredictionFilter = namedtuple( + "AnnotationOrPredictionFilter", ["key", "op", "value", "type"] +) +FieldFilter = namedtuple( + "FieldFilter", + ["key", "op", "value", "type"], + defaults=[None, None, None, FilterType.FIELD], +) +MetadataFilter = namedtuple( + "MetadataFilter", + ["key", "op", "value", "type"], + defaults=[None, None, None, FilterType.METADATA], +) + +DNFFilter = List[ + List[Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]] +] +DNFFilter.__doc__ = """\ Disjunctive normal form (DNF) filters. DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures each describe a single column predicate. The list of inner predicates is @@ -71,40 +93,34 @@ class FilterOp(str, Enum): ] +def field_getter(field_name): + return lambda ann_or_pred: getattr(ann_or_pred, field_name) + + +def metadata_field_getter(field_name): + return lambda ann_or_pred: ann_or_pred.metadata[field_name] + + def filter_to_comparison_function( - metadata_filter: MetadataFilter, + metadata_filter: AnnotationOrPredictionFilter, ) -> Callable[[Union[AnnotationsWithMetadata, PredictionsWithMetadata]], bool]: + if FilterType(metadata_filter.type) == FilterType.FIELD: + getter = field_getter(metadata_filter.key) + elif FilterType(metadata_filter.type) == FilterType.METADATA: + getter = metadata_field_getter(metadata_filter.key) op = FilterOp(metadata_filter.op) if op is FilterOp.GT: - return ( - lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] - > metadata_filter.value - ) + return lambda ann_or_pred: getter(ann_or_pred) > metadata_filter.value elif op is FilterOp.GTE: - return ( - lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] - >= metadata_filter.value - ) + return lambda ann_or_pred: getter(ann_or_pred) >= metadata_filter.value elif op is FilterOp.LT: - return ( - lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] - < metadata_filter.value - ) + return lambda ann_or_pred: getter(ann_or_pred) < metadata_filter.value elif op is FilterOp.LTE: - return ( - lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] - <= metadata_filter.value - ) + return lambda ann_or_pred: getter(ann_or_pred) <= metadata_filter.value elif op is FilterOp.EQ: - return ( - lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] - == metadata_filter.value - ) + return lambda ann_or_pred: getter(ann_or_pred) == metadata_filter.value elif op is FilterOp.NEQ: - return ( - lambda ann_or_pred: ann_or_pred.metadata[metadata_filter.key] - != metadata_filter.value - ) + return lambda ann_or_pred: getter(ann_or_pred) != metadata_filter.value else: raise RuntimeError( f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {metadata_filter}," @@ -112,8 +128,10 @@ def filter_to_comparison_function( def filter_by_metadata_fields( - ann_or_pred: Union[List[Annotation], List[Prediction]], - metadata_filter: Union[DNFMetadataFilters, List[MetadataFilter]], + ann_or_pred: Union[ + List[AnnotationsWithMetadata], List[PredictionsWithMetadata] + ], + metadata_filter: Union[DNFFilter, List[MetadataFilter], List[List[List]]], ): """ Attributes: @@ -127,9 +145,25 @@ def filter_by_metadata_fields( if metadata_filter is None or len(metadata_filter) == 0: return ann_or_pred - if isinstance(metadata_filter[0], MetadataFilter): + if isinstance(metadata_filter[0], MetadataFilter) or isinstance( + metadata_filter[0], FieldFilter + ): # Normalize into DNF - metadata_filter: DNFMetadataFilters = [metadata_filter] + metadata_filter: DNFFilter = [metadata_filter] + # NOTE: We have to handle JSON transformed tuples which become three layers of lists + if ( + isinstance(metadata_filter, list) + and isinstance(metadata_filter[0], list) + and isinstance(metadata_filter[0][0], list) + ): + formatted_filter = [] + for or_branch in metadata_filter: + and_chain = [ + AnnotationOrPredictionFilter(*condition) + for condition in or_branch + ] + formatted_filter.append(and_chain) + metadata_filter = formatted_filter filtered = [] for item in ann_or_pred: @@ -161,8 +195,8 @@ def __init__( self, enforce_label_match: bool = False, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, ): """Initializes CuboidMetric abstract object. @@ -215,7 +249,7 @@ def __call__( cuboid_annotations, self.annotation_filters ) cuboid_predictions = filter_by_metadata_fields( - cuboid_annotations, self.prediction_filters + cuboid_predictions, self.prediction_filters ) result = eval_fn( cuboid_annotations, @@ -235,8 +269,8 @@ def __init__( iou_threshold: float = 0.0, confidence_threshold: float = 0.0, iou_2d: bool = False, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, ): """Initializes CuboidIOU object. @@ -297,8 +331,8 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, ): """Initializes CuboidIOU object. @@ -352,8 +386,8 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, ): """Initializes CuboidIOU object. diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index cfc5fbb5..43e2f7f1 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -5,7 +5,10 @@ from .constants import SCENARIO_TEST_ID_KEY from .data_transfer_objects.eval_function import GetEvalFunctions -from .data_transfer_objects.scenario_test import CreateScenarioTestRequest +from .data_transfer_objects.scenario_test import ( + CreateScenarioTestRequest, + EvalFunctionListEntry, +) from .errors import CreateScenarioTestError from .eval_functions.available_eval_functions import AvailableEvalFunctions from .eval_functions.base_eval_function import EvalFunctionConfig @@ -83,7 +86,10 @@ def create_scenario_test( name=name, slice_id=slice_id, evaluation_functions=[ - ef.to_entry() for ef in evaluation_functions # type:ignore + EvalFunctionListEntry( + id=ef.id, eval_func_arguments=ef.eval_func_arguments + ) + for ef in evaluation_functions ], ).dict(), "validate/scenario_test", diff --git a/nucleus/validate/data_transfer_objects/eval_function.py b/nucleus/validate/data_transfer_objects/eval_function.py index 2592be3d..5f7727b1 100644 --- a/nucleus/validate/data_transfer_objects/eval_function.py +++ b/nucleus/validate/data_transfer_objects/eval_function.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic import validator +from pydantic import BaseModel, validator from ...pydantic_base import ImmutableModel from ..constants import ThresholdComparison diff --git a/nucleus/validate/data_transfer_objects/scenario_test.py b/nucleus/validate/data_transfer_objects/scenario_test.py index 174571ff..2232efb7 100644 --- a/nucleus/validate/data_transfer_objects/scenario_test.py +++ b/nucleus/validate/data_transfer_objects/scenario_test.py @@ -7,10 +7,15 @@ from .eval_function import EvalFunctionEntry +class EvalFunctionListEntry(ImmutableModel): + id: str + eval_func_arguments: dict + + class CreateScenarioTestRequest(ImmutableModel): name: str slice_id: str - evaluation_functions: List[EvalFunctionEntry] + evaluation_functions: List[EvalFunctionListEntry] @validator("slice_id") def startswith_slice_indicator(cls, v): # pylint: disable=no-self-argument diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index 0c3f135f..d2dc7efc 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -6,7 +6,7 @@ EvalFunctionConfig, ) -from ...metrics.cuboid_metrics import DNFMetadataFilters +from ...metrics.cuboid_metrics import DNFFilter from ..data_transfer_objects.eval_function import EvalFunctionEntry from ..errors import EvalFunctionNotAvailableError @@ -169,8 +169,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, **kwargs, ): """Configure a call to CuboidIOU object. @@ -211,8 +211,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, **kwargs, ): """Configure a call to CuboidIOU object. @@ -254,8 +254,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, **kwargs, ): """Configure a call to CuboidPrecision object. @@ -280,7 +280,6 @@ def __call__( enforce_label_match=enforce_label_match, iou_threshold=iou_threshold, confidence_threshold=confidence_threshold, - iou_2d=False, annotation_filters=annotation_filters, prediction_filters=prediction_filters, **kwargs, @@ -297,8 +296,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFMetadataFilters] = None, - prediction_filters: Optional[DNFMetadataFilters] = None, + annotation_filters: Optional[DNFFilter] = None, + prediction_filters: Optional[DNFFilter] = None, **kwargs, ): """Configure a call to a CuboidRecall object. @@ -323,7 +322,6 @@ def __call__( enforce_label_match=enforce_label_match, iou_threshold=iou_threshold, confidence_threshold=confidence_threshold, - iou_2d=False, annotation_filters=annotation_filters, prediction_filters=prediction_filters, **kwargs, diff --git a/nucleus/validate/eval_functions/base_eval_function.py b/nucleus/validate/eval_functions/base_eval_function.py index af823f70..fa59ec31 100644 --- a/nucleus/validate/eval_functions/base_eval_function.py +++ b/nucleus/validate/eval_functions/base_eval_function.py @@ -62,6 +62,3 @@ def _op_to_test_metric(self, comparison: ThresholdComparison, value): threshold=value, eval_func_arguments=self.eval_func_arguments, ) - - def to_entry(self): - return self.eval_func_entry diff --git a/pyproject.toml b/pyproject.toml index dafcd587..9c518db4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.8.3" +version = "0.9b3" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] From 3ffffd0be17d29888b1456b40b93ee8b6a65677d Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Thu, 31 Mar 2022 14:42:56 +0200 Subject: [PATCH 31/33] Add tests for filtering functions and move them to seperate module --- nucleus/metrics/__init__.py | 6 + nucleus/metrics/cuboid_metrics.py | 236 +++------------- nucleus/metrics/filtering.py | 264 ++++++++++++++++++ .../data_transfer_objects/eval_function.py | 2 +- .../data_transfer_objects/scenario_test.py | 2 - .../available_eval_functions.py | 18 +- pyproject.toml | 2 +- tests/helpers.py | 1 + tests/metrics/test_filtering.py | 238 ++++++++++++++++ 9 files changed, 564 insertions(+), 205 deletions(-) create mode 100644 nucleus/metrics/filtering.py create mode 100644 tests/metrics/test_filtering.py diff --git a/nucleus/metrics/__init__.py b/nucleus/metrics/__init__.py index cb19d8a8..460561f7 100644 --- a/nucleus/metrics/__init__.py +++ b/nucleus/metrics/__init__.py @@ -1,6 +1,12 @@ from .base import Metric, ScalarResult from .categorization_metrics import CategorizationF1 from .cuboid_metrics import CuboidIOU, CuboidPrecision, CuboidRecall +from .filtering import ( + FieldFilter, + ListOfOrAndFilters, + MetadataFilter, + apply_filters, +) from .polygon_metrics import ( PolygonAveragePrecision, PolygonIOU, diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py index 5431f283..313ef759 100644 --- a/nucleus/metrics/cuboid_metrics.py +++ b/nucleus/metrics/cuboid_metrics.py @@ -1,182 +1,16 @@ -import enum -import functools import sys from abc import abstractmethod -from collections import namedtuple -from enum import Enum -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union -from nucleus.annotation import ( - Annotation, - AnnotationList, - BoxAnnotation, - CategoryAnnotation, - CuboidAnnotation, - LineAnnotation, - MultiCategoryAnnotation, - PolygonAnnotation, - SegmentationAnnotation, -) -from nucleus.prediction import ( - BoxPrediction, - CategoryPrediction, - CuboidPrediction, - LinePrediction, - PolygonPrediction, - Prediction, - PredictionList, - SegmentationPrediction, -) +from nucleus.annotation import AnnotationList, CuboidAnnotation +from nucleus.prediction import CuboidPrediction, PredictionList from .base import Metric, ScalarResult from .cuboid_utils import detection_iou, label_match_wrapper, recall_precision +from .filtering import ListOfAndFilters, ListOfOrAndFilters, apply_filters from .filters import confidence_filter -class FilterOp(str, Enum): - GT = ">" - GTE = ">=" - LT = "<" - LTE = "<=" - EQ = "==" - NEQ = "!=" - - -class FilterType(str, enum.Enum): - FIELD = "field" - METADATA = "metadata" - - -AnnotationOrPredictionFilter = namedtuple( - "AnnotationOrPredictionFilter", ["key", "op", "value", "type"] -) -FieldFilter = namedtuple( - "FieldFilter", - ["key", "op", "value", "type"], - defaults=[None, None, None, FilterType.FIELD], -) -MetadataFilter = namedtuple( - "MetadataFilter", - ["key", "op", "value", "type"], - defaults=[None, None, None, FilterType.METADATA], -) - -DNFFilter = List[ - List[Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter]] -] -DNFFilter.__doc__ = """\ -Disjunctive normal form (DNF) filters. -DNF allows arbitrary boolean logical combinations of single field predicates. -The innermost structures each describe a single column predicate. The list of inner predicates is -interpreted as a conjunction (AND), forming a more selective and multiple column predicate. -Finally, the most outer list combines these filters as a disjunction (OR). -""" - -AnnotationsWithMetadata = Union[ - BoxAnnotation, - CategoryAnnotation, - CuboidAnnotation, - LineAnnotation, - MultiCategoryAnnotation, - PolygonAnnotation, - SegmentationAnnotation, -] - - -PredictionsWithMetadata = Union[ - BoxPrediction, - CategoryPrediction, - CuboidPrediction, - LinePrediction, - PolygonPrediction, - SegmentationPrediction, -] - - -def field_getter(field_name): - return lambda ann_or_pred: getattr(ann_or_pred, field_name) - - -def metadata_field_getter(field_name): - return lambda ann_or_pred: ann_or_pred.metadata[field_name] - - -def filter_to_comparison_function( - metadata_filter: AnnotationOrPredictionFilter, -) -> Callable[[Union[AnnotationsWithMetadata, PredictionsWithMetadata]], bool]: - if FilterType(metadata_filter.type) == FilterType.FIELD: - getter = field_getter(metadata_filter.key) - elif FilterType(metadata_filter.type) == FilterType.METADATA: - getter = metadata_field_getter(metadata_filter.key) - op = FilterOp(metadata_filter.op) - if op is FilterOp.GT: - return lambda ann_or_pred: getter(ann_or_pred) > metadata_filter.value - elif op is FilterOp.GTE: - return lambda ann_or_pred: getter(ann_or_pred) >= metadata_filter.value - elif op is FilterOp.LT: - return lambda ann_or_pred: getter(ann_or_pred) < metadata_filter.value - elif op is FilterOp.LTE: - return lambda ann_or_pred: getter(ann_or_pred) <= metadata_filter.value - elif op is FilterOp.EQ: - return lambda ann_or_pred: getter(ann_or_pred) == metadata_filter.value - elif op is FilterOp.NEQ: - return lambda ann_or_pred: getter(ann_or_pred) != metadata_filter.value - else: - raise RuntimeError( - f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {metadata_filter}," - ) - - -def filter_by_metadata_fields( - ann_or_pred: Union[ - List[AnnotationsWithMetadata], List[PredictionsWithMetadata] - ], - metadata_filter: Union[DNFFilter, List[MetadataFilter], List[List[List]]], -): - """ - Attributes: - ann_or_pred: Prediction or Annotation - metadata_filter: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like - [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field - predicates. The innermost structures each describe a single column predicate. The list of inner predicates is - interpreted as a conjunction (AND), forming a more selective and multiple column predicate. - Finally, the most outer list combines these filters as a disjunction (OR). - """ - if metadata_filter is None or len(metadata_filter) == 0: - return ann_or_pred - - if isinstance(metadata_filter[0], MetadataFilter) or isinstance( - metadata_filter[0], FieldFilter - ): - # Normalize into DNF - metadata_filter: DNFFilter = [metadata_filter] - # NOTE: We have to handle JSON transformed tuples which become three layers of lists - if ( - isinstance(metadata_filter, list) - and isinstance(metadata_filter[0], list) - and isinstance(metadata_filter[0][0], list) - ): - formatted_filter = [] - for or_branch in metadata_filter: - and_chain = [ - AnnotationOrPredictionFilter(*condition) - for condition in or_branch - ] - formatted_filter.append(and_chain) - metadata_filter = formatted_filter - - filtered = [] - for item in ann_or_pred: - for or_branch in metadata_filter: - and_conditions = ( - filter_to_comparison_function(cond) for cond in or_branch - ) - if all(c(item) for c in and_conditions): - filtered.append(item) - break - return filtered - - class CuboidMetric(Metric): """Abstract class for metrics of cuboids. @@ -195,24 +29,30 @@ def __init__( self, enforce_label_match: bool = False, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, ): """Initializes CuboidMetric abstract object. Args: enforce_label_match: whether to enforce that annotation and prediction labels must match. Default False confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 - annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like - [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field - predicates. The innermost structures each describe a single column predicate. The list of inner predicates is - interpreted as a conjunction (AND), forming a more selective and multiple column predicate. - Finally, the most outer list combines these filters as a disjunction (OR). - prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like - [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field - predicates. The innermost structures each describe a single column predicate. The list of inner predicates is - interpreted as a conjunction (AND), forming a more selective and multiple column predicate. - Finally, the most outer list combines these filters as a disjunction (OR). + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), + like [[MetadataFilter('x', '==', 0), FieldFilter('label', '==', 'pedestrian')], ...]. + DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures + each describe a single field predicate. The list of inner predicates is interpreted as a conjunction + (AND), forming a more selective and multiple column predicate. Finally, the most outer list combines + these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), + like [[MetadataFilter('x', '==', 0), FieldFilter('label', '==', 'pedestrian')], ...]. + DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures + each describe a single field predicate. The list of inner predicates is interpreted as a conjunction + (AND), forming a more selective and multiple column predicate. Finally, the most outer list combines + these filters as a disjunction (OR). """ self.enforce_label_match = enforce_label_match assert 0 <= confidence_threshold <= 1 @@ -245,11 +85,11 @@ def __call__( cuboid_predictions.extend(predictions.cuboid_predictions) eval_fn = label_match_wrapper(self.eval) - cuboid_annotations = filter_by_metadata_fields( - cuboid_annotations, self.annotation_filters + cuboid_annotations = apply_filters( + cuboid_annotations, self.annotation_filters # type: ignore ) - cuboid_predictions = filter_by_metadata_fields( - cuboid_predictions, self.prediction_filters + cuboid_predictions = apply_filters( + cuboid_predictions, self.prediction_filters # type: ignore ) result = eval_fn( cuboid_annotations, @@ -269,8 +109,12 @@ def __init__( iou_threshold: float = 0.0, confidence_threshold: float = 0.0, iou_2d: bool = False, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, ): """Initializes CuboidIOU object. @@ -331,8 +175,12 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, ): """Initializes CuboidIOU object. @@ -386,8 +234,12 @@ def __init__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, ): """Initializes CuboidIOU object. diff --git a/nucleus/metrics/filtering.py b/nucleus/metrics/filtering.py new file mode 100644 index 00000000..f419abf5 --- /dev/null +++ b/nucleus/metrics/filtering.py @@ -0,0 +1,264 @@ +import enum +import functools +from enum import Enum +from typing import Callable, List, NamedTuple, Sequence, Union + +from nucleus.annotation import ( + BoxAnnotation, + CategoryAnnotation, + CuboidAnnotation, + LineAnnotation, + MultiCategoryAnnotation, + PolygonAnnotation, +) +from nucleus.prediction import ( + BoxPrediction, + CategoryPrediction, + CuboidPrediction, + LinePrediction, + PolygonPrediction, +) + + +class FilterOp(str, Enum): + GT = ">" + GTE = ">=" + LT = "<" + LTE = "<=" + EQ = "=" + EQEQ = "==" + NEQ = "!=" + + +class FilterType(str, enum.Enum): + """The type of the filter decides the getter used for the comparison. + Attributes: + FIELD: Access the attribute field of an object + METADATA: Access the metadata dictionary of an object + """ + + FIELD = "field" + METADATA = "metadata" + + +class AnnotationOrPredictionFilter(NamedTuple): + key: str + op: Union[FilterOp, str] + value: Union[str, float, int] + allow_missing: bool + type: FilterType + + +class FieldFilter(NamedTuple): + key: str + op: Union[FilterOp, str] + value: Union[str, float, int] + allow_missing: bool = False + type: FilterType = FilterType.FIELD + + +class MetadataFilter(NamedTuple): + key: str + op: Union[FilterOp, str] + value: Union[str, float, int] + allow_missing: bool = False + type: FilterType = FilterType.METADATA + + +Filter = Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter] +ListOfOrAndFilters = List[List[Union[Filter]]] +ListOfOrAndFilters.__doc__ = """\ +Disjunctive normal form (DNF) filters. +DNF allows arbitrary boolean logical combinations of single field predicates. +The innermost structures each describe a single field predicate. + +The list of inner predicates is interpreted as a conjunction (AND), forming a more selective and multiple column +predicate. + +Finally, the most outer list combines these filters as a disjunction (OR). +""" +ListOfAndFilters = List[ + Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter] +] + +AnnotationsWithMetadata = Union[ + BoxAnnotation, + CategoryAnnotation, + CuboidAnnotation, + LineAnnotation, + MultiCategoryAnnotation, + PolygonAnnotation, +] +PredictionsWithMetadata = Union[ + BoxPrediction, + CategoryPrediction, + CuboidPrediction, + LinePrediction, + PolygonPrediction, +] + + +def _attribute_getter( + field_name: str, + allow_missing: bool, + ann_or_pred: Union[AnnotationsWithMetadata, PredictionsWithMetadata], +): + """Create a function to get object fields""" + if allow_missing: + return ( + getattr(ann_or_pred, field_name) + if hasattr(ann_or_pred, field_name) + else AlwaysFalseComparison() + ) + else: + return getattr(ann_or_pred, field_name) + + +class AlwaysFalseComparison: + """Helper class to make sure that allow filtering out missing fields (by always returning a false comparison)""" + + def __gt__(self, other): + return False + + def __ge__(self, other): + return False + + def __lt__(self, other): + return False + + def __le__(self, other): + return False + + def __eq__(self, other): + return False + + def __ne__(self, other): + return False + + +def _metadata_field_getter( + field_name: str, + allow_missing: bool, + ann_or_pred: Union[AnnotationsWithMetadata, PredictionsWithMetadata], +): + """Create a function to get a metadata field""" + if allow_missing: + return ( + ann_or_pred.metadata.get(field_name, AlwaysFalseComparison()) + if ann_or_pred.metadata + else AlwaysFalseComparison() + ) + else: + return ( + ann_or_pred.metadata[field_name] + if ann_or_pred.metadata + else RuntimeError( + f"No metadata on {ann_or_pred}, trying to access {field_name}" + ) + ) + + +def _filter_to_comparison_function( + filter_def: Filter, +) -> Callable[[Union[AnnotationsWithMetadata, PredictionsWithMetadata]], bool]: + """Creates a comparison function from a filter configuration to apply to annotations or predictions + + Parameters: + filter_def: Definition of a filter conditions + + Returns: + + """ + if FilterType(filter_def.type) == FilterType.FIELD: + getter = functools.partial( + _attribute_getter, filter_def.key, filter_def.allow_missing + ) + elif FilterType(filter_def.type) == FilterType.METADATA: + getter = functools.partial( + _metadata_field_getter, filter_def.key, filter_def.allow_missing + ) + op = FilterOp(filter_def.op) + if op is FilterOp.GT: + return lambda ann_or_pred: getter(ann_or_pred) > filter_def.value + elif op is FilterOp.GTE: + return lambda ann_or_pred: getter(ann_or_pred) >= filter_def.value + elif op is FilterOp.LT: + return lambda ann_or_pred: getter(ann_or_pred) < filter_def.value + elif op is FilterOp.LTE: + return lambda ann_or_pred: getter(ann_or_pred) <= filter_def.value + elif op is FilterOp.EQ or op is FilterOp.EQEQ: + return lambda ann_or_pred: getter(ann_or_pred) == filter_def.value + elif op is FilterOp.NEQ: + return lambda ann_or_pred: getter(ann_or_pred) != filter_def.value + else: + raise RuntimeError( + f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {filter_def}," + ) + + +def apply_filters( + ann_or_pred: Union[ + Sequence[AnnotationsWithMetadata], Sequence[PredictionsWithMetadata] + ], + filters: Union[ + ListOfOrAndFilters, ListOfAndFilters, List[List[List]], List[List] + ], +): + """Apply filters to list of annotations or list of predictions + Attributes: + ann_or_pred: Prediction or Annotation + filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + if filters is None or len(filters) == 0: + return ann_or_pred + + filters = ensureListOfOrAndFilters(filters) + + dnf_condition_functions = [] + for or_branch in filters: + and_conditions = [ + _filter_to_comparison_function(cond) for cond in or_branch + ] + dnf_condition_functions.append(and_conditions) + + filtered = [] + for item in ann_or_pred: + for or_conditions in dnf_condition_functions: + if all(c(item) for c in or_conditions): + filtered.append(item) + break + return filtered + + +def ensureListOfOrAndFilters(filters) -> ListOfOrAndFilters: + """JSON encoding creates a triple nested lists from the doubly nested tuples. This function creates the + tuple form again.""" + if isinstance(filters[0], (MetadataFilter, FieldFilter)): + # Normalize into DNF + filters: ListOfOrAndFilters = [filters] # type: ignore + + # NOTE: We have to handle JSON transformed tuples which become two or three layers of lists + if ( + isinstance(filters, list) + and isinstance(filters[0], list) + and isinstance(filters[0][0], str) + ): + filters = [filters] + if ( + isinstance(filters, list) + and isinstance(filters[0], list) + and isinstance(filters[0][0], list) + ): + formatted_filter = [] + for or_branch in filters: + and_chain = [ + AnnotationOrPredictionFilter(*condition) + for condition in or_branch + ] + formatted_filter.append(and_chain) + filters = formatted_filter + return filters diff --git a/nucleus/validate/data_transfer_objects/eval_function.py b/nucleus/validate/data_transfer_objects/eval_function.py index 5f7727b1..2592be3d 100644 --- a/nucleus/validate/data_transfer_objects/eval_function.py +++ b/nucleus/validate/data_transfer_objects/eval_function.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, validator +from pydantic import validator from ...pydantic_base import ImmutableModel from ..constants import ThresholdComparison diff --git a/nucleus/validate/data_transfer_objects/scenario_test.py b/nucleus/validate/data_transfer_objects/scenario_test.py index 2232efb7..4e029867 100644 --- a/nucleus/validate/data_transfer_objects/scenario_test.py +++ b/nucleus/validate/data_transfer_objects/scenario_test.py @@ -4,8 +4,6 @@ from nucleus.pydantic_base import ImmutableModel -from .eval_function import EvalFunctionEntry - class EvalFunctionListEntry(ImmutableModel): id: str diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index d2dc7efc..d5a83910 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -6,7 +6,7 @@ EvalFunctionConfig, ) -from ...metrics.cuboid_metrics import DNFFilter +from ...metrics.filtering import ListOfOrAndFilters from ..data_transfer_objects.eval_function import EvalFunctionEntry from ..errors import EvalFunctionNotAvailableError @@ -169,8 +169,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, **kwargs, ): """Configure a call to CuboidIOU object. @@ -211,8 +211,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, **kwargs, ): """Configure a call to CuboidIOU object. @@ -254,8 +254,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, **kwargs, ): """Configure a call to CuboidPrecision object. @@ -296,8 +296,8 @@ def __call__( enforce_label_match: bool = True, iou_threshold: float = 0.0, confidence_threshold: float = 0.0, - annotation_filters: Optional[DNFFilter] = None, - prediction_filters: Optional[DNFFilter] = None, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, **kwargs, ): """Configure a call to a CuboidRecall object. diff --git a/pyproject.toml b/pyproject.toml index 9c518db4..93cfd03d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.9b3" +version = "0.9b4" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/helpers.py b/tests/helpers.py index 7c1205e8..6e16d147 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -172,6 +172,7 @@ def reference_id_from_url(url): "height": 80 + i * 10, "reference_id": reference_id_from_url(TEST_IMG_URLS[i]), "annotation_id": f"[Pytest] Box Annotation Annotation Id{i}", + "metadata": {"field_1": "string", "index": i}, } for i in range(len(TEST_IMG_URLS)) ] diff --git a/tests/metrics/test_filtering.py b/tests/metrics/test_filtering.py new file mode 100644 index 00000000..ccd80e6d --- /dev/null +++ b/tests/metrics/test_filtering.py @@ -0,0 +1,238 @@ +import json + +import pytest + +from nucleus.metrics import FieldFilter, MetadataFilter, apply_filters +from tests.metrics.helpers import ( + TEST_BOX_ANNOTATION_LIST, + TEST_BOX_PREDICTION_LIST, +) + + +@pytest.fixture( + params=[ + TEST_BOX_ANNOTATION_LIST.box_annotations, + TEST_BOX_PREDICTION_LIST.box_predictions, + ] +) +def annotations_or_predictions(request): + yield request.param + + +def test_filter_field(annotations_or_predictions): + dnf_filters = [ + FieldFilter("label", "==", annotations_or_predictions[0].label) + ] + filtered = apply_filters(annotations_or_predictions, dnf_filters) + assert filtered == [annotations_or_predictions[0]] + + +def test_filter_metadata_field(annotations_or_predictions): + dnf_filters = [ + MetadataFilter( + "index", "==", annotations_or_predictions[0].metadata["index"] + ) + ] + filtered = apply_filters(annotations_or_predictions, dnf_filters) + assert filtered == [annotations_or_predictions[0]] + + +def test_only_and(annotations_or_predictions): + and_filters = [ + MetadataFilter( + "index", "==", annotations_or_predictions[0].metadata["index"] + ) + ] + or_filters = [and_filters] + filtered_and = apply_filters(annotations_or_predictions, and_filters) + filtered_or = apply_filters(annotations_or_predictions, or_filters) + assert filtered_and == filtered_or + + +def test_json_encoded_filters(annotations_or_predictions): + filters = [ + [MetadataFilter("index", ">", 0), MetadataFilter("index", "<", 4)] + ] + expected = apply_filters(annotations_or_predictions, filters) + json_string = json.dumps(filters) + filters_from_json = json.loads(json_string) + json_filtered = apply_filters( + annotations_or_predictions, filters_from_json + ) + assert json_filtered == expected + + +def test_json_encoded_and_filters(annotations_or_predictions): + filters = [ + MetadataFilter("index", ">", 0), + MetadataFilter("index", "<", 4), + ] + expected = apply_filters(annotations_or_predictions, filters) + json_string = json.dumps(filters) + filters_from_json = json.loads(json_string) + json_filtered = apply_filters( + annotations_or_predictions, filters_from_json + ) + assert json_filtered == expected + + +def test_or_branches(annotations_or_predictions): + index_0_or_2 = [ + [MetadataFilter("index", "==", 0)], + [MetadataFilter("index", "==", 2)], + ] + filtered = apply_filters(annotations_or_predictions, index_0_or_2) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_only_one_or(annotations_or_predictions): + later_matches = [ + [MetadataFilter("index", "==", -1)], + [MetadataFilter("index", "==", 2)], + ] + filtered = apply_filters(annotations_or_predictions, later_matches) + assert filtered == [ + annotations_or_predictions[2], + ] + + +def test_and_branches(annotations_or_predictions): + index_0_or_2 = [ + [ + MetadataFilter("index", "==", 0), + FieldFilter("label", "==", annotations_or_predictions[0].label), + ], + [ + MetadataFilter("index", "==", 2), + FieldFilter("label", "==", annotations_or_predictions[2].label), + ], + ] + filtered = apply_filters(annotations_or_predictions, index_0_or_2) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_multi_or(annotations_or_predictions): + all_match = [ + [MetadataFilter("index", "==", i)] + for i in range(len(annotations_or_predictions)) + ] + filtered = apply_filters(annotations_or_predictions, all_match) + assert filtered == annotations_or_predictions + + +def test_missing_field_raises(annotations_or_predictions): + missing_field = [[FieldFilter("i_dont_exist", "==", 1)]] + with pytest.raises(AttributeError): + apply_filters(annotations_or_predictions, missing_field) + + +def test_allow_missing_field(annotations_or_predictions): + missing_field = [ + [FieldFilter("i_dont_exist", "==", 1, allow_missing=True)] + ] + filtered = apply_filters(annotations_or_predictions, missing_field) + assert filtered == [] + + +def test_missing_metadata_raises(annotations_or_predictions): + missing_field = [[MetadataFilter("i_dont_exist", "==", 1)]] + with pytest.raises(KeyError): + apply_filters(annotations_or_predictions, missing_field) + + +def test_allow_missing_metadata_field(annotations_or_predictions): + missing_field = [ + [FieldFilter("i_dont_exist", "==", 1, allow_missing=True)] + ] + filtered = apply_filters(annotations_or_predictions, missing_field) + assert filtered == [] + + +def test_gt_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", ">", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_gt_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", ">", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_gte_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", ">=", 1)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_gte_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", ">=", annotations_or_predictions[1].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_lt_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "<", 1)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:1] + + +def test_lt_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "<", annotations_or_predictions[1].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:1] + + +def test_lte_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "<=", 1)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:2] + + +def test_lte_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "<=", annotations_or_predictions[1].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:2] + + +def test_eqeq_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "==", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_eqeq_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "==", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_eq_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "=", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_eq_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "=", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_neq_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "!=", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_neq_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "!=", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] From 894894536226304db803fd87365125e6869e0897 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Thu, 31 Mar 2022 15:02:22 +0200 Subject: [PATCH 32/33] Add in and not in statements --- nucleus/metrics/filtering.py | 29 +++++++++++++++++++---- tests/metrics/test_filtering.py | 42 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/nucleus/metrics/filtering.py b/nucleus/metrics/filtering.py index f419abf5..3713b82d 100644 --- a/nucleus/metrics/filtering.py +++ b/nucleus/metrics/filtering.py @@ -1,7 +1,7 @@ import enum import functools from enum import Enum -from typing import Callable, List, NamedTuple, Sequence, Union +from typing import Callable, Iterable, List, NamedTuple, Sequence, Set, Union from nucleus.annotation import ( BoxAnnotation, @@ -28,6 +28,8 @@ class FilterOp(str, Enum): EQ = "=" EQEQ = "==" NEQ = "!=" + IN = "in" + NOT_IN = "not in" class FilterType(str, enum.Enum): @@ -41,10 +43,19 @@ class FilterType(str, enum.Enum): METADATA = "metadata" +FilterableBaseVals = Union[str, float, int] +FilterableValues = Union[ + FilterableBaseVals, + Sequence[FilterableBaseVals], + Set[FilterableBaseVals], + Iterable[FilterableBaseVals], +] + + class AnnotationOrPredictionFilter(NamedTuple): key: str op: Union[FilterOp, str] - value: Union[str, float, int] + value: FilterableValues allow_missing: bool type: FilterType @@ -52,7 +63,7 @@ class AnnotationOrPredictionFilter(NamedTuple): class FieldFilter(NamedTuple): key: str op: Union[FilterOp, str] - value: Union[str, float, int] + value: FilterableValues allow_missing: bool = False type: FilterType = FilterType.FIELD @@ -60,7 +71,7 @@ class FieldFilter(NamedTuple): class MetadataFilter(NamedTuple): key: str op: Union[FilterOp, str] - value: Union[str, float, int] + value: FilterableValues allow_missing: bool = False type: FilterType = FilterType.METADATA @@ -158,7 +169,7 @@ def _metadata_field_getter( ) -def _filter_to_comparison_function( +def _filter_to_comparison_function( # pylint: disable=too-many-return-statements filter_def: Filter, ) -> Callable[[Union[AnnotationsWithMetadata, PredictionsWithMetadata]], bool]: """Creates a comparison function from a filter configuration to apply to annotations or predictions @@ -190,6 +201,14 @@ def _filter_to_comparison_function( return lambda ann_or_pred: getter(ann_or_pred) == filter_def.value elif op is FilterOp.NEQ: return lambda ann_or_pred: getter(ann_or_pred) != filter_def.value + elif op is FilterOp.IN: + return lambda ann_or_pred: getter(ann_or_pred) in set( + filter_def.value # type: ignore + ) + elif op is FilterOp.NOT_IN: + return lambda ann_or_pred: getter(ann_or_pred) not in set( + filter_def.value # type:ignore + ) else: raise RuntimeError( f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {filter_def}," diff --git a/tests/metrics/test_filtering.py b/tests/metrics/test_filtering.py index ccd80e6d..ac8c4119 100644 --- a/tests/metrics/test_filtering.py +++ b/tests/metrics/test_filtering.py @@ -236,3 +236,45 @@ def test_neq_field(annotations_or_predictions): valid_gt = [FieldFilter("x", "!=", annotations_or_predictions[0].x)] filtered = apply_filters(annotations_or_predictions, valid_gt) assert filtered == annotations_or_predictions[1:] + + +def test_in_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "in", [0, 2])] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_in_field(annotations_or_predictions): + valid_gt = [ + FieldFilter( + "x", + "in", + [annotations_or_predictions[0].x, annotations_or_predictions[2].x], + ) + ] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_not_in_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "not in", [0, 1])] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[2:] + + +def test_not_in_field(annotations_or_predictions): + valid_gt = [ + FieldFilter( + "x", + "not in", + [annotations_or_predictions[0].x, annotations_or_predictions[1].x], + ) + ] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[2:] From a980d4cc90b622502ab713fc56766a3f63901e77 Mon Sep 17 00:00:00 2001 From: Gunnar Atli Thoroddsen Date: Thu, 31 Mar 2022 15:28:44 +0200 Subject: [PATCH 33/33] Fix rebase error with conftest.py --- conftest.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/conftest.py b/conftest.py index b2cdfef1..90946313 100644 --- a/conftest.py +++ b/conftest.py @@ -39,13 +39,6 @@ def model(CLIENT): CLIENT.delete_model(model.id) -@pytest.fixture() -def model(CLIENT): - model = CLIENT.create_model(TEST_DATASET_NAME, "fake_reference_id") - yield model - CLIENT.delete_model(model.id) - - if __name__ == "__main__": client = nucleus.NucleusClient(API_KEY) # ds = client.create_dataset("Test Dataset With Autotags")