@@ -107,35 +107,41 @@ def update_priority(yaml, item, dispatch_priority, priority_val):
107
107
108
108
109
109
def update_custompodresources (
110
- item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers
111
- ):
110
+ item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers , head_cpus , head_memory , head_gpus ):
112
111
if "custompodresources" in item .keys ():
113
112
custompodresources = item .get ("custompodresources" )
114
113
for i in range (len (custompodresources )):
114
+ resource = custompodresources [i ]
115
115
if i == 0 :
116
116
# Leave head node resources as template default
117
- continue
118
- resource = custompodresources [i ]
119
- for k , v in resource .items ():
120
- if k == "replicas" and i == 1 :
121
- resource [k ] = workers
122
- if k == "requests" or k == "limits" :
123
- for spec , _ in v .items ():
124
- if spec == "cpu" :
125
- if k == "limits" :
126
- resource [k ][spec ] = max_cpu
127
- else :
128
- resource [k ][spec ] = min_cpu
129
- if spec == "memory" :
130
- if k == "limits" :
131
- resource [k ][spec ] = str (max_memory ) + "G"
132
- else :
133
- resource [k ][spec ] = str (min_memory ) + "G"
134
- if spec == "nvidia.com/gpu" :
135
- if i == 0 :
136
- resource [k ][spec ] = 0
137
- else :
138
- resource [k ][spec ] = gpu
117
+ resource ["requests" ]["cpu" ] = head_cpus
118
+ resource ["limits" ]["cpu" ] = head_cpus
119
+ resource ["requests" ]["memory" ] = str (head_memory ) + "G"
120
+ resource ["limits" ]["memory" ] = str (head_memory ) + "G"
121
+ resource ["requests" ]["nvidia.com/gpu" ] = head_gpus
122
+ resource ["limits" ]["nvidia.com/gpu" ] = head_gpus
123
+
124
+ else :
125
+ for k , v in resource .items ():
126
+ if k == "replicas" and i == 1 :
127
+ resource [k ] = workers
128
+ if k == "requests" or k == "limits" :
129
+ for spec , _ in v .items ():
130
+ if spec == "cpu" :
131
+ if k == "limits" :
132
+ resource [k ][spec ] = max_cpu
133
+ else :
134
+ resource [k ][spec ] = min_cpu
135
+ if spec == "memory" :
136
+ if k == "limits" :
137
+ resource [k ][spec ] = str (max_memory ) + "G"
138
+ else :
139
+ resource [k ][spec ] = str (min_memory ) + "G"
140
+ if spec == "nvidia.com/gpu" :
141
+ if i == 0 :
142
+ resource [k ][spec ] = 0
143
+ else :
144
+ resource [k ][spec ] = gpu
139
145
else :
140
146
sys .exit ("Error: malformed template" )
141
147
@@ -205,11 +211,15 @@ def update_nodes(
205
211
instascale ,
206
212
env ,
207
213
image_pull_secrets ,
214
+ head_cpus ,
215
+ head_memory ,
216
+ head_gpus ,
208
217
):
209
218
if "generictemplate" in item .keys ():
210
219
head = item .get ("generictemplate" ).get ("spec" ).get ("headGroupSpec" )
220
+ head ["rayStartParams" ]["num_gpus" ] = str (int (head_gpus ))
221
+
211
222
worker = item .get ("generictemplate" ).get ("spec" ).get ("workerGroupSpecs" )[0 ]
212
-
213
223
# Head counts as first worker
214
224
worker ["replicas" ] = workers
215
225
worker ["minReplicas" ] = workers
@@ -225,7 +235,7 @@ def update_nodes(
225
235
update_env (spec , env )
226
236
if comp == head :
227
237
# TODO: Eventually add head node configuration outside of template
228
- continue
238
+ update_resources ( spec , head_cpus , head_cpus , head_memory , head_memory , head_gpus )
229
239
else :
230
240
update_resources (spec , min_cpu , max_cpu , min_memory , max_memory , gpu )
231
241
@@ -350,6 +360,9 @@ def write_user_appwrapper(user_yaml, output_file_name):
350
360
def generate_appwrapper (
351
361
name : str ,
352
362
namespace : str ,
363
+ head_cpus : int ,
364
+ head_memory : int ,
365
+ head_gpus : int ,
353
366
min_cpu : int ,
354
367
max_cpu : int ,
355
368
min_memory : int ,
@@ -375,8 +388,7 @@ def generate_appwrapper(
375
388
update_labels (user_yaml , instascale , instance_types )
376
389
update_priority (user_yaml , item , dispatch_priority , priority_val )
377
390
update_custompodresources (
378
- item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers
379
- )
391
+ item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers , head_cpus , head_memory , head_gpus )
380
392
update_nodes (
381
393
item ,
382
394
appwrapper_name ,
@@ -390,6 +402,9 @@ def generate_appwrapper(
390
402
instascale ,
391
403
env ,
392
404
image_pull_secrets ,
405
+ head_cpus ,
406
+ head_memory ,
407
+ head_gpus ,
393
408
)
394
409
update_dashboard_route (route_item , cluster_name , namespace )
395
410
if local_interactive :
0 commit comments