Skip to content

Commit a273a85

Browse files
authored
Return indices (#21)
* minor bugfix in adapters for sklearn * adapter returns indicies of records
1 parent d561340 commit a273a85

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

refinery/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ def get_project_details(self) -> Dict[str, str]:
8181
api_response = api_calls.get_request(url, self.session_token)
8282
return api_response
8383

84+
def get_primary_keys(self) -> List[str]:
85+
"""Fetches the primary keys of your current project.
86+
87+
Returns:
88+
List[str]: Containing the primary keys of your project.
89+
"""
90+
project_details = self.get_project_details()
91+
project_attributes = project_details["attributes"]
92+
93+
primary_keys = []
94+
for attribute in project_attributes:
95+
if attribute["is_primary_key"]:
96+
primary_keys.append(attribute["name"])
97+
return primary_keys
98+
8499
def get_lookup_list(self, list_id: str) -> Dict[str, str]:
85100
"""Fetches a lookup list of your current project.
86101

refinery/adapter/sklearn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def build_classification_dataset(
2525
Dict[str, Dict[str, Any]]: Containing the train and test datasets, with embedded inputs.
2626
"""
2727

28-
df_train, df_test, _ = split_train_test_on_weak_supervision(
28+
df_train, df_test, _, primary_keys = split_train_test_on_weak_supervision(
2929
client, sentence_input, classification_label, num_train
3030
)
3131

@@ -40,7 +40,12 @@ def build_classification_dataset(
4040
return {
4141
"train": {
4242
"inputs": inputs_train,
43+
"index": df_train[primary_keys].to_dict("records"),
4344
"labels": df_train["label"],
4445
},
45-
"test": {"inputs": inputs_test, "labels": df_test["label"]},
46+
"test": {
47+
"inputs": inputs_test,
48+
"index": df_test[primary_keys].to_dict("records"),
49+
"labels": df_test["label"],
50+
},
4651
}

refinery/adapter/transformers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@ def build_classification_dataset(
1818
_type_: HuggingFace dataset
1919
"""
2020

21-
df_train, df_test, label_options = split_train_test_on_weak_supervision(
21+
(
22+
df_train,
23+
df_test,
24+
label_options,
25+
primary_keys,
26+
) = split_train_test_on_weak_supervision(
2227
client, sentence_input, classification_label
2328
)
2429

@@ -44,4 +49,6 @@ def build_classification_dataset(
4449
if os.path.exists(test_file_path):
4550
os.remove(test_file_path)
4651

47-
return dataset, mapping
52+
index = {"train": df_train[primary_keys], "test": df_test[primary_keys]}
53+
54+
return dataset, mapping, index

refinery/adapter/util.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,27 @@ def split_train_test_on_weak_supervision(
2020
Tuple[pd.DataFrame, pd.DataFrame, List[str]]: Containing the train and test dataframes and the label name options.
2121
"""
2222

23+
primary_keys = client.get_primary_keys()
24+
2325
label_attribute_train = f"{_label}__WEAK_SUPERVISION"
2426
label_attribute_test = f"{_label}__MANUAL"
2527

2628
df_test = client.get_record_export(
2729
tokenize=False,
28-
keep_attributes=[_input, label_attribute_test],
30+
keep_attributes=primary_keys + [_input, label_attribute_test],
2931
dropna=True,
3032
).rename(columns={label_attribute_test: "label"})
3133

34+
if num_train is not None:
35+
num_samples = num_train + len(df_test)
36+
else:
37+
num_samples = None
38+
3239
df_train = client.get_record_export(
3340
tokenize=False,
34-
keep_attributes=[_input, label_attribute_train],
41+
keep_attributes=primary_keys + [_input, label_attribute_train],
3542
dropna=True,
36-
num_samples=num_train + len(df_test),
43+
num_samples=num_samples,
3744
).rename(columns={label_attribute_train: "label"})
3845

3946
# Remove overlapping data
@@ -47,4 +54,5 @@ def split_train_test_on_weak_supervision(
4754
df_train.reset_index(drop=True),
4855
df_test.reset_index(drop=True),
4956
label_options,
57+
primary_keys,
5058
)

0 commit comments

Comments
 (0)