Skip to content

Commit 10b8a36

Browse files
committed
Review changes & list_cluster functions
1 parent 9246b81 commit 10b8a36

File tree

6 files changed

+399
-51
lines changed

6 files changed

+399
-51
lines changed

src/codeflare_sdk/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
RayCluster,
1313
AppWrapper,
1414
get_cluster,
15+
list_all_queued,
16+
list_all_clusters,
1517
)
1618

1719
from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient

src/codeflare_sdk/cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
AppWrapper,
1414
)
1515

16-
from .cluster import Cluster, ClusterConfiguration, get_cluster
16+
from .cluster import Cluster, ClusterConfiguration, get_cluster, list_all_queued, list_all_clusters
1717

1818
from .awload import AWManager

src/codeflare_sdk/cluster/cluster.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -620,17 +620,24 @@ def list_all_clusters(namespace: str, print_to_console: bool = True):
620620
return clusters
621621

622622

623-
def list_all_queued(namespace: str, print_to_console: bool = True):
623+
def list_all_queued(namespace: str, print_to_console: bool = True, mcad: bool = False):
624624
"""
625-
Returns (and prints by default) a list of all currently queued-up AppWrappers
625+
Returns (and prints by default) a list of all currently queued-up Ray Clusters or AppWrappers
626626
in a given namespace.
627627
"""
628-
app_wrappers = _get_app_wrappers(
629-
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
630-
)
631-
if print_to_console:
632-
pretty_print.print_app_wrappers_status(app_wrappers)
633-
return app_wrappers
628+
if mcad:
629+
resources = _get_app_wrappers(
630+
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
631+
)
632+
if print_to_console:
633+
pretty_print.print_app_wrappers_status(resources)
634+
else:
635+
resources = _get_ray_clusters(
636+
namespace, filter=[RayClusterStatus.READY, RayClusterStatus.SUSPENDED]
637+
)
638+
if print_to_console:
639+
pretty_print.print_ray_clusters_status(resources)
640+
return resources
634641

635642

636643
def get_current_namespace(): # pragma: no cover
@@ -913,7 +920,9 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
913920
return None
914921

915922

916-
def _get_ray_clusters(namespace="default") -> List[RayCluster]:
923+
def _get_ray_clusters(
924+
namespace="default", filter: Optional[List[RayClusterStatus]] = None
925+
) -> List[RayCluster]:
917926
list_of_clusters = []
918927
try:
919928
config_check()
@@ -927,8 +936,15 @@ def _get_ray_clusters(namespace="default") -> List[RayCluster]:
927936
except Exception as e: # pragma: no cover
928937
return _kube_api_error_handling(e)
929938

930-
for rc in rcs["items"]:
931-
list_of_clusters.append(_map_to_ray_cluster(rc))
939+
# Get a list of RCs with the filter if it is passed to the function
940+
if filter is not None:
941+
for rc in rcs["items"]:
942+
ray_cluster = _map_to_ray_cluster(rc)
943+
if filter and ray_cluster.status in filter:
944+
list_of_clusters.append(ray_cluster)
945+
else:
946+
for rc in rcs["items"]:
947+
list_of_clusters.append(_map_to_ray_cluster(rc))
932948
return list_of_clusters
933949

