|
| 1 | +import argparse |
| 2 | +import sys |
| 3 | +from collections import defaultdict |
| 4 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 5 | + |
| 6 | +import scaleapi |
| 7 | +from scaleapi.exceptions import ScaleException, ScaleUnauthorized |
| 8 | + |
| 9 | +# Script that takes in an array of batch names (split by comma) and |
| 10 | +# applies a bulk action to cancel all tasks in each batch. |
| 11 | +# By default, this script makes 50 concurrent API calls. |
| 12 | + |
| 13 | +# Example: |
| 14 | +# python cancel_batch.py --api_key "SCALE_API_KEY" |
| 15 | +# --batches "batch1,batch2" --clear "True" |
| 16 | + |
| 17 | +# Change this for update concurrency |
| 18 | +MAX_WORKERS = 50 |
| 19 | + |
| 20 | + |
| 21 | +def cancel_batch(client, batch_name, clear_unique_id): |
| 22 | + print(f"\nProcessing Batch: {batch_name}") |
| 23 | + try: |
| 24 | + batch = client.get_batch(batch_name) |
| 25 | + except ScaleException: |
| 26 | + print(f"-ERROR: Batch {batch_name} not found.") |
| 27 | + return |
| 28 | + |
| 29 | + task_count = client.get_tasks_count( |
| 30 | + project_name=batch.project, batch_name=batch.name |
| 31 | + ) |
| 32 | + print(f"-Batch {batch.name} contains {task_count} tasks.") |
| 33 | + |
| 34 | + summary_metrics = defaultdict(lambda: 0) |
| 35 | + task_in_progress = 0 |
| 36 | + processes = [] |
| 37 | + |
| 38 | + tasks = client.get_tasks(project_name=batch.project, batch_name=batch.name) |
| 39 | + |
| 40 | + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| 41 | + for task in tasks: |
| 42 | + task_in_progress += 1 |
| 43 | + if task_in_progress % 1000 == 0: |
| 44 | + print(f"-Processing Task # {task_in_progress}") |
| 45 | + processes.append( |
| 46 | + executor.submit( |
| 47 | + cancel_task_with_response, client, task, clear_unique_id |
| 48 | + ) |
| 49 | + ) |
| 50 | + |
| 51 | + for process in as_completed(processes): |
| 52 | + result = process.result() |
| 53 | + summary_metrics[result["status"]] += 1 |
| 54 | + |
| 55 | + for k, v in summary_metrics.items(): |
| 56 | + print(f"--{k}: {v} tasks") |
| 57 | + |
| 58 | + |
| 59 | +def cancel_task_with_response(client: scaleapi.ScaleClient, task, clear_unique_ud): |
| 60 | + task_status = task.as_dict()["status"] |
| 61 | + if task_status in ["completed", "canceled"]: |
| 62 | + return {"task": task.id, "status": task_status} |
| 63 | + |
| 64 | + try: |
| 65 | + task = client.cancel_task(task.id, clear_unique_ud) |
| 66 | + return {"task": task.id, "status": task.as_dict()["status"]} |
| 67 | + except ScaleException: |
| 68 | + return {"task": task.id, "status": "Can not cancel"} |
| 69 | + except Exception as err: |
| 70 | + print(err) |
| 71 | + return {"task": task.id, "status": "Errored"} |
| 72 | + |
| 73 | + |
| 74 | +def get_args(): |
| 75 | + ap = argparse.ArgumentParser() |
| 76 | + ap.add_argument("--api_key", required=True, help="Please provide Scale API Key") |
| 77 | + ap.add_argument( |
| 78 | + "--batches", required=True, help="Please enter batch names separated by a comma" |
| 79 | + ) |
| 80 | + ap.add_argument( |
| 81 | + "--clear", |
| 82 | + type=bool, |
| 83 | + help="Set to True if you want to remove unique_id upon cancel", |
| 84 | + ) |
| 85 | + return ap.parse_args() |
| 86 | + |
| 87 | + |
| 88 | +def main(): |
| 89 | + args = get_args() |
| 90 | + clear_unique_id = args.clear or False |
| 91 | + |
| 92 | + client = scaleapi.ScaleClient(args.api_key) |
| 93 | + |
| 94 | + # Testing API Key |
| 95 | + try: |
| 96 | + client.projects() |
| 97 | + except ScaleUnauthorized as err: |
| 98 | + print(err.message) |
| 99 | + sys.exit(1) |
| 100 | + |
| 101 | + batch_list = args.batches.split(",") |
| 102 | + batches = [word.strip() for word in batch_list] |
| 103 | + |
| 104 | + for batch_name in batches: |
| 105 | + cancel_batch(client, batch_name, clear_unique_id) |
| 106 | + |
| 107 | + |
| 108 | +if __name__ == "__main__": |
| 109 | + main() |
0 commit comments