From 2dd90a869bd4556702362f2ab640c0e3adb16316 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Thu, 8 Sep 2022 09:42:59 +0200 Subject: [PATCH 1/8] prepares database for embedders as code --- enums.py | 2 ++ models.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/enums.py b/enums.py index 82f8d644..6cd08128 100644 --- a/enums.py +++ b/enums.py @@ -65,6 +65,8 @@ class Tablenames(Enum): LABELING_TASK_LABEL = "labeling_task_label" EMBEDDING = "embedding" EMBEDDING_TENSOR = "embedding_tensor" + EMBEDDER = "embedder" + EMBEDDER_PAYLOAD = "embedder_payload" RECORD = "record" RECORD_TOKENIZED = "record_tokenized" RECORD_TOKENIZATION_TASK = "record_tokenization_task" diff --git a/models.py b/models.py index 13a82f70..843ce83e 100644 --- a/models.py +++ b/models.py @@ -556,6 +556,66 @@ class RecordAttributeTokenStatistics(Base): # -------------------- EMBEDDING_ -------------------- +class Embedder(Base): + __tablename__ = Tablenames.EMBEDDER.value + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + ) + # e.g. ATTRIBUTE-LEVEL, TOKEN-LEVEL + type = Column(String) + name = Column(String) + description = Column(String) + source_code = Column(String) + version = Column(Integer, default=1) + + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="CASCADE"), + index=True, + ) + + payloads = parent_to_child_relationship( + Tablenames.EMBEDDER, + Tablenames.EMBEDDER_PAYLOAD, + order_by="iteration.desc()", + ) + + +class EmbedderPayload(Base): + __tablename__ = Tablenames.EMBEDDER_PAYLOAD.value + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + ) + embedder_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.EMBEDDER.value}.id", ondelete="CASCADE"), + index=True, + ) + state = Column( + String, default=PayloadState.CREATED.value + ) # e.g. CREATED, FINISHED, FAILED + progress = Column(Float, default=0.0) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="CASCADE"), + index=True, + ) + finished_at = Column(DateTime) + iteration = Column(Integer) + source_code = Column(String) + input_data = Column(JSON) + output_data = Column(JSON) + logs = Column(ARRAY(String)) + + class Embedding(Base): __tablename__ = Tablenames.EMBEDDING.value id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) From 25e8f7efdb9a040552a95b59696d5a616fd980b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Thu, 8 Sep 2022 10:46:25 +0200 Subject: [PATCH 2/8] build api around embedders --- business_objects/embedder.py | 236 +++++++++++++++++++++++++++++++++++ models.py | 5 + 2 files changed, 241 insertions(+) create mode 100644 business_objects/embedder.py diff --git a/business_objects/embedder.py b/business_objects/embedder.py new file mode 100644 index 00000000..2c032698 --- /dev/null +++ b/business_objects/embedder.py @@ -0,0 +1,236 @@ +from datetime import datetime +from typing import Dict, List, Any, Optional + +from . import general +from ..models import ( + Embedder, + EmbedderPayload, +) +from ..session import session + + +def get(project_id: str, source_id: str) -> Embedder: + return ( + session.query(Embedder) + .filter( + Embedder.project_id == project_id, + Embedder.id == source_id, + ) + .first() + ) + + +def get_all(project_id: str) -> List[Embedder]: + return ( + session.query(Embedder) + .filter( + Embedder.project_id == project_id, + ) + .all() + ) + + +def get_payload(project_id: str, payload_id: str) -> EmbedderPayload: + return ( + session.query(EmbedderPayload) + .filter( + EmbedderPayload.project_id == project_id, + EmbedderPayload.id == payload_id, + ) + .first() + ) + + +def get_last_payload(project_id: str, source_id: str) -> EmbedderPayload: + return ( + session.query(EmbedderPayload) + .filter( + EmbedderPayload.project_id == project_id, + EmbedderPayload.embedder_id == source_id, + ) + .order_by(EmbedderPayload.created_at.desc()) + .first() + ) + + +def get_payloads_by_project_id(project_id: str) -> List[Any]: + query: str = f""" + SELECT + payload.id, + payload.embedder_id, + payload.created_at, + payload.finished_at, + payload.iteration, + payload.source_code, + payload.logs, + payload.state + FROM + embedder_payload AS payload + INNER JOIN + embedder + ON + payload.embedder_id=embedder.id + WHERE + embedder.project_id='{project_id}' + ; + """ + return general.execute_all(query) + + +def get_overview_data(project_id: str) -> List[Dict[str, Any]]: + query = f""" + SELECT array_agg(row_to_json(data_select)) + FROM ( + SELECT + _is.id, + _is.name, + _is.type "embedderType", + _is.description, + _is.created_at "createdAt", + _is.created_by "createdBy", + isp.state, + isp.created_at "lastRun", + FROM embedder _is + LEFT JOIN LATERAL( + SELECT isp.id,isp.state,isp.created_at + FROM embedder_payload isp + WHERE _is.id = isp.embedder_id + AND _is.project_id = isp.project_id + ORDER BY isp.iteration DESC + LIMIT 1 + ) isp ON TRUE + WHERE _is.project_id = '{project_id}' + ORDER BY "createdAt" DESC,name + )data_select """ + values = general.execute_first(query) + + if values: + return values[0] + + +def continue_payload(project_id: str, embedder_id: str, payload_id: str) -> bool: + query = f""" + SELECT isp.state + FROM embedder_payload isp + INNER JOIN embedder _is + ON isp.embedder_id = _is.id AND isp.project_id = _is.project_id + WHERE isp.id = '{payload_id}' + AND isp.source_id = '{embedder_id}' + AND isp.project_id = '{project_id}' """ + + value = general.execute_first(query) + if not value or value[0] != "CREATED": + return False + return True + + +def create( + project_id: str, + name: str, + type: str, + description: str, + source_code: str, + version: Optional[int] = None, + created_at: Optional[datetime] = None, + created_by: Optional[str] = None, + with_commit: bool = False, +) -> Embedder: + embedder: Embedder = Embedder( + project_id=project_id, + name=name, + type=type, + description=description, + source_code=source_code, + version=version, + created_at=created_at, + created_by=created_by, + ) + general.add(embedder, with_commit) + return embedder + + +def create_payload( + project_id: str, + embedder_id: str, + state: str, + created_by: Optional[str] = None, + created_at: Optional[datetime] = None, + finished_at: Optional[datetime] = None, + iteration: Optional[int] = None, + source_code: Optional[str] = None, + logs: List[str] = None, + with_commit: bool = False, +) -> EmbedderPayload: + payload: EmbedderPayload = EmbedderPayload( + embedder_id=embedder_id, project_id=project_id, state=state + ) + if created_by: + payload.created_by = created_by + if iteration: + payload.iteration = iteration + if source_code: + payload.source_code = source_code + if created_at: + payload.created_at = created_at + if finished_at: + payload.finished_at = finished_at + if logs: + payload.logs = logs + general.add(payload, with_commit) + return payload + + +def delete(project_id: str, embedder_id: str, with_commit: bool = False) -> None: + session.query(Embedder).filter( + Embedder.project_id == project_id, + Embedder.id == embedder_id, + ).delete() + general.flush_or_commit(with_commit) + + +def update( + project_id: str, + embedder_id: str, + name: Optional[str] = None, + type: Optional[str] = None, + description: Optional[str] = None, + source_code: Optional[str] = None, + version: Optional[int] = None, + created_at: Optional[datetime] = None, + created_by: Optional[str] = None, + with_commit: bool = False, +) -> None: + information_source = get(project_id, embedder_id) + + if name is not None: + information_source.name = name + if type is not None: + information_source.type = type + if description is not None: + information_source.description = description + if source_code is not None: + information_source.source_code = source_code + if version is not None: + information_source.version = version + if created_at is not None: + information_source.created_at = created_at + if created_by is not None: + information_source.created_at = created_by + general.flush_or_commit(with_commit) + + +def update_payload( + project_id: str, + payload_id: str, + state: str = None, + progress: float = None, + with_commit: bool = False, +) -> None: + payload_item = get_payload(project_id, payload_id) + if not payload_item: + return + if state: + payload_item.state = state + if progress: + payload_item.progress = progress + general.flush_or_commit(with_commit) diff --git a/models.py b/models.py index 843ce83e..26aa0b27 100644 --- a/models.py +++ b/models.py @@ -261,6 +261,11 @@ class Project(Base): Tablenames.INFORMATION_SOURCE, order_by=["created_at.desc()", "name.asc()", "id.desc()"], ) + embedders = parent_to_child_relationship( + Tablenames.PROJECT, + Tablenames.EMBEDDER, + order_by=["created_at.desc()", "name.asc()", "id.desc()"], + ) knowledge_bases = parent_to_child_relationship( Tablenames.PROJECT, Tablenames.KNOWLEDGE_BASE, From 59ae2676840f2e640defa4884000446c77322945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Sat, 10 Sep 2022 15:57:58 +0200 Subject: [PATCH 3/8] adds embedder states --- business_objects/embedder_payload.py | 42 ++++++++++++++++++++++++++++ enums.py | 6 ++++ 2 files changed, 48 insertions(+) create mode 100644 business_objects/embedder_payload.py diff --git a/business_objects/embedder_payload.py b/business_objects/embedder_payload.py new file mode 100644 index 00000000..e9a99e47 --- /dev/null +++ b/business_objects/embedder_payload.py @@ -0,0 +1,42 @@ +from datetime import datetime +from typing import List, Optional + +from . import general +from .. import enums +from ..models import Embedder, EmbedderPayload +from ..session import session + + +def get(project_id: str, payload_id: str) -> EmbedderPayload: + return ( + session.query(EmbedderPayload) + .filter( + Embedder.project_id == project_id, + Embedder.id == EmbedderPayload.embedder_id, + EmbedderPayload.id == payload_id, + ) + .first() + ) + + +def create( + project_id: str, + source_code: str, + state: enums.PayloadState, + iteration: int, + embedder_id: str, + created_by: str, + created_at: Optional[datetime] = None, + with_commit: bool = False, +) -> EmbedderPayload: + payload: EmbedderPayload = EmbedderPayload( + source_code=source_code, + state=state.value, + iteration=iteration, + embedder_id=embedder_id, + created_by=created_by, + project_id=project_id, + created_at=created_at, + ) + general.add(payload, with_commit) + return payload diff --git a/enums.py b/enums.py index 6cd08128..bde80d68 100644 --- a/enums.py +++ b/enums.py @@ -188,6 +188,12 @@ class NotificationType(Enum): WEAK_SUPERVISION_TASK_STARTED = "WEAK_SUPERVISION_TASK_STARTED" WEAK_SUPERVISION_TASK_DONE = "WEAK_SUPERVISION_TASK_DONE" WEAK_SUPERVISION_TASK_FAILED = "WEAK_SUPERVISION_TASK_FAILED" + + EMBEDDER_STARTED = "EMBEDDER_STARTED" + EMBEDDER_PREPARATION_STARTED = "EMBEDDER_PREPARATION_STARTED" + EMBEDDER_COMPLETED = "EMBEDDER_COMPLETED" + EMBEDDER_FAILED = "EMBEDDER_FAILED" + INFORMATION_SOURCE_STARTED = "INFORMATION_SOURCE_STARTED" INFORMATION_SOURCE_PREPARATION_STARTED = "INFORMATION_SOURCE_PREPARATION_STARTED" INFORMATION_SOURCE_COMPLETED = "INFORMATION_SOURCE_COMPLETED" From a1911ced10200cece3fa7d1d1aeec704d0cf60aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Sun, 11 Sep 2022 15:18:26 +0200 Subject: [PATCH 4/8] adds option to add external information sources like model callbacks --- business_objects/attribute.py | 11 +++++++++++ business_objects/information_source.py | 16 ++++++++++++++-- business_objects/labeling_task.py | 8 ++++++++ business_objects/record_label_association.py | 14 ++++++++++++++ enums.py | 1 + 5 files changed, 48 insertions(+), 2 deletions(-) diff --git a/business_objects/attribute.py b/business_objects/attribute.py index 80143b20..4daeedb5 100644 --- a/business_objects/attribute.py +++ b/business_objects/attribute.py @@ -39,6 +39,17 @@ def get_by_name(project_id: str, name: str) -> Attribute: ) +def get_all_by_names(project_id: str, attribute_names: List[str]) -> List[Attribute]: + return ( + session.query(Attribute) + .filter( + Attribute.project_id == project_id, + Attribute.name.in_(attribute_names), + ) + .all() + ) + + def get_all(project_id: str) -> List[Attribute]: return session.query(Attribute).filter(Attribute.project_id == project_id).all() diff --git a/business_objects/information_source.py b/business_objects/information_source.py index 67dd2c12..562e4666 100644 --- a/business_objects/information_source.py +++ b/business_objects/information_source.py @@ -22,6 +22,17 @@ def get(project_id: str, source_id: str) -> InformationSource: ) +def get_by_name(project_id: str, name: str) -> InformationSource: + return ( + session.query(InformationSource) + .filter( + InformationSource.project_id == project_id, + InformationSource.name == name, + ) + .first() + ) + + def get_all(project_id: str) -> List[InformationSource]: return ( session.query(InformationSource) @@ -212,7 +223,8 @@ def get_overview_data(project_id: str) -> List[Dict[str, Any]]: if values: return values[0] -def continue_payload(project_id:str,source_id:str,payload_id:str)->bool: + +def continue_payload(project_id: str, source_id: str, payload_id: str) -> bool: query = f""" SELECT isp.state FROM information_source_payload isp @@ -223,7 +235,7 @@ def continue_payload(project_id:str,source_id:str,payload_id:str)->bool: AND isp.project_id = '{project_id}' """ value = general.execute_first(query) - if not value or value[0]!= "CREATED": + if not value or value[0] != "CREATED": return False return True diff --git a/business_objects/labeling_task.py b/business_objects/labeling_task.py index 6ff8dae4..f9a03dce 100644 --- a/business_objects/labeling_task.py +++ b/business_objects/labeling_task.py @@ -32,6 +32,14 @@ def get_task_name_id_dict(project_id: str) -> Dict[str, str]: return {labeling_task.name: labeling_task.id for labeling_task in labeling_tasks} +def get_labeling_task_by_name(project_id: str, task_name: str) -> LabelingTask: + return ( + session.query(LabelingTask) + .filter(LabelingTask.project_id == project_id, LabelingTask.name == task_name) + .first() + ) + + def get_labeling_tasks_by_selected_sources(project_id: str) -> List[LabelingTask]: return ( session.query(LabelingTask) diff --git a/business_objects/record_label_association.py b/business_objects/record_label_association.py index 137c3cbc..0e560f75 100644 --- a/business_objects/record_label_association.py +++ b/business_objects/record_label_association.py @@ -474,6 +474,20 @@ def delete_by_source_id( general.flush_or_commit(with_commit) +def delete_by_source_id_and_record_ids( + project_id: str, + information_source_id: str, + record_ids: List[str], + with_commit: bool = False, +) -> None: + session.query(RecordLabelAssociation).filter( + RecordLabelAssociation.project_id == project_id, + RecordLabelAssociation.source_id == information_source_id, + RecordLabelAssociation.record_id.in_(record_ids), + ).delete() + general.flush_or_commit(with_commit) + + def delete_record_label_associations( project_id: str, record_task_concatenation: str, with_commit: bool = False ) -> None: diff --git a/enums.py b/enums.py index bde80d68..2bf32a47 100644 --- a/enums.py +++ b/enums.py @@ -27,6 +27,7 @@ class LabelSource(Enum): # WEAK_SUPERVISION = Output of the Weak Supervision Model - ehemeals "programmatic" WEAK_SUPERVISION = "WEAK_SUPERVISION" INFORMATION_SOURCE = "INFORMATION_SOURCE" + MODEL_CALLBACK = "MODEL_CALLBACK" class InformationSourceType(Enum): From 1bac975cbf9e9da708596c0976818cfa63dee7b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Mon, 12 Sep 2022 01:04:30 +0200 Subject: [PATCH 5/8] update to differentiate between callbacks and heuristics --- business_objects/information_source.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/business_objects/information_source.py b/business_objects/information_source.py index 562e4666..6d1eed19 100644 --- a/business_objects/information_source.py +++ b/business_objects/information_source.py @@ -179,7 +179,7 @@ def get_exclusion_record_ids_for_task(task_id: str) -> List[str]: return exclusion_ids -def get_overview_data(project_id: str) -> List[Dict[str, Any]]: +def get_overview_data(project_id: str, operator: str) -> List[Dict[str, Any]]: query = f""" SELECT array_agg(row_to_json(data_select)) FROM ( @@ -216,6 +216,7 @@ def get_overview_data(project_id: str) -> List[Dict[str, Any]]: GROUP BY source_id) stats ON _is.id = stats.source_id WHERE _is.project_id = '{project_id}' + AND _is.type {operator} 'MODEL_CALLBACK' ORDER BY "createdAt" DESC,name )data_select """ values = general.execute_first(query) @@ -456,7 +457,10 @@ def update_quantity_stats( def update_is_selected_for_project( - project_id: str, update_value: bool, with_commit: bool = False + project_id: str, + update_value: bool, + with_commit: bool = False, + operator: Optional[str] = "!=", ) -> None: if update_value: str_value = "TRUE" @@ -466,7 +470,9 @@ def update_is_selected_for_project( query = f""" UPDATE information_source SET is_selected = {str_value} - WHERE project_id = '{project_id}' """ + WHERE project_id = '{project_id}' + AND type {operator} 'MODEL_CALLBACK' + """ general.execute(query) general.flush_or_commit(with_commit) From 8fba0154a6bb787f29dd21769b14e1804cf1f7d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Tue, 13 Sep 2022 15:10:35 +0200 Subject: [PATCH 6/8] implements suggestions --- business_objects/embedder.py | 236 --------------------------- business_objects/embedder_payload.py | 42 ----- enums.py | 2 - 3 files changed, 280 deletions(-) delete mode 100644 business_objects/embedder.py delete mode 100644 business_objects/embedder_payload.py diff --git a/business_objects/embedder.py b/business_objects/embedder.py deleted file mode 100644 index 2c032698..00000000 --- a/business_objects/embedder.py +++ /dev/null @@ -1,236 +0,0 @@ -from datetime import datetime -from typing import Dict, List, Any, Optional - -from . import general -from ..models import ( - Embedder, - EmbedderPayload, -) -from ..session import session - - -def get(project_id: str, source_id: str) -> Embedder: - return ( - session.query(Embedder) - .filter( - Embedder.project_id == project_id, - Embedder.id == source_id, - ) - .first() - ) - - -def get_all(project_id: str) -> List[Embedder]: - return ( - session.query(Embedder) - .filter( - Embedder.project_id == project_id, - ) - .all() - ) - - -def get_payload(project_id: str, payload_id: str) -> EmbedderPayload: - return ( - session.query(EmbedderPayload) - .filter( - EmbedderPayload.project_id == project_id, - EmbedderPayload.id == payload_id, - ) - .first() - ) - - -def get_last_payload(project_id: str, source_id: str) -> EmbedderPayload: - return ( - session.query(EmbedderPayload) - .filter( - EmbedderPayload.project_id == project_id, - EmbedderPayload.embedder_id == source_id, - ) - .order_by(EmbedderPayload.created_at.desc()) - .first() - ) - - -def get_payloads_by_project_id(project_id: str) -> List[Any]: - query: str = f""" - SELECT - payload.id, - payload.embedder_id, - payload.created_at, - payload.finished_at, - payload.iteration, - payload.source_code, - payload.logs, - payload.state - FROM - embedder_payload AS payload - INNER JOIN - embedder - ON - payload.embedder_id=embedder.id - WHERE - embedder.project_id='{project_id}' - ; - """ - return general.execute_all(query) - - -def get_overview_data(project_id: str) -> List[Dict[str, Any]]: - query = f""" - SELECT array_agg(row_to_json(data_select)) - FROM ( - SELECT - _is.id, - _is.name, - _is.type "embedderType", - _is.description, - _is.created_at "createdAt", - _is.created_by "createdBy", - isp.state, - isp.created_at "lastRun", - FROM embedder _is - LEFT JOIN LATERAL( - SELECT isp.id,isp.state,isp.created_at - FROM embedder_payload isp - WHERE _is.id = isp.embedder_id - AND _is.project_id = isp.project_id - ORDER BY isp.iteration DESC - LIMIT 1 - ) isp ON TRUE - WHERE _is.project_id = '{project_id}' - ORDER BY "createdAt" DESC,name - )data_select """ - values = general.execute_first(query) - - if values: - return values[0] - - -def continue_payload(project_id: str, embedder_id: str, payload_id: str) -> bool: - query = f""" - SELECT isp.state - FROM embedder_payload isp - INNER JOIN embedder _is - ON isp.embedder_id = _is.id AND isp.project_id = _is.project_id - WHERE isp.id = '{payload_id}' - AND isp.source_id = '{embedder_id}' - AND isp.project_id = '{project_id}' """ - - value = general.execute_first(query) - if not value or value[0] != "CREATED": - return False - return True - - -def create( - project_id: str, - name: str, - type: str, - description: str, - source_code: str, - version: Optional[int] = None, - created_at: Optional[datetime] = None, - created_by: Optional[str] = None, - with_commit: bool = False, -) -> Embedder: - embedder: Embedder = Embedder( - project_id=project_id, - name=name, - type=type, - description=description, - source_code=source_code, - version=version, - created_at=created_at, - created_by=created_by, - ) - general.add(embedder, with_commit) - return embedder - - -def create_payload( - project_id: str, - embedder_id: str, - state: str, - created_by: Optional[str] = None, - created_at: Optional[datetime] = None, - finished_at: Optional[datetime] = None, - iteration: Optional[int] = None, - source_code: Optional[str] = None, - logs: List[str] = None, - with_commit: bool = False, -) -> EmbedderPayload: - payload: EmbedderPayload = EmbedderPayload( - embedder_id=embedder_id, project_id=project_id, state=state - ) - if created_by: - payload.created_by = created_by - if iteration: - payload.iteration = iteration - if source_code: - payload.source_code = source_code - if created_at: - payload.created_at = created_at - if finished_at: - payload.finished_at = finished_at - if logs: - payload.logs = logs - general.add(payload, with_commit) - return payload - - -def delete(project_id: str, embedder_id: str, with_commit: bool = False) -> None: - session.query(Embedder).filter( - Embedder.project_id == project_id, - Embedder.id == embedder_id, - ).delete() - general.flush_or_commit(with_commit) - - -def update( - project_id: str, - embedder_id: str, - name: Optional[str] = None, - type: Optional[str] = None, - description: Optional[str] = None, - source_code: Optional[str] = None, - version: Optional[int] = None, - created_at: Optional[datetime] = None, - created_by: Optional[str] = None, - with_commit: bool = False, -) -> None: - information_source = get(project_id, embedder_id) - - if name is not None: - information_source.name = name - if type is not None: - information_source.type = type - if description is not None: - information_source.description = description - if source_code is not None: - information_source.source_code = source_code - if version is not None: - information_source.version = version - if created_at is not None: - information_source.created_at = created_at - if created_by is not None: - information_source.created_at = created_by - general.flush_or_commit(with_commit) - - -def update_payload( - project_id: str, - payload_id: str, - state: str = None, - progress: float = None, - with_commit: bool = False, -) -> None: - payload_item = get_payload(project_id, payload_id) - if not payload_item: - return - if state: - payload_item.state = state - if progress: - payload_item.progress = progress - general.flush_or_commit(with_commit) diff --git a/business_objects/embedder_payload.py b/business_objects/embedder_payload.py deleted file mode 100644 index e9a99e47..00000000 --- a/business_objects/embedder_payload.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from typing import List, Optional - -from . import general -from .. import enums -from ..models import Embedder, EmbedderPayload -from ..session import session - - -def get(project_id: str, payload_id: str) -> EmbedderPayload: - return ( - session.query(EmbedderPayload) - .filter( - Embedder.project_id == project_id, - Embedder.id == EmbedderPayload.embedder_id, - EmbedderPayload.id == payload_id, - ) - .first() - ) - - -def create( - project_id: str, - source_code: str, - state: enums.PayloadState, - iteration: int, - embedder_id: str, - created_by: str, - created_at: Optional[datetime] = None, - with_commit: bool = False, -) -> EmbedderPayload: - payload: EmbedderPayload = EmbedderPayload( - source_code=source_code, - state=state.value, - iteration=iteration, - embedder_id=embedder_id, - created_by=created_by, - project_id=project_id, - created_at=created_at, - ) - general.add(payload, with_commit) - return payload diff --git a/enums.py b/enums.py index 2bf32a47..072e6fbf 100644 --- a/enums.py +++ b/enums.py @@ -66,8 +66,6 @@ class Tablenames(Enum): LABELING_TASK_LABEL = "labeling_task_label" EMBEDDING = "embedding" EMBEDDING_TENSOR = "embedding_tensor" - EMBEDDER = "embedder" - EMBEDDER_PAYLOAD = "embedder_payload" RECORD = "record" RECORD_TOKENIZED = "record_tokenized" RECORD_TOKENIZATION_TASK = "record_tokenization_task" From 1d61da9edc8ee32e5fa01738c1ec993e22550da1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Tue, 13 Sep 2022 15:37:05 +0200 Subject: [PATCH 7/8] removes embedder and implements suggestion from PR 39 --- business_objects/information_source.py | 13 +++-- enums.py | 5 -- models.py | 66 -------------------------- 3 files changed, 9 insertions(+), 75 deletions(-) diff --git a/business_objects/information_source.py b/business_objects/information_source.py index 6d1eed19..e9a48b62 100644 --- a/business_objects/information_source.py +++ b/business_objects/information_source.py @@ -179,7 +179,7 @@ def get_exclusion_record_ids_for_task(task_id: str) -> List[str]: return exclusion_ids -def get_overview_data(project_id: str, operator: str) -> List[Dict[str, Any]]: +def get_overview_data(project_id: str, type_selection: str) -> List[Dict[str, Any]]: query = f""" SELECT array_agg(row_to_json(data_select)) FROM ( @@ -216,7 +216,7 @@ def get_overview_data(project_id: str, operator: str) -> List[Dict[str, Any]]: GROUP BY source_id) stats ON _is.id = stats.source_id WHERE _is.project_id = '{project_id}' - AND _is.type {operator} 'MODEL_CALLBACK' + AND _is.type {type_selection} ORDER BY "createdAt" DESC,name )data_select """ values = general.execute_first(query) @@ -460,8 +460,13 @@ def update_is_selected_for_project( project_id: str, update_value: bool, with_commit: bool = False, - operator: Optional[str] = "!=", + is_model_callback: bool = False, ) -> None: + + if is_model_callback: + type_selection = " = 'MODEL_CALLBACK'" + else: + type_selection = " != 'MODEL_CALLBACK'" if update_value: str_value = "TRUE" else: @@ -471,7 +476,7 @@ def update_is_selected_for_project( UPDATE information_source SET is_selected = {str_value} WHERE project_id = '{project_id}' - AND type {operator} 'MODEL_CALLBACK' + AND type {type_selection} """ general.execute(query) general.flush_or_commit(with_commit) diff --git a/enums.py b/enums.py index 072e6fbf..12b40e34 100644 --- a/enums.py +++ b/enums.py @@ -188,11 +188,6 @@ class NotificationType(Enum): WEAK_SUPERVISION_TASK_DONE = "WEAK_SUPERVISION_TASK_DONE" WEAK_SUPERVISION_TASK_FAILED = "WEAK_SUPERVISION_TASK_FAILED" - EMBEDDER_STARTED = "EMBEDDER_STARTED" - EMBEDDER_PREPARATION_STARTED = "EMBEDDER_PREPARATION_STARTED" - EMBEDDER_COMPLETED = "EMBEDDER_COMPLETED" - EMBEDDER_FAILED = "EMBEDDER_FAILED" - INFORMATION_SOURCE_STARTED = "INFORMATION_SOURCE_STARTED" INFORMATION_SOURCE_PREPARATION_STARTED = "INFORMATION_SOURCE_PREPARATION_STARTED" INFORMATION_SOURCE_COMPLETED = "INFORMATION_SOURCE_COMPLETED" diff --git a/models.py b/models.py index 26aa0b27..2b1bcfa8 100644 --- a/models.py +++ b/models.py @@ -261,11 +261,6 @@ class Project(Base): Tablenames.INFORMATION_SOURCE, order_by=["created_at.desc()", "name.asc()", "id.desc()"], ) - embedders = parent_to_child_relationship( - Tablenames.PROJECT, - Tablenames.EMBEDDER, - order_by=["created_at.desc()", "name.asc()", "id.desc()"], - ) knowledge_bases = parent_to_child_relationship( Tablenames.PROJECT, Tablenames.KNOWLEDGE_BASE, @@ -560,67 +555,6 @@ class RecordAttributeTokenStatistics(Base): num_token = Column(Integer) -# -------------------- EMBEDDING_ -------------------- -class Embedder(Base): - __tablename__ = Tablenames.EMBEDDER.value - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - project_id = Column( - UUID(as_uuid=True), - ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), - index=True, - ) - # e.g. ATTRIBUTE-LEVEL, TOKEN-LEVEL - type = Column(String) - name = Column(String) - description = Column(String) - source_code = Column(String) - version = Column(Integer, default=1) - - created_at = Column(DateTime, default=sql.func.now()) - created_by = Column( - UUID(as_uuid=True), - ForeignKey(f"{Tablenames.USER.value}.id", ondelete="CASCADE"), - index=True, - ) - - payloads = parent_to_child_relationship( - Tablenames.EMBEDDER, - Tablenames.EMBEDDER_PAYLOAD, - order_by="iteration.desc()", - ) - - -class EmbedderPayload(Base): - __tablename__ = Tablenames.EMBEDDER_PAYLOAD.value - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - project_id = Column( - UUID(as_uuid=True), - ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), - index=True, - ) - embedder_id = Column( - UUID(as_uuid=True), - ForeignKey(f"{Tablenames.EMBEDDER.value}.id", ondelete="CASCADE"), - index=True, - ) - state = Column( - String, default=PayloadState.CREATED.value - ) # e.g. CREATED, FINISHED, FAILED - progress = Column(Float, default=0.0) - created_at = Column(DateTime, default=sql.func.now()) - created_by = Column( - UUID(as_uuid=True), - ForeignKey(f"{Tablenames.USER.value}.id", ondelete="CASCADE"), - index=True, - ) - finished_at = Column(DateTime) - iteration = Column(Integer) - source_code = Column(String) - input_data = Column(JSON) - output_data = Column(JSON) - logs = Column(ARRAY(String)) - - class Embedding(Base): __tablename__ = Tablenames.EMBEDDING.value id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) From 657a5097223e2659bcf4cfc5ec2b483ef5a12496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Tue, 13 Sep 2022 16:03:41 +0200 Subject: [PATCH 8/8] implements PR 7 --- business_objects/information_source.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/business_objects/information_source.py b/business_objects/information_source.py index e9a48b62..b9980126 100644 --- a/business_objects/information_source.py +++ b/business_objects/information_source.py @@ -179,7 +179,13 @@ def get_exclusion_record_ids_for_task(task_id: str) -> List[str]: return exclusion_ids -def get_overview_data(project_id: str, type_selection: str) -> List[Dict[str, Any]]: +def get_overview_data( + project_id: str, is_model_callback: bool = False +) -> List[Dict[str, Any]]: + if is_model_callback: + type_selection = " = 'MODEL_CALLBACK'" + else: + type_selection = " != 'MODEL_CALLBACK'" query = f""" SELECT array_agg(row_to_json(data_select)) FROM (