Skip to content

add task batch metadata and task templates support #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion scaleapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scaleapi.evaluation_tasks import EvaluationTask
from scaleapi.exceptions import ScaleInvalidRequest
from scaleapi.files import File
from scaleapi.projects import Project
from scaleapi.projects import Project, TaskTemplate
from scaleapi.training_tasks import TrainingTask

from ._version import __version__ # noqa: F401
Expand Down Expand Up @@ -619,6 +619,7 @@ def create_batch(
callback: str = "",
calibration_batch: bool = False,
self_label_batch: bool = False,
metadata: Dict = None,
) -> Batch:
"""Create a new Batch within a project.
https://docs.scale.com/reference#batch-creation
Expand All @@ -639,6 +640,8 @@ def create_batch(
Only applicable for self serve projects.
Create a self_label batch by setting
the self_label_batch flag to true.
metadata (Dict):
Optional metadata to be stored at the TaskBatch level

Returns:
Batch: Created batch object
Expand All @@ -650,6 +653,7 @@ def create_batch(
calibration_batch=calibration_batch,
self_label_batch=self_label_batch,
callback=callback,
metadata=metadata or {},
)
batchdata = self.api.post_request(endpoint, body=payload)
return Batch(batchdata, self)
Expand Down Expand Up @@ -828,6 +832,22 @@ def get_batches(
offset += batches.limit
has_more = batches.has_more

def set_batch_metadata(self, batch_name: str, metadata: Dict) -> Batch:
"""Sets metadata for a TaskBatch.

Args:
batch_name (str):
Batch name
metadata (Dict):
Metadata to set for TaskBatch

Returns:
Batch
"""
endpoint = f"batches/{Api.quote_string(batch_name)}/setMetadata"
batchdata = self.api.post_request(endpoint, body=metadata)
return Batch(batchdata, self)

def create_project(
self,
project_name: str,
Expand Down Expand Up @@ -934,6 +954,23 @@ def update_project(self, project_name: str, **kwargs) -> Project:
projectdata = self.api.post_request(endpoint, body=kwargs)
return Project(projectdata, self)

def get_project_template(self, project_name: str) -> TaskTemplate:
"""Gets the task template of a project if a template exists.
Throws an error if the project task-type does not support
Task Templates. Currently only TextCollection and Chat task
types support Task Templates.

Args:
project_name (str):
Project's name

Returns:
TaskTemplate
"""
endpoint = f"projects/{Api.quote_string(project_name)}/taskTemplates"
template = self.api.get_request(endpoint)
return TaskTemplate(template, self)

def upload_file(self, file: IO, **kwargs) -> File:
"""Upload file.
Refer to Files API Reference:
Expand Down
2 changes: 1 addition & 1 deletion scaleapi/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "2.15.2"
__version__ = "2.15.3"
__package_name__ = "scaleapi"
1 change: 1 addition & 0 deletions scaleapi/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, json, client):
self.project = json["project"]
self.created_at = json["created_at"]
self.project = json["project"]
self.metadata = json["metadata"]

self.tasks_pending = None
self.tasks_completed = None
Expand Down
36 changes: 36 additions & 0 deletions scaleapi/projects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
class TaskTemplate:
"""Task Template Object."""

def __init__(self, json, client):
self._json = json
self._client = client
self.id = json["id"]
self.project = json["project"]
self.version = json["version"]
self.created_at = json["created_at"]
self.updated_at = json["updated_at"]
self.template_variables = json["template_variables"]

def __hash__(self):
return hash(self.id)

def __str__(self):
return f"TaskTemplate(id={self.id}, project={self.project})"

def __repr__(self):
return f"TaskTemplate({self._json})"

def get_template_variables(self):
"""Returns template variables dictionary"""
return self.template_variables

def as_dict(self):
"""Returns task template object as JSON dictionary"""
return self._json


class Project:
"""Project class, containing Project information."""

Expand Down Expand Up @@ -25,6 +56,11 @@ def __str__(self):
def __repr__(self):
return f"Project({self._json})"

def get_template(self) -> TaskTemplate:
"""Returns TaskTemplate.
Only works for Chat and TextCollection type."""
return self._client.get_project_template(self.name)

def as_dict(self):
"""Returns all attributes as a dictionary"""
return self._json
17 changes: 14 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=missing-function-docstring

import os
import time
import uuid
from datetime import datetime
Expand All @@ -21,7 +21,7 @@

try:
print(f"SDK Version: {scaleapi.__version__}")
test_api_key = "test_fe79860cdbe547bf91b4e7da897a6c92"
test_api_key = os.environ["SCALE_TEST_API_KEY"]

if test_api_key.startswith("test_") or test_api_key.endswith("|test"):
client = scaleapi.ScaleClient(test_api_key, "pytest")
Expand Down Expand Up @@ -392,6 +392,7 @@ def create_a_batch():
callback="http://www.example.com/callback",
batch_name=str(uuid.uuid4()),
project=TEST_PROJECT_NAME,
metadata={"some_key": "some_value"},
)


Expand Down Expand Up @@ -437,6 +438,8 @@ def test_get_batch():
batch2 = client.get_batch(batch.name)
assert batch.name == batch2.name
assert batch2.status == BatchStatus.InProgress.value
# test metadata
assert batch2.metadata["some_key"] == "some_value"


def test_batches():
Expand All @@ -459,6 +462,12 @@ def test_get_batches():
assert total_batches == len(all_batches)


def test_set_batch_metadata():
batch = create_a_batch()
batch = client.set_batch_metadata(batch.name, {"new_key": "new_value"})
assert batch.metadata["new_key"] == "new_value"


def test_files_upload():
with open("tests/test_image.png", "rb") as f:
client.upload_file(
Expand Down Expand Up @@ -496,7 +505,9 @@ def test_invite_teammates():
try:
project = client.get_project(STUDIO_TEST_PROJECT)
except ScaleResourceNotFound:
client.create_project(project_name=STUDIO_TEST_PROJECT)
client.create_project(
project_name=STUDIO_TEST_PROJECT, task_type=TaskType.ImageAnnotation
)
STUDIO_BATCH_TEST_NAME = f"studio-test-batch-{current_timestamp}"


Expand Down