diff --git a/scaleapi/__init__.py b/scaleapi/__init__.py index ae96c64..26523a7 100644 --- a/scaleapi/__init__.py +++ b/scaleapi/__init__.py @@ -4,7 +4,7 @@ from scaleapi.evaluation_tasks import EvaluationTask from scaleapi.exceptions import ScaleInvalidRequest from scaleapi.files import File -from scaleapi.projects import Project +from scaleapi.projects import Project, TaskTemplate from scaleapi.training_tasks import TrainingTask from ._version import __version__ # noqa: F401 @@ -619,6 +619,7 @@ def create_batch( callback: str = "", calibration_batch: bool = False, self_label_batch: bool = False, + metadata: Dict = None, ) -> Batch: """Create a new Batch within a project. https://docs.scale.com/reference#batch-creation @@ -639,6 +640,8 @@ def create_batch( Only applicable for self serve projects. Create a self_label batch by setting the self_label_batch flag to true. + metadata (Dict): + Optional metadata to be stored at the TaskBatch level Returns: Batch: Created batch object @@ -650,6 +653,7 @@ def create_batch( calibration_batch=calibration_batch, self_label_batch=self_label_batch, callback=callback, + metadata=metadata or {}, ) batchdata = self.api.post_request(endpoint, body=payload) return Batch(batchdata, self) @@ -828,6 +832,22 @@ def get_batches( offset += batches.limit has_more = batches.has_more + def set_batch_metadata(self, batch_name: str, metadata: Dict) -> Batch: + """Sets metadata for a TaskBatch. + + Args: + batch_name (str): + Batch name + metadata (Dict): + Metadata to set for TaskBatch + + Returns: + Batch + """ + endpoint = f"batches/{Api.quote_string(batch_name)}/setMetadata" + batchdata = self.api.post_request(endpoint, body=metadata) + return Batch(batchdata, self) + def create_project( self, project_name: str, @@ -934,6 +954,23 @@ def update_project(self, project_name: str, **kwargs) -> Project: projectdata = self.api.post_request(endpoint, body=kwargs) return Project(projectdata, self) + def get_project_template(self, project_name: str) -> TaskTemplate: + """Gets the task template of a project if a template exists. + Throws an error if the project task-type does not support + Task Templates. Currently only TextCollection and Chat task + types support Task Templates. + + Args: + project_name (str): + Project's name + + Returns: + TaskTemplate + """ + endpoint = f"projects/{Api.quote_string(project_name)}/taskTemplates" + template = self.api.get_request(endpoint) + return TaskTemplate(template, self) + def upload_file(self, file: IO, **kwargs) -> File: """Upload file. Refer to Files API Reference: diff --git a/scaleapi/_version.py b/scaleapi/_version.py index d6ce145..f60d06a 100644 --- a/scaleapi/_version.py +++ b/scaleapi/_version.py @@ -1,2 +1,2 @@ -__version__ = "2.15.2" +__version__ = "2.15.3" __package_name__ = "scaleapi" diff --git a/scaleapi/batches.py b/scaleapi/batches.py index d0cb73a..743b2bc 100644 --- a/scaleapi/batches.py +++ b/scaleapi/batches.py @@ -19,6 +19,7 @@ def __init__(self, json, client): self.project = json["project"] self.created_at = json["created_at"] self.project = json["project"] + self.metadata = json["metadata"] self.tasks_pending = None self.tasks_completed = None diff --git a/scaleapi/projects.py b/scaleapi/projects.py index 8a0536f..17147d6 100644 --- a/scaleapi/projects.py +++ b/scaleapi/projects.py @@ -1,3 +1,34 @@ +class TaskTemplate: + """Task Template Object.""" + + def __init__(self, json, client): + self._json = json + self._client = client + self.id = json["id"] + self.project = json["project"] + self.version = json["version"] + self.created_at = json["created_at"] + self.updated_at = json["updated_at"] + self.template_variables = json["template_variables"] + + def __hash__(self): + return hash(self.id) + + def __str__(self): + return f"TaskTemplate(id={self.id}, project={self.project})" + + def __repr__(self): + return f"TaskTemplate({self._json})" + + def get_template_variables(self): + """Returns template variables dictionary""" + return self.template_variables + + def as_dict(self): + """Returns task template object as JSON dictionary""" + return self._json + + class Project: """Project class, containing Project information.""" @@ -25,6 +56,11 @@ def __str__(self): def __repr__(self): return f"Project({self._json})" + def get_template(self) -> TaskTemplate: + """Returns TaskTemplate. + Only works for Chat and TextCollection type.""" + return self._client.get_project_template(self.name) + def as_dict(self): """Returns all attributes as a dictionary""" return self._json diff --git a/tests/test_client.py b/tests/test_client.py index 63d4c67..9bf52ea 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,5 @@ # pylint: disable=missing-function-docstring - +import os import time import uuid from datetime import datetime @@ -21,7 +21,7 @@ try: print(f"SDK Version: {scaleapi.__version__}") - test_api_key = "test_fe79860cdbe547bf91b4e7da897a6c92" + test_api_key = os.environ["SCALE_TEST_API_KEY"] if test_api_key.startswith("test_") or test_api_key.endswith("|test"): client = scaleapi.ScaleClient(test_api_key, "pytest") @@ -392,6 +392,7 @@ def create_a_batch(): callback="http://www.example.com/callback", batch_name=str(uuid.uuid4()), project=TEST_PROJECT_NAME, + metadata={"some_key": "some_value"}, ) @@ -437,6 +438,8 @@ def test_get_batch(): batch2 = client.get_batch(batch.name) assert batch.name == batch2.name assert batch2.status == BatchStatus.InProgress.value + # test metadata + assert batch2.metadata["some_key"] == "some_value" def test_batches(): @@ -459,6 +462,12 @@ def test_get_batches(): assert total_batches == len(all_batches) +def test_set_batch_metadata(): + batch = create_a_batch() + batch = client.set_batch_metadata(batch.name, {"new_key": "new_value"}) + assert batch.metadata["new_key"] == "new_value" + + def test_files_upload(): with open("tests/test_image.png", "rb") as f: client.upload_file( @@ -496,7 +505,9 @@ def test_invite_teammates(): try: project = client.get_project(STUDIO_TEST_PROJECT) except ScaleResourceNotFound: - client.create_project(project_name=STUDIO_TEST_PROJECT) + client.create_project( + project_name=STUDIO_TEST_PROJECT, task_type=TaskType.ImageAnnotation + ) STUDIO_BATCH_TEST_NAME = f"studio-test-batch-{current_timestamp}"