diff --git a/nucleus/metrics/__init__.py b/nucleus/metrics/__init__.py index 1fd038a2..460561f7 100644 --- a/nucleus/metrics/__init__.py +++ b/nucleus/metrics/__init__.py @@ -1,5 +1,12 @@ from .base import Metric, ScalarResult from .categorization_metrics import CategorizationF1 +from .cuboid_metrics import CuboidIOU, CuboidPrecision, CuboidRecall +from .filtering import ( + FieldFilter, + ListOfOrAndFilters, + MetadataFilter, + apply_filters, +) from .polygon_metrics import ( PolygonAveragePrecision, PolygonIOU, diff --git a/nucleus/metrics/categorization_metrics.py b/nucleus/metrics/categorization_metrics.py index 416f831a..80979c8e 100644 --- a/nucleus/metrics/categorization_metrics.py +++ b/nucleus/metrics/categorization_metrics.py @@ -143,7 +143,8 @@ def __init__( ): """ Args: - confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. + Must be in [0, 1]. Default 0.0 f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \ default='macro' This parameter is required for multiclass/multilabel targets. diff --git a/nucleus/metrics/cuboid_metrics.py b/nucleus/metrics/cuboid_metrics.py new file mode 100644 index 00000000..313ef759 --- /dev/null +++ b/nucleus/metrics/cuboid_metrics.py @@ -0,0 +1,274 @@ +import sys +from abc import abstractmethod +from typing import List, Optional, Union + +from nucleus.annotation import AnnotationList, CuboidAnnotation +from nucleus.prediction import CuboidPrediction, PredictionList + +from .base import Metric, ScalarResult +from .cuboid_utils import detection_iou, label_match_wrapper, recall_precision +from .filtering import ListOfAndFilters, ListOfOrAndFilters, apply_filters +from .filters import confidence_filter + + +class CuboidMetric(Metric): + """Abstract class for metrics of cuboids. + + The CuboidMetric class automatically filters incoming annotations and + predictions for only cuboid annotations. It also filters + predictions whose confidence is less than the provided confidence_threshold. + Finally, it provides support for enforcing matching labels. If + `enforce_label_match` is set to True, then annotations and predictions will + only be matched if they have the same label. + + To create a new concrete CuboidMetric, override the `eval` function + with logic to define a metric between cuboid annotations and predictions. + """ + + def __init__( + self, + enforce_label_match: bool = False, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + ): + """Initializes CuboidMetric abstract object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Default False + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), + like [[MetadataFilter('x', '==', 0), FieldFilter('label', '==', 'pedestrian')], ...]. + DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures + each describe a single field predicate. The list of inner predicates is interpreted as a conjunction + (AND), forming a more selective and multiple column predicate. Finally, the most outer list combines + these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), + like [[MetadataFilter('x', '==', 0), FieldFilter('label', '==', 'pedestrian')], ...]. + DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures + each describe a single field predicate. The list of inner predicates is interpreted as a conjunction + (AND), forming a more selective and multiple column predicate. Finally, the most outer list combines + these filters as a disjunction (OR). + """ + self.enforce_label_match = enforce_label_match + assert 0 <= confidence_threshold <= 1 + self.confidence_threshold = confidence_threshold + self.annotation_filters = annotation_filters + self.prediction_filters = prediction_filters + + @abstractmethod + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + # Main evaluation function that subclasses must override. + pass + + def aggregate_score(self, results: List[ScalarResult]) -> ScalarResult: # type: ignore[override] + return ScalarResult.aggregate(results) + + def __call__( + self, annotations: AnnotationList, predictions: PredictionList + ) -> ScalarResult: + if self.confidence_threshold > 0: + predictions = confidence_filter( + predictions, self.confidence_threshold + ) + cuboid_annotations: List[CuboidAnnotation] = [] + cuboid_annotations.extend(annotations.cuboid_annotations) + cuboid_predictions: List[CuboidPrediction] = [] + cuboid_predictions.extend(predictions.cuboid_predictions) + + eval_fn = label_match_wrapper(self.eval) + cuboid_annotations = apply_filters( + cuboid_annotations, self.annotation_filters # type: ignore + ) + cuboid_predictions = apply_filters( + cuboid_predictions, self.prediction_filters # type: ignore + ) + result = eval_fn( + cuboid_annotations, + cuboid_predictions, + enforce_label_match=self.enforce_label_match, + ) + return result + + +class CuboidIOU(CuboidMetric): + """Calculates the average IOU between cuboid annotations and predictions.""" + + # TODO: Remove defaults once these are surfaced more cleanly to users. + def __init__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + iou_2d: bool = False, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + ): + """Initializes CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + assert ( + 0 <= iou_threshold <= 1 + ), "IoU threshold must be between 0 and 1." + self.iou_threshold = iou_threshold + self.iou_2d = iou_2d + super().__init__( + enforce_label_match=enforce_label_match, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + ) + + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + iou_3d_metric, iou_2d_metric = detection_iou( + predictions, + annotations, + threshold_in_overlap_ratio=self.iou_threshold, + ) + + weight = max(len(annotations), len(predictions)) + if self.iou_2d: + avg_iou = iou_2d_metric.sum() / max(weight, sys.float_info.epsilon) + else: + avg_iou = iou_3d_metric.sum() / max(weight, sys.float_info.epsilon) + + return ScalarResult(avg_iou, weight) + + +class CuboidPrecision(CuboidMetric): + """Calculates the average precision between cuboid annotations and predictions.""" + + # TODO: Remove defaults once these are surfaced more cleanly to users. + def __init__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + ): + """Initializes CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + assert ( + 0 <= iou_threshold <= 1 + ), "IoU threshold must be between 0 and 1." + self.iou_threshold = iou_threshold + super().__init__( + enforce_label_match=enforce_label_match, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + ) + + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + stats = recall_precision( + predictions, + annotations, + threshold_in_overlap_ratio=self.iou_threshold, + ) + weight = stats["tp_sum"] + stats["fp_sum"] + precision = stats["tp_sum"] / max(weight, sys.float_info.epsilon) + return ScalarResult(precision, weight) + + +class CuboidRecall(CuboidMetric): + """Calculates the average recall between cuboid annotations and predictions.""" + + # TODO: Remove defaults once these are surfaced more cleanly to users. + def __init__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + prediction_filters: Optional[ + Union[ListOfOrAndFilters, ListOfAndFilters] + ] = None, + ): + """Initializes CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + assert ( + 0 <= iou_threshold <= 1 + ), "IoU threshold must be between 0 and 1." + self.iou_threshold = iou_threshold + super().__init__( + enforce_label_match=enforce_label_match, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + ) + + def eval( + self, + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + ) -> ScalarResult: + stats = recall_precision( + predictions, + annotations, + threshold_in_overlap_ratio=self.iou_threshold, + ) + weight = stats["tp_sum"] + stats["fn_sum"] + recall = stats["tp_sum"] / max(weight, sys.float_info.epsilon) + return ScalarResult(recall, weight) diff --git a/nucleus/metrics/cuboid_utils.py b/nucleus/metrics/cuboid_utils.py new file mode 100644 index 00000000..0ebf0716 --- /dev/null +++ b/nucleus/metrics/cuboid_utils.py @@ -0,0 +1,355 @@ +from functools import wraps +from typing import Dict, List, Tuple + +import numpy as np +from shapely.geometry import Polygon + +from nucleus.annotation import CuboidAnnotation +from nucleus.prediction import CuboidPrediction + +from .base import ScalarResult + + +def group_cuboids_by_label( + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], +) -> Dict[str, Tuple[List[CuboidAnnotation], List[CuboidPrediction]]]: + """Groups input annotations and predictions by label. + + Args: + annotations: list of input cuboid annotations + predictions: list of input cuboid predictions + + Returns: + Mapping from each label to (annotations, predictions) tuple + """ + labels = set(annotation.label for annotation in annotations) + labels |= set(prediction.label for prediction in predictions) + grouped: Dict[ + str, Tuple[List[CuboidAnnotation], List[CuboidPrediction]] + ] = {label: ([], []) for label in labels} + for annotation in annotations: + grouped[annotation.label][0].append(annotation) + for prediction in predictions: + grouped[prediction.label][1].append(prediction) + return grouped + + +def label_match_wrapper(metric_fn): + """Decorator to add the ability to only apply metric to annotations and + predictions with matching labels. + + Args: + metric_fn: Metric function that takes a list of annotations, a list + of predictions, and optional args and kwargs. + + Returns: + Metric function which can optionally enforce matching labels. + """ + + @wraps(metric_fn) + def wrapper( + annotations: List[CuboidAnnotation], + predictions: List[CuboidPrediction], + *args, + enforce_label_match: bool = False, + **kwargs, + ) -> ScalarResult: + # Simply return the metric if we are not enforcing label matches. + if not enforce_label_match: + return metric_fn(annotations, predictions, *args, **kwargs) + + # For each bin of annotations/predictions, compute the metric applied + # only to that bin. Then aggregate results across all bins. + grouped_inputs = group_cuboids_by_label(annotations, predictions) + metric_results = [] + for binned_annotations, binned_predictions in grouped_inputs.values(): + metric_result = metric_fn( + binned_annotations, binned_predictions, *args, **kwargs + ) + metric_results.append(metric_result) + assert all( + isinstance(r, ScalarResult) for r in metric_results + ), "Expected every result to be a ScalarResult" + return ScalarResult.aggregate(metric_results) + + return wrapper + + +def process_dataitem(dataitem): + processed_item = {} + processed_item["xyz"] = np.array( + [[ann.position.x, ann.position.y, ann.position.z] for ann in dataitem] + ) + processed_item["wlh"] = np.array( + [ + [ann.dimensions.x, ann.dimensions.y, ann.dimensions.z] + for ann in dataitem + ] + ) + processed_item["yaw"] = np.array([ann.yaw for ann in dataitem]) + return processed_item + + +def compute_outer_iou( + xyz_0: np.ndarray, + wlh_0: np.ndarray, + yaw_0: np.ndarray, + xyz_1: np.ndarray, + wlh_1: np.ndarray, + yaw_1: np.ndarray, + scale_convention: bool = True, + distance_threshold=25, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Computes outer 3D and 2D IoU + :param xyz_0: (n, 3) + :param wlh_0: (n, 3) + :param yaw_0: (n,) + :param xyz_1: (m, 3) + :param wlh_1: (m, 3) + :param yaw_1: (m,) + :param scale_convention: flag whether the internal Scale convention is used (have to be adjusted by pi/2) + :param distance_threshold: computes iou only within this distance (~3x speedup) + :return: (n, m) 3D IoU, (n, m) 2D IoU + """ + + bottom_z = np.maximum.outer( + xyz_0[:, 2] - (wlh_0[:, 2] / 2), xyz_1[:, 2] - (wlh_1[:, 2] / 2) + ) + top_z = np.minimum.outer( + xyz_0[:, 2] + (wlh_0[:, 2] / 2), xyz_1[:, 2] + (wlh_1[:, 2] / 2) + ) + height_intersection = np.maximum(0, top_z - bottom_z) + + cuboid_corners_0 = get_batch_cuboid_corners( + xyz_0, wlh_0, yaw_0, scale_convention=scale_convention + ) + cuboid_corners_1 = get_batch_cuboid_corners( + xyz_1, wlh_1, yaw_1, scale_convention=scale_convention + ) + polygons_1 = [ + Polygon(corners_1[[1, 0, 4, 5, 1], :2]) + for corners_1 in cuboid_corners_1 + ] + area_intersection = np.zeros( + (cuboid_corners_0.shape[0], cuboid_corners_1.shape[0]), + dtype=np.float32, + ) + + if cuboid_corners_0.shape[0] != 0 and cuboid_corners_1.shape[0] != 0: + distance_mask = ( + np.linalg.norm( + xyz_0[:, np.newaxis, :] - xyz_1[np.newaxis, :, :], axis=2 + ) + < distance_threshold + ) + + for i, corners_0 in enumerate(cuboid_corners_0): + for j, polygon_1 in enumerate(polygons_1): + if distance_mask[i, j]: + area_intersection[i, j] = ( + Polygon(corners_0[[1, 0, 4, 5, 1], :2]) + .intersection(polygon_1) + .area + ) + + intersection = height_intersection * area_intersection + area_0 = wlh_0[:, 0] * wlh_0[:, 1] + area_1 = wlh_1[:, 0] * wlh_1[:, 1] + union_2d = np.add.outer(area_0, area_1) - area_intersection + + volume_0 = area_0 * wlh_0[:, 2] + volume_1 = area_1 * wlh_1[:, 2] + union = np.add.outer(volume_0, volume_1) - intersection + return intersection / union, area_intersection / union_2d + + +def get_batch_cuboid_corners( + xyz: np.ndarray, + wlh: np.ndarray, + yaw: np.ndarray, + pitch: np.ndarray = None, + roll: np.ndarray = None, + scale_convention: bool = True, +) -> np.ndarray: + """ + Vectorized batch version of get_cuboid_corners + :param xyz: (n, 3) + :param wlh: (n, 3) + :param yaw: (n,) + :param pitch: (n,) + :param roll: (n,) + :param scale_convention: flag whether the internal Scale convention is used (have to be adjusted by pi/2) + :return: (n, 8, 3) + """ + if scale_convention: + yaw = yaw.copy() + np.pi / 2 + + w, l, h = wlh[:, 0, None], wlh[:, 1, None], wlh[:, 2, None] + + x_corners = l / 2 * np.array([1, 1, 1, 1, -1, -1, -1, -1]) + y_corners = w / 2 * np.array([1, -1, -1, 1, 1, -1, -1, 1]) + z_corners = h / 2 * np.array([1, 1, -1, -1, 1, 1, -1, -1]) + corners = np.stack((x_corners, y_corners, z_corners), axis=1) + + rot_mats = get_batch_rotation_matrices(yaw, pitch, roll) + corners = np.matmul(rot_mats, corners) + + x, y, z = xyz[:, 0, None], xyz[:, 1, None], xyz[:, 2, None] + corners[:, 0, :] = corners[:, 0, :] + x + corners[:, 1, :] = corners[:, 1, :] + y + corners[:, 2, :] = corners[:, 2, :] + z + return corners.swapaxes(1, 2) + + +def get_batch_rotation_matrices( + yaw: np.ndarray, pitch: np.ndarray = None, roll: np.ndarray = None +) -> np.ndarray: + if pitch is None: + pitch = np.zeros_like(yaw) + if roll is None: + roll = np.zeros_like(yaw) + cy = np.cos(yaw) + sy = np.sin(yaw) + cp = np.cos(pitch) + sp = np.sin(pitch) + cr = np.cos(roll) + sr = np.sin(roll) + return np.stack( + ( + np.stack( + (cy * cp, cy * sp * sr - sy * cr, cy * sp * cr + sy * sr), 1 + ), + np.stack( + (sy * cp, sy * sp * sr + cy * cr, sy * sp * cr - cy * sr), 1 + ), + np.stack((-sp, cp * sr, cp * cr), 1), + ), + 1, + ) + + +def associate_cuboids_on_iou( + xyz_0: np.ndarray, + wlh_0: np.ndarray, + yaw_0: np.ndarray, + xyz_1: np.ndarray, + wlh_1: np.ndarray, + yaw_1: np.ndarray, + threshold_in_overlap_ratio: float = 0.1, +) -> List[Tuple[int, int]]: + if xyz_0.shape[0] < 1 or xyz_1.shape[0] < 1: + return [] + iou_matrix, _ = compute_outer_iou(xyz_0, wlh_0, yaw_0, xyz_1, wlh_1, yaw_1) + mapping = [] + for i, m in enumerate(iou_matrix.max(axis=1)): + if m >= threshold_in_overlap_ratio: + mapping.append((i, iou_matrix[i].argmax())) + return mapping + + +def recall_precision( + prediction: List[CuboidPrediction], + groundtruth: List[CuboidAnnotation], + threshold_in_overlap_ratio: float, +) -> Dict[str, float]: + """ + Calculates the precision and recall of each lidar frame. + + Args: + :param predictions: list of cuboid annotation predictions. + :param ground_truth: list of cuboid annotation groundtruths. + :param threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. + """ + + tp_sum = 0 + fp_sum = 0 + fn_sum = 0 + num_predicted = 0 + num_instances = 0 + + gt_items = process_dataitem(groundtruth) + pred_items = process_dataitem(prediction) + + num_predicted += pred_items["xyz"].shape[0] + num_instances += gt_items["xyz"].shape[0] + + tp = np.zeros(pred_items["xyz"].shape[0]) + fp = np.ones(pred_items["xyz"].shape[0]) + fn = np.ones(gt_items["xyz"].shape[0]) + + mapping = associate_cuboids_on_iou( + pred_items["xyz"], + pred_items["wlh"], + pred_items["yaw"] + np.pi / 2, + gt_items["xyz"], + gt_items["wlh"], + gt_items["yaw"] + np.pi / 2, + threshold_in_overlap_ratio=threshold_in_overlap_ratio, + ) + + for pred_id, gt_id in mapping: + if fn[gt_id] == 0: + continue + tp[pred_id] = 1 + fp[pred_id] = 0 + fn[gt_id] = 0 + + tp_sum += tp.sum() + fp_sum += fp.sum() + fn_sum += fn.sum() + + return { + "tp_sum": tp_sum, + "fp_sum": fp_sum, + "fn_sum": fn_sum, + "precision": tp_sum / (tp_sum + fp_sum), + "recall": tp_sum / (tp_sum + fn_sum), + "num_predicted": num_predicted, + "num_instances": num_instances, + } + + +def detection_iou( + prediction: List[CuboidPrediction], + groundtruth: List[CuboidAnnotation], + threshold_in_overlap_ratio: float, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculates the 2D IOU and 3D IOU overlap between predictions and groundtruth. + Uses linear sum assignment to associate cuboids. + + Args: + :param predictions: list of cuboid annotation predictions. + :param ground_truth: list of cuboid annotation groundtruths. + :param threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. + """ + + gt_items = process_dataitem(groundtruth) + pred_items = process_dataitem(prediction) + + meter_2d = [] + meter_3d = [] + + if gt_items["xyz"].shape[0] == 0 or pred_items["xyz"].shape[0] == 0: + return np.array([0.0]), np.array([0.0]) + + iou_3d, iou_2d = compute_outer_iou( + gt_items["xyz"], + gt_items["wlh"], + gt_items["yaw"], + pred_items["xyz"], + pred_items["wlh"], + pred_items["yaw"], + ) + + for i, m in enumerate(iou_3d.max(axis=1)): + if m >= threshold_in_overlap_ratio: + j = iou_3d[i].argmax() + meter_3d.append(iou_3d[i, j]) + meter_2d.append(iou_2d[i, j]) + + meter_3d = np.array(meter_3d) + meter_2d = np.array(meter_2d) + return meter_3d, meter_2d diff --git a/nucleus/metrics/filtering.py b/nucleus/metrics/filtering.py new file mode 100644 index 00000000..3713b82d --- /dev/null +++ b/nucleus/metrics/filtering.py @@ -0,0 +1,283 @@ +import enum +import functools +from enum import Enum +from typing import Callable, Iterable, List, NamedTuple, Sequence, Set, Union + +from nucleus.annotation import ( + BoxAnnotation, + CategoryAnnotation, + CuboidAnnotation, + LineAnnotation, + MultiCategoryAnnotation, + PolygonAnnotation, +) +from nucleus.prediction import ( + BoxPrediction, + CategoryPrediction, + CuboidPrediction, + LinePrediction, + PolygonPrediction, +) + + +class FilterOp(str, Enum): + GT = ">" + GTE = ">=" + LT = "<" + LTE = "<=" + EQ = "=" + EQEQ = "==" + NEQ = "!=" + IN = "in" + NOT_IN = "not in" + + +class FilterType(str, enum.Enum): + """The type of the filter decides the getter used for the comparison. + Attributes: + FIELD: Access the attribute field of an object + METADATA: Access the metadata dictionary of an object + """ + + FIELD = "field" + METADATA = "metadata" + + +FilterableBaseVals = Union[str, float, int] +FilterableValues = Union[ + FilterableBaseVals, + Sequence[FilterableBaseVals], + Set[FilterableBaseVals], + Iterable[FilterableBaseVals], +] + + +class AnnotationOrPredictionFilter(NamedTuple): + key: str + op: Union[FilterOp, str] + value: FilterableValues + allow_missing: bool + type: FilterType + + +class FieldFilter(NamedTuple): + key: str + op: Union[FilterOp, str] + value: FilterableValues + allow_missing: bool = False + type: FilterType = FilterType.FIELD + + +class MetadataFilter(NamedTuple): + key: str + op: Union[FilterOp, str] + value: FilterableValues + allow_missing: bool = False + type: FilterType = FilterType.METADATA + + +Filter = Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter] +ListOfOrAndFilters = List[List[Union[Filter]]] +ListOfOrAndFilters.__doc__ = """\ +Disjunctive normal form (DNF) filters. +DNF allows arbitrary boolean logical combinations of single field predicates. +The innermost structures each describe a single field predicate. + +The list of inner predicates is interpreted as a conjunction (AND), forming a more selective and multiple column +predicate. + +Finally, the most outer list combines these filters as a disjunction (OR). +""" +ListOfAndFilters = List[ + Union[FieldFilter, MetadataFilter, AnnotationOrPredictionFilter] +] + +AnnotationsWithMetadata = Union[ + BoxAnnotation, + CategoryAnnotation, + CuboidAnnotation, + LineAnnotation, + MultiCategoryAnnotation, + PolygonAnnotation, +] +PredictionsWithMetadata = Union[ + BoxPrediction, + CategoryPrediction, + CuboidPrediction, + LinePrediction, + PolygonPrediction, +] + + +def _attribute_getter( + field_name: str, + allow_missing: bool, + ann_or_pred: Union[AnnotationsWithMetadata, PredictionsWithMetadata], +): + """Create a function to get object fields""" + if allow_missing: + return ( + getattr(ann_or_pred, field_name) + if hasattr(ann_or_pred, field_name) + else AlwaysFalseComparison() + ) + else: + return getattr(ann_or_pred, field_name) + + +class AlwaysFalseComparison: + """Helper class to make sure that allow filtering out missing fields (by always returning a false comparison)""" + + def __gt__(self, other): + return False + + def __ge__(self, other): + return False + + def __lt__(self, other): + return False + + def __le__(self, other): + return False + + def __eq__(self, other): + return False + + def __ne__(self, other): + return False + + +def _metadata_field_getter( + field_name: str, + allow_missing: bool, + ann_or_pred: Union[AnnotationsWithMetadata, PredictionsWithMetadata], +): + """Create a function to get a metadata field""" + if allow_missing: + return ( + ann_or_pred.metadata.get(field_name, AlwaysFalseComparison()) + if ann_or_pred.metadata + else AlwaysFalseComparison() + ) + else: + return ( + ann_or_pred.metadata[field_name] + if ann_or_pred.metadata + else RuntimeError( + f"No metadata on {ann_or_pred}, trying to access {field_name}" + ) + ) + + +def _filter_to_comparison_function( # pylint: disable=too-many-return-statements + filter_def: Filter, +) -> Callable[[Union[AnnotationsWithMetadata, PredictionsWithMetadata]], bool]: + """Creates a comparison function from a filter configuration to apply to annotations or predictions + + Parameters: + filter_def: Definition of a filter conditions + + Returns: + + """ + if FilterType(filter_def.type) == FilterType.FIELD: + getter = functools.partial( + _attribute_getter, filter_def.key, filter_def.allow_missing + ) + elif FilterType(filter_def.type) == FilterType.METADATA: + getter = functools.partial( + _metadata_field_getter, filter_def.key, filter_def.allow_missing + ) + op = FilterOp(filter_def.op) + if op is FilterOp.GT: + return lambda ann_or_pred: getter(ann_or_pred) > filter_def.value + elif op is FilterOp.GTE: + return lambda ann_or_pred: getter(ann_or_pred) >= filter_def.value + elif op is FilterOp.LT: + return lambda ann_or_pred: getter(ann_or_pred) < filter_def.value + elif op is FilterOp.LTE: + return lambda ann_or_pred: getter(ann_or_pred) <= filter_def.value + elif op is FilterOp.EQ or op is FilterOp.EQEQ: + return lambda ann_or_pred: getter(ann_or_pred) == filter_def.value + elif op is FilterOp.NEQ: + return lambda ann_or_pred: getter(ann_or_pred) != filter_def.value + elif op is FilterOp.IN: + return lambda ann_or_pred: getter(ann_or_pred) in set( + filter_def.value # type: ignore + ) + elif op is FilterOp.NOT_IN: + return lambda ann_or_pred: getter(ann_or_pred) not in set( + filter_def.value # type:ignore + ) + else: + raise RuntimeError( + f"Fell through all op cases, no match for: '{op}' - MetadataFilter: {filter_def}," + ) + + +def apply_filters( + ann_or_pred: Union[ + Sequence[AnnotationsWithMetadata], Sequence[PredictionsWithMetadata] + ], + filters: Union[ + ListOfOrAndFilters, ListOfAndFilters, List[List[List]], List[List] + ], +): + """Apply filters to list of annotations or list of predictions + Attributes: + ann_or_pred: Prediction or Annotation + filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '=', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + if filters is None or len(filters) == 0: + return ann_or_pred + + filters = ensureListOfOrAndFilters(filters) + + dnf_condition_functions = [] + for or_branch in filters: + and_conditions = [ + _filter_to_comparison_function(cond) for cond in or_branch + ] + dnf_condition_functions.append(and_conditions) + + filtered = [] + for item in ann_or_pred: + for or_conditions in dnf_condition_functions: + if all(c(item) for c in or_conditions): + filtered.append(item) + break + return filtered + + +def ensureListOfOrAndFilters(filters) -> ListOfOrAndFilters: + """JSON encoding creates a triple nested lists from the doubly nested tuples. This function creates the + tuple form again.""" + if isinstance(filters[0], (MetadataFilter, FieldFilter)): + # Normalize into DNF + filters: ListOfOrAndFilters = [filters] # type: ignore + + # NOTE: We have to handle JSON transformed tuples which become two or three layers of lists + if ( + isinstance(filters, list) + and isinstance(filters[0], list) + and isinstance(filters[0][0], str) + ): + filters = [filters] + if ( + isinstance(filters, list) + and isinstance(filters[0], list) + and isinstance(filters[0][0], list) + ): + formatted_filter = [] + for or_branch in filters: + and_chain = [ + AnnotationOrPredictionFilter(*condition) + for condition in or_branch + ] + formatted_filter.append(and_chain) + filters = formatted_filter + return filters diff --git a/nucleus/validate/client.py b/nucleus/validate/client.py index d99afc07..43e2f7f1 100644 --- a/nucleus/validate/client.py +++ b/nucleus/validate/client.py @@ -5,12 +5,13 @@ from .constants import SCENARIO_TEST_ID_KEY from .data_transfer_objects.eval_function import GetEvalFunctions -from .data_transfer_objects.scenario_test import CreateScenarioTestRequest -from .errors import CreateScenarioTestError -from .eval_functions.available_eval_functions import ( - AvailableEvalFunctions, - EvalFunction, +from .data_transfer_objects.scenario_test import ( + CreateScenarioTestRequest, + EvalFunctionListEntry, ) +from .errors import CreateScenarioTestError +from .eval_functions.available_eval_functions import AvailableEvalFunctions +from .eval_functions.base_eval_function import EvalFunctionConfig from .scenario_test import ScenarioTest SUCCESS_KEY = "success" @@ -36,7 +37,8 @@ def eval_functions(self) -> AvailableEvalFunctions: import nucleus client = nucleus.NucleusClient("YOUR_SCALE_API_KEY") - scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 # Creates an EvaluationCriterion by comparison + # Creates an EvaluationCriterion by using a comparison op + scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 Returns: :class:`AvailableEvalFunctions`: A container for all the available eval functions @@ -51,7 +53,7 @@ def create_scenario_test( self, name: str, slice_id: str, - evaluation_functions: List[EvalFunction], + evaluation_functions: List[EvalFunctionConfig], ) -> ScenarioTest: """Creates a new Scenario Test from an existing Nucleus :class:`Slice`:. :: @@ -78,12 +80,16 @@ def create_scenario_test( "Must pass an evaluation_function to the scenario test! I.e. " "evaluation_functions=[client.validate.eval_functions.bbox_iou()]" ) + response = self.connection.post( CreateScenarioTestRequest( name=name, slice_id=slice_id, evaluation_functions=[ - ef.to_entry() for ef in evaluation_functions # type:ignore + EvalFunctionListEntry( + id=ef.id, eval_func_arguments=ef.eval_func_arguments + ) + for ef in evaluation_functions ], ).dict(), "validate/scenario_test", diff --git a/nucleus/validate/data_transfer_objects/eval_function.py b/nucleus/validate/data_transfer_objects/eval_function.py index 1e165de5..2592be3d 100644 --- a/nucleus/validate/data_transfer_objects/eval_function.py +++ b/nucleus/validate/data_transfer_objects/eval_function.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional from pydantic import validator @@ -50,12 +50,14 @@ class EvaluationCriterion(ImmutableModel): eval_function_id (str): ID of evaluation function threshold_comparison (:class:`ThresholdComparison`): comparator for evaluation. i.e. threshold=0.5 and threshold_comparator > implies that a test only passes if score > 0.5. threshold (float): numerical threshold that together with threshold comparison, defines success criteria for test evaluation. + eval_func_arguments: Arguments to pass to the eval function constructor """ # TODO: Having only eval_function_id hurts readability -> Add function name eval_function_id: str threshold_comparison: ThresholdComparison threshold: float + eval_func_arguments: Dict[str, Any] @validator("eval_function_id") def valid_eval_function_id(cls, v): # pylint: disable=no-self-argument diff --git a/nucleus/validate/data_transfer_objects/scenario_test.py b/nucleus/validate/data_transfer_objects/scenario_test.py index 174571ff..4e029867 100644 --- a/nucleus/validate/data_transfer_objects/scenario_test.py +++ b/nucleus/validate/data_transfer_objects/scenario_test.py @@ -4,13 +4,16 @@ from nucleus.pydantic_base import ImmutableModel -from .eval_function import EvalFunctionEntry + +class EvalFunctionListEntry(ImmutableModel): + id: str + eval_func_arguments: dict class CreateScenarioTestRequest(ImmutableModel): name: str slice_id: str - evaluation_functions: List[EvalFunctionEntry] + evaluation_functions: List[EvalFunctionListEntry] @validator("slice_id") def startswith_slice_indicator(cls, v): # pylint: disable=no-self-argument diff --git a/nucleus/validate/eval_functions/available_eval_functions.py b/nucleus/validate/eval_functions/available_eval_functions.py index 090e63f4..d5a83910 100644 --- a/nucleus/validate/eval_functions/available_eval_functions.py +++ b/nucleus/validate/eval_functions/available_eval_functions.py @@ -1,46 +1,396 @@ import itertools -from typing import Callable, Dict, List, Type, Union +from typing import Callable, Dict, List, Optional, Union from nucleus.logger import logger -from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction +from nucleus.validate.eval_functions.base_eval_function import ( + EvalFunctionConfig, +) +from ...metrics.filtering import ListOfOrAndFilters from ..data_transfer_objects.eval_function import EvalFunctionEntry from ..errors import EvalFunctionNotAvailableError MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes" -class BoundingBoxIOU(BaseEvalFunction): +class PolygonIOUConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = False, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonIOU` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_iou: BoundingBoxIOU = client.validate.eval_functions.bbox_iou + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_iou(confidence_threshold=0.8) > 0.5] + ) + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_iou" -class BoundingBoxMeanAveragePrecision(BaseEvalFunction): +class PolygonMAPConfig(EvalFunctionConfig): + def __call__( + self, + iou_threshold: float = 0.5, + **kwargs, + ): + """Configures a call to :class:`PolygonMAP` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_map: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_map + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_map(iou_threshold=0.6) > 0.8] + ) + + Args: + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + iou_threshold=iou_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_map" -class BoundingBoxRecall(BaseEvalFunction): +class PolygonRecallConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = False, + iou_threshold: float = 0.5, + confidence_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonRecall` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_recall: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_recall + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_recall(iou_threshold=0.6, confidence_threshold=0.4) > 0.9] + ) + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_recall" -class BoundingBoxPrecision(BaseEvalFunction): +class PolygonPrecisionConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = False, + iou_threshold: float = 0.5, + confidence_threshold: float = 0.0, + **kwargs, + ): + """Configures a call to :class:`PolygonPrecision` object. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + bbox_precision: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_precision + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[bbox_precision(iou_threshold=0.6, confidence_threshold=0.4) > 0.9] + ) + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + **kwargs, + ) + @classmethod def expected_name(cls) -> str: return "bbox_precision" -class CategorizationF1(BaseEvalFunction): +class CuboidIOU2DConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, + **kwargs, + ): + """Configure a call to CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + iou_2d=True, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_iou_2d" + + +class CuboidIOU3DConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, + **kwargs, + ): + """Configure a call to CuboidIOU object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + iou_2d=False, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_iou_3d" + + +class CuboidPrecisionConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, + **kwargs, + ): + """Configure a call to CuboidPrecision object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_precision" + + +class CuboidRecallConfig(EvalFunctionConfig): + def __call__( + self, + enforce_label_match: bool = True, + iou_threshold: float = 0.0, + confidence_threshold: float = 0.0, + annotation_filters: Optional[ListOfOrAndFilters] = None, + prediction_filters: Optional[ListOfOrAndFilters] = None, + **kwargs, + ): + """Configure a call to a CuboidRecall object. + + Args: + enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to True + iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0 + iou_2d: whether to return the BEV 2D IOU if true, or the 3D IOU if false. + confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0 + annotation_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + prediction_filters: MetadataFilter predicates. Predicates are expressed in disjunctive normal form (DNF), like + [[MetadataFilter('x', '==', 0), ...], ...]. DNF allows arbitrary boolean logical combinations of single field + predicates. The innermost structures each describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and multiple column predicate. + Finally, the most outer list combines these filters as a disjunction (OR). + """ + return super().__call__( + enforce_label_match=enforce_label_match, + iou_threshold=iou_threshold, + confidence_threshold=confidence_threshold, + annotation_filters=annotation_filters, + prediction_filters=prediction_filters, + **kwargs, + ) + + @classmethod + def expected_name(cls) -> str: + return "cuboid_recall" + + +class CategorizationF1Config(EvalFunctionConfig): + def __call__( + self, + confidence_threshold: Optional[float] = None, + f1_method: Optional[str] = None, + **kwargs, + ): + """ Configure an evaluation of :class:`CategorizationF1`. + :: + + import nucleus + + client = nucleus.NucleusClient(YOUR_SCALE_API_KEY) + cat_f1: CategorizationF1 = client.validate.eval_functions.cat_f1 + slice_id = "slc_" + scenario_test = client.validate.create_scenario_test( + "Example test", + slice_id=slice_id, + evaluation_criteria=[cat_f1(confidence_threshold=0.6, f1_method="weighted") > 0.7] + ) + + Args: + confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. + Must be in [0, 1]. Default 0.0 + f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \ + default='macro' + This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + """ + return super().__call__( + confidence_threshold=confidence_threshold, f1_method=f1_method + ) + @classmethod def expected_name(cls) -> str: return "cat_f1" -class CustomEvalFunction(BaseEvalFunction): +class CustomEvalFunction(EvalFunctionConfig): @classmethod def expected_name(cls) -> str: raise NotImplementedError( @@ -48,7 +398,7 @@ def expected_name(cls) -> str: ) # Placeholder: See super().eval_func_entry for actual name -class StandardEvalFunction(BaseEvalFunction): +class StandardEvalFunction(EvalFunctionConfig): """Class for standard Model CI eval functions that have not been added as attributes on AvailableEvalFunctions yet. """ @@ -65,7 +415,7 @@ def expected_name(cls) -> str: return "public_function" # Placeholder: See super().eval_func_entry for actual name -class EvalFunctionNotAvailable(BaseEvalFunction): +class EvalFunctionNotAvailable(EvalFunctionConfig): def __init__( self, not_available_name: str ): # pylint: disable=super-init-not-called @@ -89,13 +439,14 @@ def expected_name(cls) -> str: EvalFunction = Union[ - Type[BoundingBoxIOU], - Type[BoundingBoxMeanAveragePrecision], - Type[BoundingBoxPrecision], - Type[BoundingBoxRecall], - Type[CustomEvalFunction], - Type[EvalFunctionNotAvailable], - Type[StandardEvalFunction], + PolygonIOUConfig, + PolygonMAPConfig, + PolygonPrecisionConfig, + PolygonRecallConfig, + CategorizationF1Config, + CustomEvalFunction, + EvalFunctionNotAvailable, + StandardEvalFunction, ] @@ -124,25 +475,29 @@ def __init__(self, available_functions: List[EvalFunctionEntry]): f.name: f for f in available_functions if f.is_public } # NOTE: Public are assigned - self._public_to_function: Dict[str, BaseEvalFunction] = {} + self._public_to_function: Dict[str, EvalFunctionConfig] = {} self._custom_to_function: Dict[str, CustomEvalFunction] = { f.name: CustomEvalFunction(f) for f in available_functions if not f.is_public } - self.bbox_iou = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore - self.bbox_precision = self._assign_eval_function_if_defined( - BoundingBoxPrecision # type: ignore + self.bbox_iou: PolygonIOUConfig = self._assign_eval_function_if_defined(PolygonIOUConfig) # type: ignore + self.bbox_precision: PolygonPrecisionConfig = self._assign_eval_function_if_defined( + PolygonPrecisionConfig # type: ignore ) - self.bbox_recall = self._assign_eval_function_if_defined( - BoundingBoxRecall # type: ignore + self.bbox_recall: PolygonRecallConfig = self._assign_eval_function_if_defined( + PolygonRecallConfig # type: ignore ) - self.bbox_map = self._assign_eval_function_if_defined( - BoundingBoxMeanAveragePrecision # type: ignore + self.bbox_map: PolygonMAPConfig = self._assign_eval_function_if_defined( + PolygonMAPConfig # type: ignore ) - self.cat_f1 = self._assign_eval_function_if_defined( - CategorizationF1 # type: ignore + self.cat_f1: CategorizationF1Config = self._assign_eval_function_if_defined( + CategorizationF1Config # type: ignore ) + self.cuboid_iou_2d: CuboidIOU2DConfig = self._assign_eval_function_if_defined(CuboidIOU2DConfig) # type: ignore + self.cuboid_iou_3d: CuboidIOU3DConfig = self._assign_eval_function_if_defined(CuboidIOU3DConfig) # type: ignore + self.cuboid_precision: CuboidPrecisionConfig = self._assign_eval_function_if_defined(CuboidPrecisionConfig) # type: ignore + self.cuboid_recall: CuboidRecallConfig = self._assign_eval_function_if_defined(CuboidRecallConfig) # type: ignore # Add public entries that have not been implemented as an attribute on this class for func_entry in self._public_func_entries.values(): @@ -163,7 +518,7 @@ def __repr__(self): ) @property - def public_functions(self) -> Dict[str, BaseEvalFunction]: + def public_functions(self) -> Dict[str, EvalFunctionConfig]: """Standard functions provided by Model CI. Notes: diff --git a/nucleus/validate/eval_functions/base_eval_function.py b/nucleus/validate/eval_functions/base_eval_function.py index 1ea4c931..fa59ec31 100644 --- a/nucleus/validate/eval_functions/base_eval_function.py +++ b/nucleus/validate/eval_functions/base_eval_function.py @@ -1,4 +1,5 @@ import abc +from typing import Any, Dict from ..constants import ThresholdComparison from ..data_transfer_objects.eval_function import ( @@ -7,8 +8,8 @@ ) -class BaseEvalFunction(abc.ABC): - """Abstract base class for concrete implementations of EvalFunctions +class EvalFunctionConfig(abc.ABC): + """Abstract base class for concrete implementations of EvalFunctionsConfigs Operating on this class with comparison operators produces an EvaluationCriterion """ @@ -17,6 +18,7 @@ def __init__(self, eval_func_entry: EvalFunctionEntry): self.eval_func_entry = eval_func_entry self.id = eval_func_entry.id self.name = eval_func_entry.name + self.eval_func_arguments: Dict[str, Any] = {} def __repr__(self): return f"" @@ -26,7 +28,7 @@ def __repr__(self): def expected_name(cls) -> str: """Name to look for in the EvalFunctionDefinitions""" - def __call__(self) -> "BaseEvalFunction": + def __call__(self, **kwargs) -> "EvalFunctionConfig": """Adding call to prepare for being able to pass parameters to function Notes: @@ -34,6 +36,7 @@ def __call__(self) -> "BaseEvalFunction": to look like eval_function() > 0.5 to support eval_function(parameters) > 0.5 in the future """ + self.eval_func_arguments.update(**kwargs) return self def __gt__(self, other) -> EvaluationCriterion: @@ -57,7 +60,5 @@ def _op_to_test_metric(self, comparison: ThresholdComparison, value): eval_function_id=self.eval_func_entry.id, threshold_comparison=comparison, threshold=value, + eval_func_arguments=self.eval_func_arguments, ) - - def to_entry(self): - return self.eval_func_entry diff --git a/poetry.lock b/poetry.lock index 55dc71a2..59bda3e5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1669,6 +1669,19 @@ nativelib = ["pyobjc-framework-cocoa", "pywin32"] objc = ["pyobjc-framework-cocoa"] win32 = ["pywin32"] +[[package]] +name = "shapely" +version = "1.8.1.post1" +description = "Geometric objects, predicates, and operations" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +all = ["pytest", "pytest-cov", "numpy"] +test = ["pytest", "pytest-cov"] +vectorized = ["numpy"] + [[package]] name = "shellingham" version = "1.4.0" @@ -2064,7 +2077,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.6.2" -content-hash = "8d6269a2d5d30685e0d1e9b38187b6105826042e868acfb5bdb5a3e7188560f8" +content-hash = "4830dd3f9f593703fbbae85efdbd6feca5b0513285ad87132d4f90934fca1098" [metadata.files] absl-py = [ @@ -3206,32 +3219,24 @@ pyzmq = [ {file = "pyzmq-22.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f89468059ebc519a7acde1ee50b779019535db8dcf9b8c162ef669257fef7a93"}, {file = "pyzmq-22.3.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea12133df25e3a6918718fbb9a510c6ee5d3fdd5a346320421aac3882f4feeea"}, {file = "pyzmq-22.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c532fd68b93998aab92356be280deec5de8f8fe59cd28763d2cc8a58747b7f"}, - {file = "pyzmq-22.3.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f907c7359ce8bf7f7e63c82f75ad0223384105f5126f313400b7e8004d9b33c3"}, - {file = "pyzmq-22.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:902319cfe23366595d3fa769b5b751e6ee6750a0a64c5d9f757d624b2ac3519e"}, {file = "pyzmq-22.3.0-cp310-cp310-win32.whl", hash = "sha256:67db33bea0a29d03e6eeec55a8190e033318cee3cbc732ba8fd939617cbf762d"}, {file = "pyzmq-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:7661fc1d5cb73481cf710a1418a4e1e301ed7d5d924f91c67ba84b2a1b89defd"}, {file = "pyzmq-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:79244b9e97948eaf38695f4b8e6fc63b14b78cc37f403c6642ba555517ac1268"}, {file = "pyzmq-22.3.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab888624ed68930442a3f3b0b921ad7439c51ba122dbc8c386e6487a658e4a4e"}, {file = "pyzmq-22.3.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18cd854b423fce44951c3a4d3e686bac8f1243d954f579e120a1714096637cc0"}, {file = "pyzmq-22.3.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:de8df0684398bd74ad160afdc2a118ca28384ac6f5e234eb0508858d8d2d9364"}, - {file = "pyzmq-22.3.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:62bcade20813796c426409a3e7423862d50ff0639f5a2a95be4b85b09a618666"}, - {file = "pyzmq-22.3.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ea5a79e808baef98c48c884effce05c31a0698c1057de8fc1c688891043c1ce1"}, {file = "pyzmq-22.3.0-cp36-cp36m-win32.whl", hash = "sha256:3c1895c95be92600233e476fe283f042e71cf8f0b938aabf21b7aafa62a8dac9"}, {file = "pyzmq-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:851977788b9caa8ed011f5f643d3ee8653af02c5fc723fa350db5125abf2be7b"}, {file = "pyzmq-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b4ebed0977f92320f6686c96e9e8dd29eed199eb8d066936bac991afc37cbb70"}, {file = "pyzmq-22.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42abddebe2c6a35180ca549fadc7228d23c1e1f76167c5ebc8a936b5804ea2df"}, {file = "pyzmq-22.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1e41b32d6f7f9c26bc731a8b529ff592f31fc8b6ef2be9fa74abd05c8a342d7"}, {file = "pyzmq-22.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:be4e0f229cf3a71f9ecd633566bd6f80d9fa6afaaff5489492be63fe459ef98c"}, - {file = "pyzmq-22.3.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:08c4e315a76ef26eb833511ebf3fa87d182152adf43dedee8d79f998a2162a0b"}, - {file = "pyzmq-22.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:badb868fff14cfd0e200eaa845887b1011146a7d26d579aaa7f966c203736b92"}, {file = "pyzmq-22.3.0-cp37-cp37m-win32.whl", hash = "sha256:7c58f598d9fcc52772b89a92d72bf8829c12d09746a6d2c724c5b30076c1f11d"}, {file = "pyzmq-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2b97502c16a5ec611cd52410bdfaab264997c627a46b0f98d3f666227fd1ea2d"}, {file = "pyzmq-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d728b08448e5ac3e4d886b165385a262883c34b84a7fe1166277fe675e1c197a"}, {file = "pyzmq-22.3.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:480b9931bfb08bf8b094edd4836271d4d6b44150da051547d8c7113bf947a8b0"}, {file = "pyzmq-22.3.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7dc09198e4073e6015d9a8ea093fc348d4e59de49382476940c3dd9ae156fba8"}, {file = "pyzmq-22.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ca6cd58f62a2751728016d40082008d3b3412a7f28ddfb4a2f0d3c130f69e74"}, - {file = "pyzmq-22.3.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:468bd59a588e276961a918a3060948ae68f6ff5a7fa10bb2f9160c18fe341067"}, - {file = "pyzmq-22.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c88fa7410e9fc471e0858638f403739ee869924dd8e4ae26748496466e27ac59"}, {file = "pyzmq-22.3.0-cp38-cp38-win32.whl", hash = "sha256:c0f84360dcca3481e8674393bdf931f9f10470988f87311b19d23cda869bb6b7"}, {file = "pyzmq-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:f762442bab706fd874064ca218b33a1d8e40d4938e96c24dafd9b12e28017f45"}, {file = "pyzmq-22.3.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:954e73c9cd4d6ae319f1c936ad159072b6d356a92dcbbabfd6e6204b9a79d356"}, @@ -3239,8 +3244,6 @@ pyzmq = [ {file = "pyzmq-22.3.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:acebba1a23fb9d72b42471c3771b6f2f18dcd46df77482612054bd45c07dfa36"}, {file = "pyzmq-22.3.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cf98fd7a6c8aaa08dbc699ffae33fd71175696d78028281bc7b832b26f00ca57"}, {file = "pyzmq-22.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d072f7dfbdb184f0786d63bda26e8a0882041b1e393fbe98940395f7fab4c5e2"}, - {file = "pyzmq-22.3.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:53f4fd13976789ffafedd4d46f954c7bb01146121812b72b4ddca286034df966"}, - {file = "pyzmq-22.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1b5d457acbadcf8b27561deeaa386b0217f47626b29672fa7bd31deb6e91e1b"}, {file = "pyzmq-22.3.0-cp39-cp39-win32.whl", hash = "sha256:e6a02cf7271ee94674a44f4e62aa061d2d049001c844657740e156596298b70b"}, {file = "pyzmq-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:d3dcb5548ead4f1123851a5ced467791f6986d68c656bc63bfff1bf9e36671e2"}, {file = "pyzmq-22.3.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3a4c9886d61d386b2b493377d980f502186cd71d501fffdba52bd2a0880cef4f"}, @@ -3407,6 +3410,34 @@ send2trash = [ {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"}, {file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"}, ] +shapely = [ + {file = "Shapely-1.8.1.post1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0ca96a3314b7a38a3bb385531469de1fcf2b2c2979ec2aa4f37b4c70632cf1ad"}, + {file = "Shapely-1.8.1.post1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:493902923fdd135316161a4ece5294ba3ce81accaa54540d2af3b93f7231143a"}, + {file = "Shapely-1.8.1.post1-cp310-cp310-win_amd64.whl", hash = "sha256:b82fc74d5efb11a71283c4ed69b4e036997cc70db4b73c646207ddf0476ade44"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:89bc5f3abc1ccbc7682c2e1664153c4f8f125fa9c24bff4abca48685739d5636"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:44cb895b1710f7559c28d69dfa08cafe4f58cd4b7a87091a55bdf6711ad9ad66"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:437fff3b6274be26ffa3e450de711ee01e436324b5a405952add2146227e3eb5"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-win32.whl", hash = "sha256:dc0f46212f84c57d13189fc33cf61e13eee292704d7652e931e4b51c54b0c73c"}, + {file = "Shapely-1.8.1.post1-cp36-cp36m-win_amd64.whl", hash = "sha256:9248aad099ecf228fbdd877b0c668823dd83c48798cf04d49a1be75167e3a7ce"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bab5ff7c576588acccd665ecce2a0fe7b47d4ce0398f2d5c1e5b2e27d09398d2"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2381ce0aff67d569eb509bcc051264aa5fbdc1fdd54f4c09963d0e09f16a8f1b"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b4d35e72022b2dbf152d476b0362596011c674ff68be9fc8f2e68e71d86502ca"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-win32.whl", hash = "sha256:5a420e7112b55a1587412a5b03ebf59e302ddd354da68516d3721718f6b8a7c5"}, + {file = "Shapely-1.8.1.post1-cp37-cp37m-win_amd64.whl", hash = "sha256:c4c366e18edf91196a399f8f0f046f93516002e6d8af0b57c23e7c7d91944b16"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2020fda37c708d44a613c020cea09e81e476f96866f348afc2601e66c0e71db1"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:69d5352fb977655c85d2f40a05ae24fc5053cccee77d0a8b1f773e54804e723e"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:83f3c8191d30ae0e3dd557434c48ca591d75342d5a3f42fc5148ec42796be624"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3e792635e92c9aacd1452a589a4fa2970114b6a9b1165e09655481f6e58970f5"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-win32.whl", hash = "sha256:8cf7331f61780506976fe2175e069d898e1b04ace73be21aad55c3ee92e58e3a"}, + {file = "Shapely-1.8.1.post1-cp38-cp38-win_amd64.whl", hash = "sha256:f109064bdb0753a6bac6238538cfeeb4a09739e2d556036b343b2eabeb9520b2"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:aea1e87450adffba3d04ccbaa790df719bb7aa23b05ac797ad16be236a5d0db8"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3a3602ba2e7715ddd5d4114173dec83d3181bfb2497e8589676c284aa739fd67"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:679789d774cfe09ca05118cab78c0a6a42985b3ed23bc93606272a4509b4df28"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:363df36370f28fdc7789857929f6ff27e659f64087b4c89f7a47ed43bd3bfe4d"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-win32.whl", hash = "sha256:bc6063875182515d3888180cc4cbdbaa6443e4a4386c4bb25499e9875b75dcac"}, + {file = "Shapely-1.8.1.post1-cp39-cp39-win_amd64.whl", hash = "sha256:54aeb2a57978ce731fd52289d0e1deee7c232d41aed53091f38776378f644184"}, + {file = "Shapely-1.8.1.post1.tar.gz", hash = "sha256:93ff06ff05fbe2be843b93c7b1ad8292e56e665ba01b4708f75ae8a757972e9f"}, +] shellingham = [ {file = "shellingham-1.4.0-py2.py3-none-any.whl", hash = "sha256:536b67a0697f2e4af32ab176c00a50ac2899c5a05e0d8e2dadac8e58888283f9"}, {file = "shellingham-1.4.0.tar.gz", hash = "sha256:4855c2458d6904829bd34c299f11fdeed7cfefbf8a2c522e4caea6cd76b3171e"}, diff --git a/pyproject.toml b/pyproject.toml index a3a47502..93cfd03d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.8.3" +version = "0.9b4" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] @@ -45,6 +45,7 @@ click = ">=7.1.2,<9.0" # NOTE: COLAB has 7.1.2 and has problems updating rich = "^10.15.2" shellingham = "^1.4.0" scikit-learn = ">=0.24.0" +Shapely = "^1.8.1" [tool.poetry.dev-dependencies] poetry = "^1.1.5" diff --git a/tests/helpers.py b/tests/helpers.py index 7c1205e8..6e16d147 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -172,6 +172,7 @@ def reference_id_from_url(url): "height": 80 + i * 10, "reference_id": reference_id_from_url(TEST_IMG_URLS[i]), "annotation_id": f"[Pytest] Box Annotation Annotation Id{i}", + "metadata": {"field_1": "string", "index": i}, } for i in range(len(TEST_IMG_URLS)) ] diff --git a/tests/metrics/test_filtering.py b/tests/metrics/test_filtering.py new file mode 100644 index 00000000..ac8c4119 --- /dev/null +++ b/tests/metrics/test_filtering.py @@ -0,0 +1,280 @@ +import json + +import pytest + +from nucleus.metrics import FieldFilter, MetadataFilter, apply_filters +from tests.metrics.helpers import ( + TEST_BOX_ANNOTATION_LIST, + TEST_BOX_PREDICTION_LIST, +) + + +@pytest.fixture( + params=[ + TEST_BOX_ANNOTATION_LIST.box_annotations, + TEST_BOX_PREDICTION_LIST.box_predictions, + ] +) +def annotations_or_predictions(request): + yield request.param + + +def test_filter_field(annotations_or_predictions): + dnf_filters = [ + FieldFilter("label", "==", annotations_or_predictions[0].label) + ] + filtered = apply_filters(annotations_or_predictions, dnf_filters) + assert filtered == [annotations_or_predictions[0]] + + +def test_filter_metadata_field(annotations_or_predictions): + dnf_filters = [ + MetadataFilter( + "index", "==", annotations_or_predictions[0].metadata["index"] + ) + ] + filtered = apply_filters(annotations_or_predictions, dnf_filters) + assert filtered == [annotations_or_predictions[0]] + + +def test_only_and(annotations_or_predictions): + and_filters = [ + MetadataFilter( + "index", "==", annotations_or_predictions[0].metadata["index"] + ) + ] + or_filters = [and_filters] + filtered_and = apply_filters(annotations_or_predictions, and_filters) + filtered_or = apply_filters(annotations_or_predictions, or_filters) + assert filtered_and == filtered_or + + +def test_json_encoded_filters(annotations_or_predictions): + filters = [ + [MetadataFilter("index", ">", 0), MetadataFilter("index", "<", 4)] + ] + expected = apply_filters(annotations_or_predictions, filters) + json_string = json.dumps(filters) + filters_from_json = json.loads(json_string) + json_filtered = apply_filters( + annotations_or_predictions, filters_from_json + ) + assert json_filtered == expected + + +def test_json_encoded_and_filters(annotations_or_predictions): + filters = [ + MetadataFilter("index", ">", 0), + MetadataFilter("index", "<", 4), + ] + expected = apply_filters(annotations_or_predictions, filters) + json_string = json.dumps(filters) + filters_from_json = json.loads(json_string) + json_filtered = apply_filters( + annotations_or_predictions, filters_from_json + ) + assert json_filtered == expected + + +def test_or_branches(annotations_or_predictions): + index_0_or_2 = [ + [MetadataFilter("index", "==", 0)], + [MetadataFilter("index", "==", 2)], + ] + filtered = apply_filters(annotations_or_predictions, index_0_or_2) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_only_one_or(annotations_or_predictions): + later_matches = [ + [MetadataFilter("index", "==", -1)], + [MetadataFilter("index", "==", 2)], + ] + filtered = apply_filters(annotations_or_predictions, later_matches) + assert filtered == [ + annotations_or_predictions[2], + ] + + +def test_and_branches(annotations_or_predictions): + index_0_or_2 = [ + [ + MetadataFilter("index", "==", 0), + FieldFilter("label", "==", annotations_or_predictions[0].label), + ], + [ + MetadataFilter("index", "==", 2), + FieldFilter("label", "==", annotations_or_predictions[2].label), + ], + ] + filtered = apply_filters(annotations_or_predictions, index_0_or_2) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_multi_or(annotations_or_predictions): + all_match = [ + [MetadataFilter("index", "==", i)] + for i in range(len(annotations_or_predictions)) + ] + filtered = apply_filters(annotations_or_predictions, all_match) + assert filtered == annotations_or_predictions + + +def test_missing_field_raises(annotations_or_predictions): + missing_field = [[FieldFilter("i_dont_exist", "==", 1)]] + with pytest.raises(AttributeError): + apply_filters(annotations_or_predictions, missing_field) + + +def test_allow_missing_field(annotations_or_predictions): + missing_field = [ + [FieldFilter("i_dont_exist", "==", 1, allow_missing=True)] + ] + filtered = apply_filters(annotations_or_predictions, missing_field) + assert filtered == [] + + +def test_missing_metadata_raises(annotations_or_predictions): + missing_field = [[MetadataFilter("i_dont_exist", "==", 1)]] + with pytest.raises(KeyError): + apply_filters(annotations_or_predictions, missing_field) + + +def test_allow_missing_metadata_field(annotations_or_predictions): + missing_field = [ + [FieldFilter("i_dont_exist", "==", 1, allow_missing=True)] + ] + filtered = apply_filters(annotations_or_predictions, missing_field) + assert filtered == [] + + +def test_gt_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", ">", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_gt_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", ">", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_gte_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", ">=", 1)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_gte_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", ">=", annotations_or_predictions[1].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_lt_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "<", 1)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:1] + + +def test_lt_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "<", annotations_or_predictions[1].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:1] + + +def test_lte_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "<=", 1)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:2] + + +def test_lte_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "<=", annotations_or_predictions[1].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[:2] + + +def test_eqeq_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "==", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_eqeq_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "==", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_eq_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "=", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_eq_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "=", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [annotations_or_predictions[0]] + + +def test_neq_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "!=", 0)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_neq_field(annotations_or_predictions): + valid_gt = [FieldFilter("x", "!=", annotations_or_predictions[0].x)] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[1:] + + +def test_in_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "in", [0, 2])] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_in_field(annotations_or_predictions): + valid_gt = [ + FieldFilter( + "x", + "in", + [annotations_or_predictions[0].x, annotations_or_predictions[2].x], + ) + ] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == [ + annotations_or_predictions[0], + annotations_or_predictions[2], + ] + + +def test_not_in_metadata(annotations_or_predictions): + valid_gt = [MetadataFilter("index", "not in", [0, 1])] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[2:] + + +def test_not_in_field(annotations_or_predictions): + valid_gt = [ + FieldFilter( + "x", + "not in", + [annotations_or_predictions[0].x, annotations_or_predictions[1].x], + ) + ] + filtered = apply_filters(annotations_or_predictions, valid_gt) + assert filtered == annotations_or_predictions[2:] diff --git a/tests/validate/test_scenario_test.py b/tests/validate/test_scenario_test.py index f6b3e61c..cdf63f0e 100644 --- a/tests/validate/test_scenario_test.py +++ b/tests/validate/test_scenario_test.py @@ -88,3 +88,14 @@ def test_scenario_test_set_model_baseline(CLIENT, annotations, scenario_test): # create some dataset_items for the scenario test to reference with pytest.raises(Exception): scenario_test.set_baseline_model("nonexistent_model_id") + + +def test_passing_eval_arguments(CLIENT, test_slice, annotations): + test_name = "scenario_test_" + get_uuid() # use uuid to make unique + CLIENT.validate.create_scenario_test( + name=test_name, + slice_id=test_slice.id, + evaluation_functions=[ + CLIENT.validate.eval_functions.bbox_iou(iou_threshold=0.5) + ], + )