diff --git a/business_objects/general.py b/business_objects/general.py index ed8bbeb1..180c8e98 100644 --- a/business_objects/general.py +++ b/business_objects/general.py @@ -30,7 +30,6 @@ def commit() -> None: session.commit() - def remove_and_refresh_session( session_token: Any, request_new: bool = False ) -> Union[Any, None]: @@ -70,6 +69,10 @@ def execute_distinct_count(count_sql: str) -> int: return session.execute(count_sql).first().distinct_count +def set_seed(seed: float = 0) -> None: + execute(f"SELECT setseed({seed});") + + def get_bind() -> Any: return session.get_bind() diff --git a/business_objects/project.py b/business_objects/project.py index dbdfba0e..3e36d082 100644 --- a/business_objects/project.py +++ b/business_objects/project.py @@ -1,9 +1,17 @@ from typing import List, Optional, Any, Dict, Union +from sqlalchemy.sql import func from . import general + from .. import enums from ..session import session -from ..models import Project +from ..models import ( + DataSliceRecordAssociation, + LabelingTask, + LabelingTaskLabel, + Project, + RecordLabelAssociation, +) def get(project_id: str) -> Project: @@ -60,6 +68,57 @@ def get_label_distribution( return values[0] +def get_confidence_distribution( + project_id: str, + labeling_task_id: str, + data_slice_id: Optional[str] = None, + num_samples: Optional[int] = None, +) -> List[float]: + query_filter = ( + session.query(RecordLabelAssociation.confidence) + .join( + LabelingTaskLabel, + (RecordLabelAssociation.labeling_task_label_id == LabelingTaskLabel.id) + & (LabelingTaskLabel.project_id == RecordLabelAssociation.project_id), + ) + .join( + LabelingTask, + (LabelingTask.id == LabelingTaskLabel.labeling_task_id) + & (LabelingTask.project_id == LabelingTaskLabel.project_id), + ) + .filter( + RecordLabelAssociation.project_id == project_id, + LabelingTask.id == labeling_task_id, + RecordLabelAssociation.source_type + == enums.LabelSource.WEAK_SUPERVISION.value, + RecordLabelAssociation.project_id == project_id, + ) + ) + + if data_slice_id is not None: + query_filter = query_filter.join( + DataSliceRecordAssociation, + (DataSliceRecordAssociation.record_id == RecordLabelAssociation.record_id) + & ( + DataSliceRecordAssociation.project_id + == RecordLabelAssociation.project_id + ), + ).filter( + DataSliceRecordAssociation.data_slice_id == data_slice_id, + ) + + if num_samples is not None: + query_filter = query_filter.order_by(func.random()).limit(num_samples) + general.set_seed(0) + confidence_scores = [confidence for confidence, in (query_filter.all())] + confidence_scores = sorted(confidence_scores) + else: + query_filter = query_filter.order_by(RecordLabelAssociation.confidence.asc()) + confidence_scores = [confidence for confidence, in (query_filter.all())] + + return confidence_scores + + def get_confusion_matrix( project_id: str, labeling_task_id: str, @@ -130,7 +189,7 @@ def create( created_by: str, created_at: Optional[str] = None, with_commit: bool = False, - status: enums.ProjectStatus = enums.ProjectStatus.INIT_UPLOAD + status: enums.ProjectStatus = enums.ProjectStatus.INIT_UPLOAD, ) -> Project: project: Project = Project( name=name,