diff --git a/refinery/__init__.py b/refinery/__init__.py index d2b030d..b477ed6 100644 --- a/refinery/__init__.py +++ b/refinery/__init__.py @@ -81,6 +81,21 @@ def get_project_details(self) -> Dict[str, str]: api_response = api_calls.get_request(url, self.session_token) return api_response + def get_primary_keys(self) -> List[str]: + """Fetches the primary keys of your current project. + + Returns: + List[str]: Containing the primary keys of your project. + """ + project_details = self.get_project_details() + project_attributes = project_details["attributes"] + + primary_keys = [] + for attribute in project_attributes: + if attribute["is_primary_key"]: + primary_keys.append(attribute["name"]) + return primary_keys + def get_lookup_list(self, list_id: str) -> Dict[str, str]: """Fetches a lookup list of your current project. diff --git a/refinery/adapter/sklearn.py b/refinery/adapter/sklearn.py index 5da7954..ed888ef 100644 --- a/refinery/adapter/sklearn.py +++ b/refinery/adapter/sklearn.py @@ -25,7 +25,7 @@ def build_classification_dataset( Dict[str, Dict[str, Any]]: Containing the train and test datasets, with embedded inputs. """ - df_train, df_test, _ = split_train_test_on_weak_supervision( + df_train, df_test, _, primary_keys = split_train_test_on_weak_supervision( client, sentence_input, classification_label, num_train ) @@ -40,7 +40,12 @@ def build_classification_dataset( return { "train": { "inputs": inputs_train, + "index": df_train[primary_keys].to_dict("records"), "labels": df_train["label"], }, - "test": {"inputs": inputs_test, "labels": df_test["label"]}, + "test": { + "inputs": inputs_test, + "index": df_test[primary_keys].to_dict("records"), + "labels": df_test["label"], + }, } diff --git a/refinery/adapter/transformers.py b/refinery/adapter/transformers.py index dd5da5c..889529f 100644 --- a/refinery/adapter/transformers.py +++ b/refinery/adapter/transformers.py @@ -18,7 +18,12 @@ def build_classification_dataset( _type_: HuggingFace dataset """ - df_train, df_test, label_options = split_train_test_on_weak_supervision( + ( + df_train, + df_test, + label_options, + primary_keys, + ) = split_train_test_on_weak_supervision( client, sentence_input, classification_label ) @@ -44,4 +49,6 @@ def build_classification_dataset( if os.path.exists(test_file_path): os.remove(test_file_path) - return dataset, mapping + index = {"train": df_train[primary_keys], "test": df_test[primary_keys]} + + return dataset, mapping, index diff --git a/refinery/adapter/util.py b/refinery/adapter/util.py index 64175a0..53f1b69 100644 --- a/refinery/adapter/util.py +++ b/refinery/adapter/util.py @@ -20,20 +20,27 @@ def split_train_test_on_weak_supervision( Tuple[pd.DataFrame, pd.DataFrame, List[str]]: Containing the train and test dataframes and the label name options. """ + primary_keys = client.get_primary_keys() + label_attribute_train = f"{_label}__WEAK_SUPERVISION" label_attribute_test = f"{_label}__MANUAL" df_test = client.get_record_export( tokenize=False, - keep_attributes=[_input, label_attribute_test], + keep_attributes=primary_keys + [_input, label_attribute_test], dropna=True, ).rename(columns={label_attribute_test: "label"}) + if num_train is not None: + num_samples = num_train + len(df_test) + else: + num_samples = None + df_train = client.get_record_export( tokenize=False, - keep_attributes=[_input, label_attribute_train], + keep_attributes=primary_keys + [_input, label_attribute_train], dropna=True, - num_samples=num_train + len(df_test), + num_samples=num_samples, ).rename(columns={label_attribute_train: "label"}) # Remove overlapping data @@ -47,4 +54,5 @@ def split_train_test_on_weak_supervision( df_train.reset_index(drop=True), df_test.reset_index(drop=True), label_options, + primary_keys, )