diff --git a/scaleapi/__init__.py b/scaleapi/__init__.py index 53d232f..838edfb 100644 --- a/scaleapi/__init__.py +++ b/scaleapi/__init__.py @@ -320,7 +320,7 @@ def tasks(self, **kwargs) -> Tasklist: def get_tasks( self, - project_name: str, + project_name: str = None, batch_name: str = None, task_type: TaskType = None, status: TaskStatus = None, @@ -345,7 +345,7 @@ def get_tasks( `task_list = list(get_tasks(...))` Args: - project_name (str): + project_name (str, optional): Project Name batch_name (str, optional): @@ -412,6 +412,11 @@ def get_tasks( Yields Task objects, can be iterated. """ + if not project_name and not batch_name: + raise ValueError( + "At least one of project_name or batch_name must be provided." + ) + next_token = None has_more = True @@ -548,7 +553,7 @@ def get_tasks_count( @staticmethod def _process_tasks_endpoint_args( - project_name: str, + project_name: str = None, batch_name: str = None, task_type: TaskType = None, status: TaskStatus = None, @@ -565,6 +570,11 @@ def _process_tasks_endpoint_args( limited_response: bool = None, ): """Generates args for /tasks endpoint.""" + if not project_name and not batch_name: + raise ValueError( + "At least one of project_name or batch_name must be provided." + ) + tasks_args = { "start_time": created_after, "end_time": created_before, diff --git a/scaleapi/_version.py b/scaleapi/_version.py index b899919..187da73 100644 --- a/scaleapi/_version.py +++ b/scaleapi/_version.py @@ -1,2 +1,2 @@ -__version__ = "2.15.10" +__version__ = "2.15.11" __package_name__ = "scaleapi" diff --git a/tests/test_client.py b/tests/test_client.py index 672fc4f..e3a007d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -503,3 +503,41 @@ def test_list_teammates(): # assert len(new_teammates) >= len( # old_teammates # ) # needs to sleep for teammates list to be updated + + +def test_get_tasks_without_project_name(): + with pytest.raises(ValueError): + list(client.get_tasks()) + + +def test_get_tasks_with_optional_project_name(): + batch = create_a_batch() + tasks = [] + for _ in range(3): + tasks.append(make_a_task(batch=batch.name)) + task_ids = {task.id for task in tasks} + for task in client.get_tasks( + project_name=None, + batch_name=batch.name, + limit=1, + ): + assert task.id in task_ids + + +def test_process_tasks_endpoint_args_with_optional_project_name(): + args = client._process_tasks_endpoint_args(batch_name="test_batch") + assert args["project"] is None + assert args["batch"] == "test_batch" + + +def test_get_tasks_with_batch_name(): + batch = create_a_batch() + tasks = [] + for _ in range(3): + tasks.append(make_a_task(batch=batch.name)) + task_ids = {task.id for task in tasks} + for task in client.get_tasks( + batch_name=batch.name, + limit=1, + ): + assert task.id in task_ids