Skip to content

Commit c6a0723

Browse files
committed
Added custom Volumes and Volume Mounts support
1 parent 6798b74 commit c6a0723

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def create_app_wrapper(self):
154154
verify_tls = self.config.verify_tls
155155
local_queue = self.config.local_queue
156156
labels = self.config.labels
157+
volumes = self.config.volumes
158+
volume_mounts = self.config.volume_mounts
157159
return generate_appwrapper(
158160
name=name,
159161
namespace=namespace,
@@ -176,6 +178,8 @@ def create_app_wrapper(self):
176178
verify_tls=verify_tls,
177179
local_queue=local_queue,
178180
labels=labels,
181+
volumes=volumes,
182+
volume_mounts=volume_mounts,
179183
)
180184

181185
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class ClusterConfiguration:
5353
write_to_file: bool = False
5454
verify_tls: bool = True
5555
labels: dict = field(default_factory=dict)
56+
volumes: list = field(default_factory=list)
57+
volume_mounts: list = field(default_factory=list)
5658

5759
def __post_init__(self):
5860
if not self.verify_tls:

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,26 @@ def update_image_pull_secrets(spec, image_pull_secrets):
9999
]
100100

101101

102+
def update_volume_mounts(spec, volume_mounts: list):
103+
if volume_mounts is []:
104+
return None
105+
106+
containers = spec.get("containers")
107+
for volume_mount in volume_mounts:
108+
for container in containers:
109+
volumeMount = client.ApiClient().sanitize_for_serialization(volume_mount)
110+
container["volumeMounts"].append(volumeMount)
111+
112+
113+
def update_volumes(spec, volumes: list):
114+
if volumes is []:
115+
return None
116+
117+
for volume in volumes:
118+
new_volume = client.ApiClient().sanitize_for_serialization(volume)
119+
spec["volumes"].append(new_volume)
120+
121+
102122
def update_env(spec, env):
103123
containers = spec.get("containers")
104124
for container in containers:
@@ -139,6 +159,8 @@ def update_nodes(
139159
head_cpus,
140160
head_memory,
141161
head_gpus,
162+
volumes,
163+
volume_mounts,
142164
):
143165
if "template" in item.keys():
144166
head = item.get("template").get("spec").get("headGroupSpec")
@@ -154,6 +176,8 @@ def update_nodes(
154176

155177
for comp in [head, worker]:
156178
spec = comp.get("template").get("spec")
179+
update_volume_mounts(spec, volume_mounts)
180+
update_volumes(spec, volumes)
157181
update_image_pull_secrets(spec, image_pull_secrets)
158182
update_image(spec, image)
159183
update_env(spec, env)
@@ -322,6 +346,8 @@ def generate_appwrapper(
322346
verify_tls: bool,
323347
local_queue: Optional[str],
324348
labels,
349+
volumes: list[client.V1Volume],
350+
volume_mounts: list[client.V1VolumeMount],
325351
):
326352
user_yaml = read_template(template)
327353
appwrapper_name, cluster_name = gen_names(name)
@@ -349,6 +375,8 @@ def generate_appwrapper(
349375
head_cpus,
350376
head_memory,
351377
head_gpus,
378+
volumes,
379+
volume_mounts,
352380
)
353381

354382
augment_labels(item, labels)

0 commit comments

Comments
 (0)