934950

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
(in the cluster sub-module) for AppWrapper generation.
1818
"""
1919

20+
from typing import Optional
2021
import typing
2122
import yaml
2223
import sys
@@ -635,39 +636,34 @@ def _create_oauth_sidecar_object(
635636
)
636637

637638

638-
def get_default_kueue_name(local_queue: str, namespace: str):
639+
def get_default_kueue_name(namespace: str):
639640
# If the local queue is set, use it. Otherwise, try to use the default queue.
640-
if local_queue is not None:
641-
return local_queue
642-
else:
643-
try:
644-
config_check()
645-
api_instance = client.CustomObjectsApi(api_config_handler())
646-
local_queues = api_instance.list_namespaced_custom_object(
647-
group="kueue.x-k8s.io",
648-
version="v1beta1",
649-
namespace=namespace,
650-
plural="localqueues",
651-
)
652-
except Exception as e: # pragma: no cover
653-
return _kube_api_error_handling(e)
654-
for lq in local_queues["items"]:
655-
if (
656-
"annotations" in lq["metadata"]
657-
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
658-
and lq["metadata"]["annotations"][
659-
"kueue.x-k8s.io/default-queue"
660-
].lower()
661-
== "true"
662-
):
663-
return lq["metadata"]["name"]
664-
raise ValueError(
665-
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
641+
try:
642+
config_check()
643+
api_instance = client.CustomObjectsApi(api_config_handler())
644+
local_queues = api_instance.list_namespaced_custom_object(
645+
group="kueue.x-k8s.io",
646+
version="v1beta1",
647+
namespace=namespace,
648+
plural="localqueues",
666649
)
650+
except Exception as e: # pragma: no cover
651+
return _kube_api_error_handling(e)
652+
for lq in local_queues["items"]:
653+
if (
654+
"annotations" in lq["metadata"]
655+
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
656+
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
657+
== "true"
658+
):
659+
return lq["metadata"]["name"]
660+
raise ValueError(
661+
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
662+
)
667663

668664

669665
def write_components(
670-
user_yaml: dict, output_file_name: str, namespace: str, local_queue: str
666+
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
671667
):
672668
# Create the directory if it doesn't exist
673669
directory_path = os.path.dirname(output_file_name)
@@ -676,6 +672,7 @@ def write_components(
676672

677673
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
678674
open(output_file_name, "w").close()
675+
lq_name = local_queue or get_default_kueue_name(namespace)
679676
with open(output_file_name, "a") as outfile:
680677
for component in components:
681678
if "generictemplate" in component:
@@ -687,13 +684,7 @@ def write_components(
687684
"workload.codeflare.dev/appwrapper"
688685
]
689686
labels = component["generictemplate"]["metadata"]["labels"]
690-
labels.update(
691-
{
692-
"kueue.x-k8s.io/queue-name": get_default_kueue_name(
693-
local_queue, namespace
694-
)
695-
}
696-
)
687+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
697688
outfile.write("---\n")
698689
yaml.dump(
699690
component["generictemplate"], outfile, default_flow_style=False
@@ -748,7 +739,7 @@ def generate_appwrapper(
748739
ingress_domain: str,
749740
ingress_options: dict,
750741
write_to_file: bool,
751-
local_queue: str,
742+
local_queue: Optional[str],
752743
):
753744
user_yaml = read_template(template)
754745
appwrapper_name, cluster_name = gen_names(name)

src/codeflare_sdk/utils/pretty_print.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,30 @@ def print_app_wrappers_status(app_wrappers: List[AppWrapper], starting: bool = F
5656
console.print(Panel.fit(table))
5757

5858

59+
def print_ray_clusters_status(app_wrappers: List[AppWrapper], starting: bool = False):
60+
if not app_wrappers:
61+
print_no_resources_found()
62+
return # shortcircuit
63+
64+
console = Console()
65+
table = Table(
66+
box=box.ASCII_DOUBLE_HEAD,
67+
title="[bold] :rocket: Cluster Queue Status :rocket:",
68+
)
69+
table.add_column("Name", style="cyan", no_wrap=True)
70+
table.add_column("Status", style="magenta")
71+
72+
for app_wrapper in app_wrappers:
73+
name = app_wrapper.name
74+
status = app_wrapper.status.value
75+
if starting:
76+
status += " (starting)"
77+
table.add_row(name, status)
78+
table.add_row("") # empty row for spacing
79+
80+
console.print(Panel.fit(table))
81+
82+
5983
def print_cluster_status(cluster: RayCluster):
6084
"Pretty prints the status of a passed-in cluster"
6185
if not cluster:

0 commit comments

Comments
 (0)