Skip to content

Commit 07d1e74

Browse files
authored
External information sources (#7)
* prepares database for embedders as code * build api around embedders * adds embedder states * adds option to add external information sources like model callbacks * update to differentiate between callbacks and heuristics * implements suggestions * removes embedder and implements suggestion from PR 39 * implements PR 7
1 parent 6b159de commit 07d1e74

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

business_objects/attribute.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ def get_by_name(project_id: str, name: str) -> Attribute:
3939
)
4040

4141

42+
def get_all_by_names(project_id: str, attribute_names: List[str]) -> List[Attribute]:
43+
return (
44+
session.query(Attribute)
45+
.filter(
46+
Attribute.project_id == project_id,
47+
Attribute.name.in_(attribute_names),
48+
)
49+
.all()
50+
)
51+
52+
4253
def get_all(project_id: str) -> List[Attribute]:
4354
return session.query(Attribute).filter(Attribute.project_id == project_id).all()
4455

business_objects/information_source.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def get(project_id: str, source_id: str) -> InformationSource:
2222
)
2323

2424

25+
def get_by_name(project_id: str, name: str) -> InformationSource:
26+
return (
27+
session.query(InformationSource)
28+
.filter(
29+
InformationSource.project_id == project_id,
30+
InformationSource.name == name,
31+
)
32+
.first()
33+
)
34+
35+
2536
def get_all(project_id: str) -> List[InformationSource]:
2637
return (
2738
session.query(InformationSource)
@@ -168,7 +179,13 @@ def get_exclusion_record_ids_for_task(task_id: str) -> List[str]:
168179
return exclusion_ids
169180

170181

171-
def get_overview_data(project_id: str) -> List[Dict[str, Any]]:
182+
def get_overview_data(
183+
project_id: str, is_model_callback: bool = False
184+
) -> List[Dict[str, Any]]:
185+
if is_model_callback:
186+
type_selection = " = 'MODEL_CALLBACK'"
187+
else:
188+
type_selection = " != 'MODEL_CALLBACK'"
172189
query = f"""
173190
SELECT array_agg(row_to_json(data_select))
174191
FROM (
@@ -205,14 +222,16 @@ def get_overview_data(project_id: str) -> List[Dict[str, Any]]:
205222
GROUP BY source_id) stats
206223
ON _is.id = stats.source_id
207224
WHERE _is.project_id = '{project_id}'
225+
AND _is.type {type_selection}
208226
ORDER BY "createdAt" DESC,name
209227
)data_select """
210228
values = general.execute_first(query)
211229

212230
if values:
213231
return values[0]
214232

215-
def continue_payload(project_id:str,source_id:str,payload_id:str)->bool:
233+
234+
def continue_payload(project_id: str, source_id: str, payload_id: str) -> bool:
216235
query = f"""
217236
SELECT isp.state
218237
FROM information_source_payload isp
@@ -223,7 +242,7 @@ def continue_payload(project_id:str,source_id:str,payload_id:str)->bool:
223242
AND isp.project_id = '{project_id}' """
224243

225244
value = general.execute_first(query)
226-
if not value or value[0]!= "CREATED":
245+
if not value or value[0] != "CREATED":
227246
return False
228247
return True
229248

@@ -444,8 +463,16 @@ def update_quantity_stats(
444463

445464

446465
def update_is_selected_for_project(
447-
project_id: str, update_value: bool, with_commit: bool = False
466+
project_id: str,
467+
update_value: bool,
468+
with_commit: bool = False,
469+
is_model_callback: bool = False,
448470
) -> None:
471+
472+
if is_model_callback:
473+
type_selection = " = 'MODEL_CALLBACK'"
474+
else:
475+
type_selection = " != 'MODEL_CALLBACK'"
449476
if update_value:
450477
str_value = "TRUE"
451478
else:
@@ -454,7 +481,9 @@ def update_is_selected_for_project(
454481
query = f"""
455482
UPDATE information_source
456483
SET is_selected = {str_value}
457-
WHERE project_id = '{project_id}' """
484+
WHERE project_id = '{project_id}'
485+
AND type {type_selection}
486+
"""
458487
general.execute(query)
459488
general.flush_or_commit(with_commit)
460489

business_objects/labeling_task.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def get_task_name_id_dict(project_id: str) -> Dict[str, str]:
3232
return {labeling_task.name: labeling_task.id for labeling_task in labeling_tasks}
3333

3434

35+
def get_labeling_task_by_name(project_id: str, task_name: str) -> LabelingTask:
36+
return (
37+
session.query(LabelingTask)
38+
.filter(LabelingTask.project_id == project_id, LabelingTask.name == task_name)
39+
.first()
40+
)
41+
42+
3543
def get_labeling_tasks_by_selected_sources(project_id: str) -> List[LabelingTask]:
3644
return (
3745
session.query(LabelingTask)

business_objects/record_label_association.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,20 @@ def delete_by_source_id(
474474
general.flush_or_commit(with_commit)
475475

476476

477+
def delete_by_source_id_and_record_ids(
478+
project_id: str,
479+
information_source_id: str,
480+
record_ids: List[str],
481+
with_commit: bool = False,
482+
) -> None:
483+
session.query(RecordLabelAssociation).filter(
484+
RecordLabelAssociation.project_id == project_id,
485+
RecordLabelAssociation.source_id == information_source_id,
486+
RecordLabelAssociation.record_id.in_(record_ids),
487+
).delete()
488+
general.flush_or_commit(with_commit)
489+
490+
477491
def delete_record_label_associations(
478492
project_id: str, record_task_concatenation: str, with_commit: bool = False
479493
) -> None:

enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class LabelSource(Enum):
2727
# WEAK_SUPERVISION = Output of the Weak Supervision Model - ehemeals "programmatic"
2828
WEAK_SUPERVISION = "WEAK_SUPERVISION"
2929
INFORMATION_SOURCE = "INFORMATION_SOURCE"
30+
MODEL_CALLBACK = "MODEL_CALLBACK"
3031

3132

3233
class InformationSourceType(Enum):
@@ -186,6 +187,7 @@ class NotificationType(Enum):
186187
WEAK_SUPERVISION_TASK_STARTED = "WEAK_SUPERVISION_TASK_STARTED"
187188
WEAK_SUPERVISION_TASK_DONE = "WEAK_SUPERVISION_TASK_DONE"
188189
WEAK_SUPERVISION_TASK_FAILED = "WEAK_SUPERVISION_TASK_FAILED"
190+
189191
INFORMATION_SOURCE_STARTED = "INFORMATION_SOURCE_STARTED"
190192
INFORMATION_SOURCE_PREPARATION_STARTED = "INFORMATION_SOURCE_PREPARATION_STARTED"
191193
INFORMATION_SOURCE_COMPLETED = "INFORMATION_SOURCE_COMPLETED"

models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,6 @@ class RecordAttributeTokenStatistics(Base):
555555
num_token = Column(Integer)
556556

557557

558-
# -------------------- EMBEDDING_ --------------------
559558
class Embedding(Base):
560559
__tablename__ = Tablenames.EMBEDDING.value
561560
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

0 commit comments

Comments
 (0)