From bd76baf8d9e8c18db4a1c43c277b33b498daca82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Sun, 11 Sep 2022 15:21:41 +0200 Subject: [PATCH 1/6] adds option to add external information sources like model callbacks --- refinery/__init__.py | 31 +++++++++++++++++++++++++++++++ refinery/callbacks/inference.py | 31 ++++++++++++++++++++++++++++--- refinery/settings.py | 3 +++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/refinery/__init__.py b/refinery/__init__.py index b477ed6..1a823fe 100644 --- a/refinery/__init__.py +++ b/refinery/__init__.py @@ -183,6 +183,37 @@ def get_record_export( msg.good(f"Downloaded export to {download_to}") return df + def post_associations( + self, + associations, + indices, + name, + label_task_name, + source_type: Optional[str] = "heuristic", + ): + """Posts associations to the server. + + Args: + associations (List[Dict[str, str]]): List of associations to post. + indices (List[str]): List of indices to post to. + name (str): Name of the association set. + label_task_name (str): Name of the label task. + source_type (Optional[str], optional): Source type of the associations. Defaults to "heuristic". + """ + url = settings.get_associations_url(self.project_id) + api_response = api_calls.post_request( + url, + { + "associations": associations, + "indices": indices, + "name": name, + "label_task_name": label_task_name, + "source_type": source_type, + }, + self.session_token, + ) + return api_response + def post_file_import( self, path: str, import_file_options: Optional[str] = "" ) -> bool: diff --git a/refinery/callbacks/inference.py b/refinery/callbacks/inference.py index 42bc870..7c4c340 100644 --- a/refinery/callbacks/inference.py +++ b/refinery/callbacks/inference.py @@ -1,11 +1,30 @@ +from typing import Callable, Optional import pandas as pd -from refinery import exceptions +from refinery import Client, exceptions class ModelCallback: def __init__( - self, client, inference_fn, preprocessing_fn=None, postprocessing_fn=None + self, + model_name: str, + label_task_name: str, + inference_fn: Callable, + client: Client, + preprocessing_fn: Optional[Callable] = None, + postprocessing_fn: Optional[Callable] = None, ): + """ + + Args: + model_name (str): Name of the model (as an idenfitier in refinery) + label_task_name (str): Name of the label task (from refinery) + inference_fn (Callable): Function to predict the output + client (Client): Refinery client + preprocessing_fn (Optional[Callable], optional): Function to apply preprocessing to your inputs. Defaults to None. + postprocessing_fn (Optional[Callable], optional): Function to apply postprocessing to the inference function's output. Defaults to None. + """ + self.model_name = model_name + self.label_task_name = label_task_name self.client = client self.inference_fn = inference_fn self.preprocessing_fn = preprocessing_fn @@ -37,4 +56,10 @@ def run(self, inputs, indices): if self.postprocessing_fn is not None: batched_outputs = self.postprocessing_fn(batched_outputs) - yield {"index": batched_indices, "associations": batched_outputs} + response = self.client.post_associations( + batched_outputs, + batched_indices, + self.model_name, + self.label_task_name, + "model_callback", + ) diff --git a/refinery/settings.py b/refinery/settings.py index 64c7c5a..65075f4 100644 --- a/refinery/settings.py +++ b/refinery/settings.py @@ -43,6 +43,9 @@ def get_export_url(project_id: str) -> str: def get_import_url(project_id: str) -> str: return f"{get_project_url(project_id)}/import" +def get_associations_url(project_id: str) -> str: + return f"{get_project_url(project_id)}/associations" + def get_base_config(project_id: str) -> str: return f"{get_project_url(project_id)}/import/base_config" From 8e3e8b1c611519f32c6848585b246ba290dcb907 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Mon, 12 Sep 2022 23:03:32 +0200 Subject: [PATCH 2/6] refactor model callback --- refinery/callbacks/inference.py | 72 +++++++++++++++++++++++++++------ refinery/callbacks/sklearn.py | 52 ++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 refinery/callbacks/sklearn.py diff --git a/refinery/callbacks/inference.py b/refinery/callbacks/inference.py index 7c4c340..8d16f33 100644 --- a/refinery/callbacks/inference.py +++ b/refinery/callbacks/inference.py @@ -1,4 +1,5 @@ -from typing import Callable, Optional +from email.generator import Generator +from typing import Any, Callable, Dict, List, Optional import pandas as pd from refinery import Client, exceptions @@ -6,40 +7,73 @@ class ModelCallback: def __init__( self, + client: Client, model_name: str, label_task_name: str, inference_fn: Callable, - client: Client, + initialize_fn: Optional[Callable] = None, preprocessing_fn: Optional[Callable] = None, postprocessing_fn: Optional[Callable] = None, + **kwargs ): """ Args: - model_name (str): Name of the model (as an idenfitier in refinery) - label_task_name (str): Name of the label task (from refinery) - inference_fn (Callable): Function to predict the output client (Client): Refinery client - preprocessing_fn (Optional[Callable], optional): Function to apply preprocessing to your inputs. Defaults to None. - postprocessing_fn (Optional[Callable], optional): Function to apply postprocessing to the inference function's output. Defaults to None. + model_name (str): Name of the model + label_task_name (str): Name of the label task + inference_fn (Callable): Function you want to apply for inference + initialize_fn (Optional[Callable], optional): Function to execute to compute internal states. Defaults to None. + preprocessing_fn (Optional[Callable], optional): Function to preprocess model inputs. Defaults to None. + postprocessing_fn (Optional[Callable], optional): Function to postprocess model outputs. Defaults to None. """ self.model_name = model_name self.label_task_name = label_task_name self.client = client self.inference_fn = inference_fn + self.initialize_fn = initialize_fn self.preprocessing_fn = preprocessing_fn self.postprocessing_fn = postprocessing_fn - self.primary_keys = client.get_primary_keys() + self.kwargs = kwargs @staticmethod - def __batch(documents): + def __batch(documents: List[Any]) -> Generator: + """Batch documents into chunks of BATCH_SIZE. + + Args: + documents (List[Any]): List of documents + + Yields: + Generator: Generator of batches + """ BATCH_SIZE = 32 length = len(documents) for idx in range(0, length, BATCH_SIZE): yield documents[idx : min(idx + BATCH_SIZE, length)] - def run(self, inputs, indices): + def initialize( + self, inputs: Optional[List[Any]], labels: Optional[List[Any]] = None + ) -> None: + """Initialize states for the computation. + + Args: + inputs (Optional[List[Any]], optional): List of inputs. Defaults to None. + labels (Optional[List[Any]], optional): List of labels. Defaults to None. + """ + if self.initialize_fn: + self.kwargs = self.initialize_fn(inputs, labels, **self.kwargs) + + def run(self, inputs: List[Any], indices: List[Dict[str, Any]]) -> None: + """Run the pipeline and send the results to refinery. + + Args: + inputs (List[Any]): List of inputs + indices (List[Dict[str, Any]]): List of indices + + Raises: + exceptions.PrimaryKeyError: If the primary key is not found in the indices + """ indices_df = pd.DataFrame(indices) if not all([key in indices_df.columns for key in self.primary_keys]): raise exceptions.PrimaryKeyError("Errorneous primary keys given for index.") @@ -49,17 +83,29 @@ def run(self, inputs, indices): batched_indices = next(index_generator) if self.preprocessing_fn is not None: - batched_inputs = self.preprocessing_fn(batched_inputs) + batched_inputs = self.preprocessing_fn(batched_inputs, **self.kwargs) batched_outputs = self.inference_fn(batched_inputs) if self.postprocessing_fn is not None: - batched_outputs = self.postprocessing_fn(batched_outputs) + batched_outputs = self.postprocessing_fn(batched_outputs, **self.kwargs) - response = self.client.post_associations( + self.client.post_associations( batched_outputs, batched_indices, self.model_name, self.label_task_name, "model_callback", ) + + def initialize_and_run( + self, inputs: List[Any], indices: List[Dict[str, Any]] + ) -> None: + """Initialize and run the pipeline. + + Args: + inputs (List[Any]): List of inputs + indices (List[Dict[str, Any]]): List of indices + """ + self.initialize(inputs) + self.run(inputs, indices) diff --git a/refinery/callbacks/sklearn.py b/refinery/callbacks/sklearn.py new file mode 100644 index 0000000..165ab69 --- /dev/null +++ b/refinery/callbacks/sklearn.py @@ -0,0 +1,52 @@ +from typing import Optional, List, Any, Dict, Callable +from refinery import Client +from refinery.callbacks.inference import ModelCallback +from sklearn.base import BaseEstimator + + +def initialize_fn(inputs, labels, **kwargs): + return {"clf": kwargs["clf"]} + + +def postprocessing_fn(outputs, **kwargs): + named_outputs = [] + for prediction in outputs: + pred_index = prediction.argmax() + label = kwargs["clf"].classes_[pred_index] + confidence = prediction[pred_index] + named_outputs.append([label, confidence]) + return named_outputs + + +class SklearnCallback(ModelCallback): + def __init__( + self, + client: Client, + sklearn_model: BaseEstimator, + labeling_task_name: str, + ) -> None: + """Callback for sklearn models. + + Args: + client (Client): Refinery client + sklearn_model (BaseEstimator): Sklearn model + labeling_task_name (str): Name of the labeling task + """ + + super().__init__( + client, + sklearn_model.__class__.__name__, + labeling_task_name, + inference_fn=sklearn_model.predict_proba, + initialize_fn=initialize_fn, + postprocessing_fn=postprocessing_fn, + ) + self.sklearn_model = sklearn_model + self.initialized = False + self.kwargs = {"clf": self.sklearn_model} + + def run(self, inputs: List[Any], indices: List[Dict[str, Any]]) -> None: + if not self.initialized: + self.initialize(None, None) + self.initialized = True + super().run(inputs, indices) From 7f8eb759e71526eee59e6afcde31f4554566b59e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20H=C3=B6tter?= Date: Mon, 12 Sep 2022 23:44:11 +0200 Subject: [PATCH 3/6] adds pytorch adapter --- refinery/adapter/torch.py | 66 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 refinery/adapter/torch.py diff --git a/refinery/adapter/torch.py b/refinery/adapter/torch.py new file mode 100644 index 0000000..b200672 --- /dev/null +++ b/refinery/adapter/torch.py @@ -0,0 +1,66 @@ +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from sklearn import preprocessing +from .sklearn import ( + build_classification_dataset as sklearn_build_classification_dataset, +) +from typing import Any, Dict, Optional, Tuple +from refinery import Client + + +class Data(Dataset): + def __init__(self, X, y, encoder): + # need to convert float64 to float32 else + # will get the following error + # RuntimeError: expected scalar type Double but found Float + self.X = torch.FloatTensor(X) + # need to convert float64 to Long else + # will get the following error + # RuntimeError: expected scalar type Long but found Float + y_encoded = encoder.transform(y.values) + self.y = torch.from_numpy(y_encoded).type(torch.LongTensor) + self.len = self.X.shape[0] + + def __getitem__(self, index): + return self.X[index], self.y[index] + + def __len__(self): + return self.len + + +def build_classification_dataset( + client: Client, + sentence_input: str, + classification_label: str, + config_string: Optional[str] = None, + num_train: Optional[int] = None, + batch_size: Optional[int] = 32, +) -> Tuple[DataLoader, DataLoader, np.array]: + """ + Builds a classification dataset from a refinery client and a config string. + + Args: + client (Client): Refinery client + sentence_input (str): Name of the column containing the sentence input. + classification_label (str): Name of the label; if this is a task on the full record, enter the string with as "__