Skip to content

Commit 2b2882c

Browse files
add ability to configure head node
1 parent 2e543ca commit 2b2882c

File tree

4 files changed

+55
-28
lines changed

4 files changed

+55
-28
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def create_app_wrapper(self):
109109

110110
name = self.config.name
111111
namespace = self.config.namespace
112+
head_cpus = self.config.head_cpus
113+
head_memory = self.config.head_memory
114+
head_gpus = self.config.head_gpus
112115
min_cpu = self.config.min_cpus
113116
max_cpu = self.config.max_cpus
114117
min_memory = self.config.min_memory
@@ -126,6 +129,9 @@ def create_app_wrapper(self):
126129
return generate_appwrapper(
127130
name=name,
128131
namespace=namespace,
132+
head_cpus=head_cpus,
133+
head_memory=head_memory,
134+
head_gpus=head_gpus,
129135
min_cpu=min_cpu,
130136
max_cpu=max_cpu,
131137
min_memory=min_memory,

src/codeflare_sdk/cluster/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class ClusterConfiguration:
3434
name: str
3535
namespace: str = None
3636
head_info: list = field(default_factory=list)
37+
head_cpus: int = 2
38+
head_memory: int = 8
39+
head_gpus: int = 0
3740
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
3841
min_cpus: int = 1
3942
max_cpus: int = 1

src/codeflare_sdk/cluster/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class RayCluster:
6969

7070
name: str
7171
status: RayClusterStatus
72+
head_cpus: int
73+
head_mem: str
74+
head_gpu: int
7275
workers: int
7376
worker_mem_min: str
7477
worker_mem_max: str

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -107,35 +107,41 @@ def update_priority(yaml, item, dispatch_priority, priority_val):
107107

108108

109109
def update_custompodresources(
110-
item, min_cpu, max_cpu, min_memory, max_memory, gpu, workers
111-
):
110+
item, min_cpu, max_cpu, min_memory, max_memory, gpu, workers, head_cpus, head_memory, head_gpus):
112111
if "custompodresources" in item.keys():
113112
custompodresources = item.get("custompodresources")
114113
for i in range(len(custompodresources)):
114+
resource = custompodresources[i]
115115
if i == 0:
116116
# Leave head node resources as template default
117-
continue
118-
resource = custompodresources[i]
119-
for k, v in resource.items():
120-
if k == "replicas" and i == 1:
121-
resource[k] = workers
122-
if k == "requests" or k == "limits":
123-
for spec, _ in v.items():
124-
if spec == "cpu":
125-
if k == "limits":
126-
resource[k][spec] = max_cpu
127-
else:
128-
resource[k][spec] = min_cpu
129-
if spec == "memory":
130-
if k == "limits":
131-
resource[k][spec] = str(max_memory) + "G"
132-
else:
133-
resource[k][spec] = str(min_memory) + "G"
134-
if spec == "nvidia.com/gpu":
135-
if i == 0:
136-
resource[k][spec] = 0
137-
else:
138-
resource[k][spec] = gpu
117+
resource["requests"]["cpu"] = head_cpus
118+
resource["limits"]["cpu"] = head_cpus
119+
resource["requests"]["memory"] = str(head_memory) + "G"
120+
resource["limits"]["memory"] = str(head_memory) + "G"
121+
resource["requests"]["nvidia.com/gpu"] = head_gpus
122+
resource["limits"]["nvidia.com/gpu"] = head_gpus
123+
124+
else:
125+
for k, v in resource.items():
126+
if k == "replicas" and i == 1:
127+
resource[k] = workers
128+
if k == "requests" or k == "limits":
129+
for spec, _ in v.items():
130+
if spec == "cpu":
131+
if k == "limits":
132+
resource[k][spec] = max_cpu
133+
else:
134+
resource[k][spec] = min_cpu
135+
if spec == "memory":
136+
if k == "limits":
137+
resource[k][spec] = str(max_memory) + "G"
138+
else:
139+
resource[k][spec] = str(min_memory) + "G"
140+
if spec == "nvidia.com/gpu":
141+
if i == 0:
142+
resource[k][spec] = 0
143+
else:
144+
resource[k][spec] = gpu
139145
else:
140146
sys.exit("Error: malformed template")
141147

@@ -205,11 +211,15 @@ def update_nodes(
205211
instascale,
206212
env,
207213
image_pull_secrets,
214+
head_cpus,
215+
head_memory,
216+
head_gpus,
208217
):
209218
if "generictemplate" in item.keys():
210219
head = item.get("generictemplate").get("spec").get("headGroupSpec")
220+
head["rayStartParams"]["num_gpus"] = str(int(head_gpus))
221+
211222
worker = item.get("generictemplate").get("spec").get("workerGroupSpecs")[0]
212-
213223
# Head counts as first worker
214224
worker["replicas"] = workers
215225
worker["minReplicas"] = workers
@@ -225,7 +235,7 @@ def update_nodes(
225235
update_env(spec, env)
226236
if comp == head:
227237
# TODO: Eventually add head node configuration outside of template
228-
continue
238+
update_resources(spec, head_cpus, head_cpus, head_memory, head_memory, head_gpus)
229239
else:
230240
update_resources(spec, min_cpu, max_cpu, min_memory, max_memory, gpu)
231241

@@ -350,6 +360,9 @@ def write_user_appwrapper(user_yaml, output_file_name):
350360
def generate_appwrapper(
351361
name: str,
352362
namespace: str,
363+
head_cpus: int,
364+
head_memory: int,
365+
head_gpus: int,
353366
min_cpu: int,
354367
max_cpu: int,
355368
min_memory: int,
@@ -375,8 +388,7 @@ def generate_appwrapper(
375388
update_labels(user_yaml, instascale, instance_types)
376389
update_priority(user_yaml, item, dispatch_priority, priority_val)
377390
update_custompodresources(
378-
item, min_cpu, max_cpu, min_memory, max_memory, gpu, workers
379-
)
391+
item, min_cpu, max_cpu, min_memory, max_memory, gpu, workers, head_cpus, head_memory, head_gpus)
380392
update_nodes(
381393
item,
382394
appwrapper_name,
@@ -390,6 +402,9 @@ def generate_appwrapper(
390402
instascale,
391403
env,
392404
image_pull_secrets,
405+
head_cpus,
406+
head_memory,
407+
head_gpus,
393408
)
394409
update_dashboard_route(route_item, cluster_name, namespace)
395410
if local_interactive:

0 commit comments

Comments
 (0)