Skip to content

Commit 54a5a12

Browse files
minor fixes and unit tests additions
1 parent 2b2882c commit 54a5a12

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
614614
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
615615
namespace=rc["metadata"]["namespace"],
616616
dashboard=ray_route,
617+
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
618+
"resources"
619+
]["limits"]["cpu"],
620+
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
621+
"resources"
622+
]["limits"]["memory"],
623+
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
624+
"resources"
625+
]["limits"]["nvidia.com/gpu"],
617626
)
618627

619628

@@ -644,6 +653,9 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
644653
worker_gpu=cluster.config.num_gpus,
645654
namespace=cluster.config.namespace,
646655
dashboard=cluster.cluster_dashboard_uri(),
656+
head_cpus=cluster.config.head_cpus,
657+
head_mem=cluster.config.head_memory,
658+
head_gpu=cluster.config.head_gpus,
647659
)
648660
if ray.status == CodeFlareClusterStatus.READY:
649661
ray.status = RayClusterStatus.READY

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,17 @@ def update_priority(yaml, item, dispatch_priority, priority_val):
107107

108108

109109
def update_custompodresources(
110-
item, min_cpu, max_cpu, min_memory, max_memory, gpu, workers, head_cpus, head_memory, head_gpus):
110+
item,
111+
min_cpu,
112+
max_cpu,
113+
min_memory,
114+
max_memory,
115+
gpu,
116+
workers,
117+
head_cpus,
118+
head_memory,
119+
head_gpus,
120+
):
111121
if "custompodresources" in item.keys():
112122
custompodresources = item.get("custompodresources")
113123
for i in range(len(custompodresources)):
@@ -120,8 +130,8 @@ def update_custompodresources(
120130
resource["limits"]["memory"] = str(head_memory) + "G"
121131
resource["requests"]["nvidia.com/gpu"] = head_gpus
122132
resource["limits"]["nvidia.com/gpu"] = head_gpus
123-
124-
else:
133+
134+
else:
125135
for k, v in resource.items():
126136
if k == "replicas" and i == 1:
127137
resource[k] = workers
@@ -217,8 +227,8 @@ def update_nodes(
217227
):
218228
if "generictemplate" in item.keys():
219229
head = item.get("generictemplate").get("spec").get("headGroupSpec")
220-
head["rayStartParams"]["num_gpus"] = str(int(head_gpus))
221-
230+
head["rayStartParams"]["num-gpus"] = str(int(head_gpus))
231+
222232
worker = item.get("generictemplate").get("spec").get("workerGroupSpecs")[0]
223233
# Head counts as first worker
224234
worker["replicas"] = workers
@@ -235,7 +245,9 @@ def update_nodes(
235245
update_env(spec, env)
236246
if comp == head:
237247
# TODO: Eventually add head node configuration outside of template
238-
update_resources(spec, head_cpus, head_cpus, head_memory, head_memory, head_gpus)
248+
update_resources(
249+
spec, head_cpus, head_cpus, head_memory, head_memory, head_gpus
250+
)
239251
else:
240252
update_resources(spec, min_cpu, max_cpu, min_memory, max_memory, gpu)
241253

@@ -388,7 +400,17 @@ def generate_appwrapper(
388400
update_labels(user_yaml, instascale, instance_types)
389401
update_priority(user_yaml, item, dispatch_priority, priority_val)
390402
update_custompodresources(
391-
item, min_cpu, max_cpu, min_memory, max_memory, gpu, workers, head_cpus, head_memory, head_gpus)
403+
item,
404+
min_cpu,
405+
max_cpu,
406+
min_memory,
407+
max_memory,
408+
gpu,
409+
workers,
410+
head_cpus,
411+
head_memory,
412+
head_gpus,
413+
)
392414
update_nodes(
393415
item,
394416
appwrapper_name,

tests/unit_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ def test_ray_details(mocker, capsys):
525525
worker_gpu=0,
526526
namespace="ns",
527527
dashboard="fake-uri",
528+
head_cpus=2,
529+
head_mem=8,
530+
head_gpu=0,
528531
)
529532
mocker.patch(
530533
"codeflare_sdk.cluster.cluster.Cluster.status",
@@ -1685,6 +1688,9 @@ def test_cluster_status(mocker):
16851688
worker_gpu=0,
16861689
namespace="ns",
16871690
dashboard="fake-uri",
1691+
head_cpus=2,
1692+
head_mem=8,
1693+
head_gpu=0,
16881694
)
16891695
cf = Cluster(ClusterConfiguration(name="test", namespace="ns"))
16901696
mocker.patch("codeflare_sdk.cluster.cluster._app_wrapper_status", return_value=None)

0 commit comments

Comments
 (0)