Skip to content

Commit c9775e1

Browse files
add task batch metadata and task templates support (#82)
* add task batch metadata and task templates support * reformatting * type fixes * comment length fix * docstrings
1 parent a52dcd6 commit c9775e1

File tree

5 files changed

+90
-5
lines changed

5 files changed

+90
-5
lines changed

scaleapi/__init__.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from scaleapi.evaluation_tasks import EvaluationTask
55
from scaleapi.exceptions import ScaleInvalidRequest
66
from scaleapi.files import File
7-
from scaleapi.projects import Project
7+
from scaleapi.projects import Project, TaskTemplate
88
from scaleapi.training_tasks import TrainingTask
99

1010
from ._version import __version__ # noqa: F401
@@ -619,6 +619,7 @@ def create_batch(
619619
callback: str = "",
620620
calibration_batch: bool = False,
621621
self_label_batch: bool = False,
622+
metadata: Dict = None,
622623
) -> Batch:
623624
"""Create a new Batch within a project.
624625
https://docs.scale.com/reference#batch-creation
@@ -639,6 +640,8 @@ def create_batch(
639640
Only applicable for self serve projects.
640641
Create a self_label batch by setting
641642
the self_label_batch flag to true.
643+
metadata (Dict):
644+
Optional metadata to be stored at the TaskBatch level
642645
643646
Returns:
644647
Batch: Created batch object
@@ -650,6 +653,7 @@ def create_batch(
650653
calibration_batch=calibration_batch,
651654
self_label_batch=self_label_batch,
652655
callback=callback,
656+
metadata=metadata or {},
653657
)
654658
batchdata = self.api.post_request(endpoint, body=payload)
655659
return Batch(batchdata, self)
@@ -828,6 +832,22 @@ def get_batches(
828832
offset += batches.limit
829833
has_more = batches.has_more
830834

835+
def set_batch_metadata(self, batch_name: str, metadata: Dict) -> Batch:
836+
"""Sets metadata for a TaskBatch.
837+
838+
Args:
839+
batch_name (str):
840+
Batch name
841+
metadata (Dict):
842+
Metadata to set for TaskBatch
843+
844+
Returns:
845+
Batch
846+
"""
847+
endpoint = f"batches/{Api.quote_string(batch_name)}/setMetadata"
848+
batchdata = self.api.post_request(endpoint, body=metadata)
849+
return Batch(batchdata, self)
850+
831851
def create_project(
832852
self,
833853
project_name: str,
@@ -934,6 +954,23 @@ def update_project(self, project_name: str, **kwargs) -> Project:
934954
projectdata = self.api.post_request(endpoint, body=kwargs)
935955
return Project(projectdata, self)
936956

957+
def get_project_template(self, project_name: str) -> TaskTemplate:
958+
"""Gets the task template of a project if a template exists.
959+
Throws an error if the project task-type does not support
960+
Task Templates. Currently only TextCollection and Chat task
961+
types support Task Templates.
962+
963+
Args:
964+
project_name (str):
965+
Project's name
966+
967+
Returns:
968+
TaskTemplate
969+
"""
970+
endpoint = f"projects/{Api.quote_string(project_name)}/taskTemplates"
971+
template = self.api.get_request(endpoint)
972+
return TaskTemplate(template, self)
973+
937974
def upload_file(self, file: IO, **kwargs) -> File:
938975
"""Upload file.
939976
Refer to Files API Reference:

scaleapi/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = "2.15.2"
1+
__version__ = "2.15.3"
22
__package_name__ = "scaleapi"

scaleapi/batches.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, json, client):
1919
self.project = json["project"]
2020
self.created_at = json["created_at"]
2121
self.project = json["project"]
22+
self.metadata = json["metadata"]
2223

2324
self.tasks_pending = None
2425
self.tasks_completed = None

scaleapi/projects.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
1+
class TaskTemplate:
2+
"""Task Template Object."""
3+
4+
def __init__(self, json, client):
5+
self._json = json
6+
self._client = client
7+
self.id = json["id"]
8+
self.project = json["project"]
9+
self.version = json["version"]
10+
self.created_at = json["created_at"]
11+
self.updated_at = json["updated_at"]
12+
self.template_variables = json["template_variables"]
13+
14+
def __hash__(self):
15+
return hash(self.id)
16+
17+
def __str__(self):
18+
return f"TaskTemplate(id={self.id}, project={self.project})"
19+
20+
def __repr__(self):
21+
return f"TaskTemplate({self._json})"
22+
23+
def get_template_variables(self):
24+
"""Returns template variables dictionary"""
25+
return self.template_variables
26+
27+
def as_dict(self):
28+
"""Returns task template object as JSON dictionary"""
29+
return self._json
30+
31+
132
class Project:
233
"""Project class, containing Project information."""
334

@@ -25,6 +56,11 @@ def __str__(self):
2556
def __repr__(self):
2657
return f"Project({self._json})"
2758

59+
def get_template(self) -> TaskTemplate:
60+
"""Returns TaskTemplate.
61+
Only works for Chat and TextCollection type."""
62+
return self._client.get_project_template(self.name)
63+
2864
def as_dict(self):
2965
"""Returns all attributes as a dictionary"""
3066
return self._json

tests/test_client.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=missing-function-docstring
2-
2+
import os
33
import time
44
import uuid
55
from datetime import datetime
@@ -21,7 +21,7 @@
2121

2222
try:
2323
print(f"SDK Version: {scaleapi.__version__}")
24-
test_api_key = "test_fe79860cdbe547bf91b4e7da897a6c92"
24+
test_api_key = os.environ["SCALE_TEST_API_KEY"]
2525

2626
if test_api_key.startswith("test_") or test_api_key.endswith("|test"):
2727
client = scaleapi.ScaleClient(test_api_key, "pytest")
@@ -392,6 +392,7 @@ def create_a_batch():
392392
callback="http://www.example.com/callback",
393393
batch_name=str(uuid.uuid4()),
394394
project=TEST_PROJECT_NAME,
395+
metadata={"some_key": "some_value"},
395396
)
396397

397398

@@ -437,6 +438,8 @@ def test_get_batch():
437438
batch2 = client.get_batch(batch.name)
438439
assert batch.name == batch2.name
439440
assert batch2.status == BatchStatus.InProgress.value
441+
# test metadata
442+
assert batch2.metadata["some_key"] == "some_value"
440443

441444

442445
def test_batches():
@@ -459,6 +462,12 @@ def test_get_batches():
459462
assert total_batches == len(all_batches)
460463

461464

465+
def test_set_batch_metadata():
466+
batch = create_a_batch()
467+
batch = client.set_batch_metadata(batch.name, {"new_key": "new_value"})
468+
assert batch.metadata["new_key"] == "new_value"
469+
470+
462471
def test_files_upload():
463472
with open("tests/test_image.png", "rb") as f:
464473
client.upload_file(
@@ -496,7 +505,9 @@ def test_invite_teammates():
496505
try:
497506
project = client.get_project(STUDIO_TEST_PROJECT)
498507
except ScaleResourceNotFound:
499-
client.create_project(project_name=STUDIO_TEST_PROJECT)
508+
client.create_project(
509+
project_name=STUDIO_TEST_PROJECT, task_type=TaskType.ImageAnnotation
510+
)
500511
STUDIO_BATCH_TEST_NAME = f"studio-test-batch-{current_timestamp}"
501512

502513

0 commit comments

Comments
 (0)