Skip to content

Commit d561340

Browse files
authored
minor bugfix in adapters for sklearn (#20)
1 parent a68a99e commit d561340

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ project_id = "your-project-id" # can be found in the URL of the web application
4141

4242
client = Client(user_name, password, project_id)
4343
# if you run the application locally, please use the following instead:
44-
# client = Client(username, password, project_id, uri="http://localhost:4455")
44+
# client = Client(user_name, password, project_id, uri="http://localhost:4455")
4545
```
4646

4747
The `project_id` can be found in your browser, e.g. if you run the app on your localhost: `http://localhost:4455/app/projects/{project_id}/overview`

refinery/adapter/sklearn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def build_classification_dataset(
99
sentence_input: str,
1010
classification_label: str,
1111
config_string: Optional[str] = None,
12+
num_train: Optional[int] = None,
1213
) -> Dict[str, Dict[str, Any]]:
1314
"""
1415
Builds a classification dataset from a refinery client and a config string.
@@ -18,22 +19,23 @@ def build_classification_dataset(
1819
sentence_input (str): Name of the column containing the sentence input.
1920
classification_label (str): Name of the label; if this is a task on the full record, enter the string with as "__<label>". Else, input it as "<attribute>__<label>".
2021
config_string (Optional[str], optional): Config string for the TransformerSentenceEmbedder. Defaults to None; if None is provided, the text will not be embedded.
22+
num_train (Optional[int], optional): Number of training examples to use. Defaults to None; if None is provided, all examples will be used.
2123
2224
Returns:
2325
Dict[str, Dict[str, Any]]: Containing the train and test datasets, with embedded inputs.
2426
"""
2527

26-
df_test, df_train, _ = split_train_test_on_weak_supervision(
27-
client, sentence_input, classification_label
28+
df_train, df_test, _ = split_train_test_on_weak_supervision(
29+
client, sentence_input, classification_label, num_train
2830
)
2931

3032
if config_string is not None:
3133
embedder = TransformerSentenceEmbedder(config_string)
32-
inputs_test = embedder.transform(df_test[sentence_input].tolist())
3334
inputs_train = embedder.transform(df_train[sentence_input].tolist())
35+
inputs_test = embedder.transform(df_test[sentence_input].tolist())
3436
else:
35-
inputs_test = df_test[sentence_input].tolist()
3637
inputs_train = df_train[sentence_input].tolist()
38+
inputs_test = df_test[sentence_input].tolist()
3739

3840
return {
3941
"train": {

refinery/adapter/util.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import List, Tuple
1+
from typing import List, Optional, Tuple
22
from refinery import Client
33
import pandas as pd
44

55

66
def split_train_test_on_weak_supervision(
7-
client: Client, _input: str, _label: str
7+
client: Client, _input: str, _label: str, num_train: Optional[int] = None
88
) -> Tuple[pd.DataFrame, pd.DataFrame, List[str]]:
99
"""
1010
Puts the data into a train (weakly supervised data) and test set (manually labeled data).
@@ -14,6 +14,7 @@ def split_train_test_on_weak_supervision(
1414
client (Client): Refinery client
1515
_input (str): Name of the column containing the sentence input.
1616
_label (str): Name of the label; if this is a task on the full record, enter the string with as "__<label>". Else, input it as "<attribute>__<label>".
17+
num_train (Optional[int], optional): Number of training examples to use. Defaults to None; if None is provided, all examples will be used.
1718
1819
Returns:
1920
Tuple[pd.DataFrame, pd.DataFrame, List[str]]: Containing the train and test dataframes and the label name options.
@@ -22,19 +23,21 @@ def split_train_test_on_weak_supervision(
2223
label_attribute_train = f"{_label}__WEAK_SUPERVISION"
2324
label_attribute_test = f"{_label}__MANUAL"
2425

25-
df_train = client.get_record_export(
26-
tokenize=False,
27-
keep_attributes=[_input, label_attribute_train],
28-
dropna=True,
29-
).rename(columns={label_attribute_train: "label"})
30-
3126
df_test = client.get_record_export(
3227
tokenize=False,
3328
keep_attributes=[_input, label_attribute_test],
3429
dropna=True,
3530
).rename(columns={label_attribute_test: "label"})
3631

37-
df_train = df_train.drop(df_test.index)
32+
df_train = client.get_record_export(
33+
tokenize=False,
34+
keep_attributes=[_input, label_attribute_train],
35+
dropna=True,
36+
num_samples=num_train + len(df_test),
37+
).rename(columns={label_attribute_train: "label"})
38+
39+
# Remove overlapping data
40+
df_train = df_train.drop(df_test.index.intersection(df_train.index))[:num_train]
3841

3942
label_options = list(
4043
set(df_test.label.unique().tolist() + df_train.label.unique().tolist())

0 commit comments

Comments
 (0)