Skip to content

Commit 87dd88a

Browse files
committed
Added persistent_volumes config variable
1 parent c0f7d7f commit 87dd88a

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def create_app_wrapper(self):
154154
verify_tls = self.config.verify_tls
155155
local_queue = self.config.local_queue
156156
labels = self.config.labels
157+
persistent_volumes = self.config.persistent_volumes
157158
return generate_appwrapper(
158159
name=name,
159160
namespace=namespace,
@@ -176,6 +177,7 @@ def create_app_wrapper(self):
176177
verify_tls=verify_tls,
177178
local_queue=local_queue,
178179
labels=labels,
180+
persistent_volumes=persistent_volumes,
179181
)
180182

181183
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ClusterConfiguration:
5353
write_to_file: bool = False
5454
verify_tls: bool = True
5555
labels: dict = field(default_factory=dict)
56+
persistent_volumes: dict = field(default_factory=list)
5657

5758
def __post_init__(self):
5859
if not self.verify_tls:

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,51 @@ def update_image_pull_secrets(spec, image_pull_secrets):
9999
]
100100

101101

102+
def validate_persistent_volumes(persistent_volumes) -> bool:
103+
mandatory_keys = {"name", "mountPath", "claimName"}
104+
105+
if not isinstance(persistent_volumes, list):
106+
raise ValueError("persistent_volumes must be a list")
107+
108+
for pv in persistent_volumes:
109+
if not isinstance(pv, dict):
110+
raise ValueError("Each item in persistent_volumes must be a dict")
111+
112+
missed_keys = mandatory_keys - pv.keys()
113+
if missed_keys:
114+
raise ValueError(f"Missing keys in persistent volume: {missed_keys}")
115+
116+
for key in mandatory_keys:
117+
if not isinstance(pv[key], str) or not pv[key]:
118+
raise ValueError(f"{key} must be a string")
119+
120+
return True
121+
122+
123+
def update_persistent_volume_mounts(spec, persistent_volumes):
124+
validate_persistent_volumes(persistent_volumes)
125+
containers = spec.get("containers")
126+
for pv in persistent_volumes:
127+
persistent_volume = client.V1Volume(
128+
name=pv["name"],
129+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
130+
claim_name=pv["claimName"]
131+
),
132+
)
133+
134+
persistent_volume = client.ApiClient().sanitize_for_serialization(
135+
persistent_volume
136+
)
137+
spec["volumes"].append(persistent_volume)
138+
139+
for container in containers:
140+
volumeMount = client.V1VolumeMount(
141+
name=pv["name"], mount_path=pv["mountPath"]
142+
)
143+
volumeMount = client.ApiClient().sanitize_for_serialization(volumeMount)
144+
container["volumeMounts"].append(volumeMount)
145+
146+
102147
def update_env(spec, env):
103148
containers = spec.get("containers")
104149
for container in containers:
@@ -139,6 +184,7 @@ def update_nodes(
139184
head_cpus,
140185
head_memory,
141186
head_gpus,
187+
persistent_volumes,
142188
):
143189
if "template" in item.keys():
144190
head = item.get("template").get("spec").get("headGroupSpec")
@@ -154,6 +200,8 @@ def update_nodes(
154200

155201
for comp in [head, worker]:
156202
spec = comp.get("template").get("spec")
203+
if persistent_volumes is not []:
204+
update_persistent_volume_mounts(spec, persistent_volumes)
157205
update_image_pull_secrets(spec, image_pull_secrets)
158206
update_image(spec, image)
159207
update_env(spec, env)
@@ -311,6 +359,7 @@ def generate_appwrapper(
311359
verify_tls: bool,
312360
local_queue: Optional[str],
313361
labels,
362+
persistent_volumes: list[dict[str, str]],
314363
):
315364
user_yaml = read_template(template)
316365
appwrapper_name, cluster_name = gen_names(name)
@@ -338,6 +387,7 @@ def generate_appwrapper(
338387
head_cpus,
339388
head_memory,
340389
head_gpus,
390+
persistent_volumes,
341391
)
342392

343393
augment_labels(item, labels)

0 commit comments

Comments
 (0)