diff --git a/README.md b/README.md index 41d5592..588c363 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![refinery repository](https://uploads-ssl.webflow.com/61e47fafb12bd56b40022a49/62cf1c3cb8272b1e9c01127e_refinery%20sdk%20banner.png)](https://github.com/code-kern-ai/refinery) [![Python 3.9](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-390/) -[![pypi 1.0.2](https://img.shields.io/badge/pypi-1.0.2-yellow.svg)](https://pypi.org/project/refinery-python-sdk/1.0.2/) +[![pypi 1.1.0](https://img.shields.io/badge/pypi-1.1.0-yellow.svg)](https://pypi.org/project/refinery-python-sdk/1.1.0/) This is the official Python SDK for [*refinery*](https://github.com/code-kern-ai/refinery), the **open-source** data-centric IDE for NLP. @@ -12,6 +12,8 @@ This is the official Python SDK for [*refinery*](https://github.com/code-kern-ai - [Fetching lookup lists](#fetching-lookup-lists) - [Upload files](#upload-files) - [Adapters](#adapters) + - [HuggingFace](#hugging-face) + - [Sklearn](#sklearn) - [Rasa](#rasa) - [What's missing?](#whats-missing) - [Roadmap](#roadmap) @@ -120,6 +122,77 @@ Alternatively, you can `rsdk push ` via CLI, given that you h ### Adapters +#### 🤗 Hugging Face +Transformers are great, but often times, you want to finetune them for your downstream task. With *refinery*, you can do so easily by letting the SDK build the dataset for you that you can use as a plug-and-play base for your training: + +```python +from refinery.adapter import transformers +dataset, mapping = transformers.build_dataset(client, "headline", "__clickbait") +``` + +From here, you can follow the [finetuning example](https://huggingface.co/docs/transformers/training) provided in the official Hugging Face documentation. A next step could look as follows: + +```python +small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000)) +small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(1000)) + +from transformers import ( + AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer +) +import numpy as np +from datasets import load_metric + +tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + +def tokenize_function(examples): + return tokenizer(examples["headline"], padding="max_length", truncation=True) + +tokenized_datasets = dataset.map(tokenize_function, batched=True) +model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) +training_args = TrainingArguments(output_dir="test_trainer") +metric = load_metric("accuracy") + +def compute_metrics(eval_pred): + logits, labels = eval_pred + predictions = np.argmax(logits, axis=-1) + return metric.compute(predictions=predictions, references=labels) + +training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch") + +small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) +small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=small_train_dataset, + eval_dataset=small_eval_dataset, + compute_metrics=compute_metrics, +) + +trainer.train() + +trainer.save_model("path/to/model") +``` + +#### Sklearn +You can use *refinery* to directly pull data into a format you can apply for building [sklearn](https://github.com/scikit-learn/scikit-learn) models. This can look as follows: + +```python +from refinery.adapter.embedders import build_classification_dataset +from sklearn.tree import DecisionTreeClassifier + +data = build_classification_dataset(client, "headline", "__clickbait", "distilbert-base-uncased") + +clf = DecisionTreeClassifier() +clf.fit(data["train"]["inputs"], data["train"]["labels"]) + +pred_test = clf.predict(data["test"]["inputs"]) +accuracy = (pred_test == data["test"]["labels"]).mean() +``` + +By the way, we can highly recommend to combine this with [Truss](https://github.com/basetenlabs/truss) for easy model serving! + #### Rasa *refinery* is perfect to be used for building chatbots with [Rasa](https://github.com/RasaHQ/rasa). We've built an adapter with which you can easily create the required Rasa training data directly from *refinery*. diff --git a/refinery/__init__.py b/refinery/__init__.py index 6e70e9a..d2b030d 100644 --- a/refinery/__init__.py +++ b/refinery/__init__.py @@ -111,6 +111,8 @@ def get_record_export( num_samples: Optional[int] = None, download_to: Optional[str] = None, tokenize: Optional[bool] = True, + keep_attributes: Optional[List[str]] = None, + dropna: Optional[bool] = False, ) -> pd.DataFrame: """Collects the export data of your project (i.e. the same data if you would export in the web app). @@ -155,6 +157,12 @@ def get_record_export( "There are no attributes that can be tokenized in this project." ) + if keep_attributes is not None: + df = df[keep_attributes] + + if dropna: + df = df.dropna() + if download_to is not None: df.to_json(download_to, orient="records") msg.good(f"Downloaded export to {download_to}") @@ -263,7 +271,9 @@ def __monitor_task(self, upload_task_id: str) -> None: if print_success_message: msg.good("File upload successful.") else: - msg.fail("Upload failed. Please look into the UI notification center for more details.") + msg.fail( + "Upload failed. Please look into the UI notification center for more details." + ) def __get_task(self, upload_task_id: str) -> Dict[str, Any]: api_response = api_calls.get_request( diff --git a/refinery/adapter/sklearn.py b/refinery/adapter/sklearn.py new file mode 100644 index 0000000..2f47e62 --- /dev/null +++ b/refinery/adapter/sklearn.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, Optional +from embedders.classification.contextual import TransformerSentenceEmbedder +from refinery import Client +from refinery.adapter.util import split_train_test_on_weak_supervision + + +def build_classification_dataset( + client: Client, + sentence_input: str, + classification_label: str, + config_string: Optional[str] = None, +) -> Dict[str, Dict[str, Any]]: + """ + Builds a classification dataset from a refinery client and a config string. + + Args: + client (Client): Refinery client + sentence_input (str): Name of the column containing the sentence input. + classification_label (str): Name of the label; if this is a task on the full record, enter the string with as "__