@@ -99,6 +99,51 @@ def update_image_pull_secrets(spec, image_pull_secrets):
99
99
]
100
100
101
101
102
+ def validate_persistent_volumes (persistent_volumes ) -> bool :
103
+ mandatory_keys = {"name" , "mountPath" , "claimName" }
104
+
105
+ if not isinstance (persistent_volumes , list ):
106
+ raise ValueError ("persistent_volumes must be a list" )
107
+
108
+ for pv in persistent_volumes :
109
+ if not isinstance (pv , dict ):
110
+ raise ValueError ("Each item in persistent_volumes must be a dict" )
111
+
112
+ missed_keys = mandatory_keys - pv .keys ()
113
+ if missed_keys :
114
+ raise ValueError (f"Missing keys in persistent volume: { missed_keys } " )
115
+
116
+ for key in mandatory_keys :
117
+ if not isinstance (pv [key ], str ) or not pv [key ]:
118
+ raise ValueError (f"{ key } must be a string" )
119
+
120
+ return True
121
+
122
+
123
+ def update_persistent_volume_mounts (spec , persistent_volumes ):
124
+ validate_persistent_volumes (persistent_volumes )
125
+ containers = spec .get ("containers" )
126
+ for pv in persistent_volumes :
127
+ persistent_volume = client .V1Volume (
128
+ name = pv ["name" ],
129
+ persistent_volume_claim = client .V1PersistentVolumeClaimVolumeSource (
130
+ claim_name = pv ["claimName" ]
131
+ ),
132
+ )
133
+
134
+ persistent_volume = client .ApiClient ().sanitize_for_serialization (
135
+ persistent_volume
136
+ )
137
+ spec ["volumes" ].append (persistent_volume )
138
+
139
+ for container in containers :
140
+ volumeMount = client .V1VolumeMount (
141
+ name = pv ["name" ], mount_path = pv ["mountPath" ]
142
+ )
143
+ volumeMount = client .ApiClient ().sanitize_for_serialization (volumeMount )
144
+ container ["volumeMounts" ].append (volumeMount )
145
+
146
+
102
147
def update_env (spec , env ):
103
148
containers = spec .get ("containers" )
104
149
for container in containers :
@@ -139,6 +184,7 @@ def update_nodes(
139
184
head_cpus ,
140
185
head_memory ,
141
186
head_gpus ,
187
+ persistent_volumes ,
142
188
):
143
189
if "template" in item .keys ():
144
190
head = item .get ("template" ).get ("spec" ).get ("headGroupSpec" )
@@ -154,6 +200,8 @@ def update_nodes(
154
200
155
201
for comp in [head , worker ]:
156
202
spec = comp .get ("template" ).get ("spec" )
203
+ if persistent_volumes is not []:
204
+ update_persistent_volume_mounts (spec , persistent_volumes )
157
205
update_image_pull_secrets (spec , image_pull_secrets )
158
206
update_image (spec , image )
159
207
update_env (spec , env )
@@ -311,6 +359,7 @@ def generate_appwrapper(
311
359
verify_tls : bool ,
312
360
local_queue : Optional [str ],
313
361
labels ,
362
+ persistent_volumes : list [dict [str , str ]],
314
363
):
315
364
user_yaml = read_template (template )
316
365
appwrapper_name , cluster_name = gen_names (name )
@@ -338,6 +387,7 @@ def generate_appwrapper(
338
387
head_cpus ,
339
388
head_memory ,
340
389
head_gpus ,
390
+ persistent_volumes ,
341
391
)
342
392
343
393
augment_labels (item , labels )
0 commit comments