Skip to content

Commit e2cc8e1

Browse files
committed
New ClusterConfiguration parameter for user labels
1 parent 59cbccc commit e2cc8e1

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def create_app_wrapper(self):
187187
write_to_file = self.config.write_to_file
188188
verify_tls = self.config.verify_tls
189189
local_queue = self.config.local_queue
190+
user_labels = self.config.user_labels
190191
return generate_appwrapper(
191192
name=name,
192193
namespace=namespace,
@@ -211,6 +212,7 @@ def create_app_wrapper(self):
211212
write_to_file=write_to_file,
212213
verify_tls=verify_tls,
213214
local_queue=local_queue,
215+
user_labels=user_labels,
214216
)
215217

216218
# creates a new cluster with the provided or default spec
@@ -453,9 +455,7 @@ def cluster_dashboard_uri(self) -> str:
453455
"name"
454456
] == f"ray-dashboard-{self.config.name}" or route["metadata"][
455457
"name"
456-
].startswith(
457-
f"{self.config.name}-ingress"
458-
):
458+
].startswith(f"{self.config.name}-ingress"):
459459
protocol = "https" if route["spec"].get("tls") else "http"
460460
return f"{protocol}://{route['spec']['host']}"
461461
else:

src/codeflare_sdk/cluster/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class ClusterConfiguration:
5454
dispatch_priority: str = None
5555
write_to_file: bool = False
5656
verify_tls: bool = True
57+
user_labels: dict = field(default_factory=dict)
5758

5859
def __post_init__(self):
5960
if not self.verify_tls:

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,11 @@ def get_default_kueue_name(namespace: str):
309309

310310

311311
def write_components(
312-
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
312+
user_yaml: dict,
313+
output_file_name: str,
314+
namespace: str,
315+
local_queue: Optional[str],
316+
user_labels: dict,
313317
):
314318
# Create the directory if it doesn't exist
315319
directory_path = os.path.dirname(output_file_name)
@@ -331,6 +335,8 @@ def write_components(
331335
]
332336
labels = component["generictemplate"]["metadata"]["labels"]
333337
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
338+
for key in user_labels:
339+
labels.update({key: user_labels[key]})
334340
outfile.write("---\n")
335341
yaml.dump(
336342
component["generictemplate"], outfile, default_flow_style=False
@@ -339,7 +345,11 @@ def write_components(
339345

340346

341347
def load_components(
342-
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
348+
user_yaml: dict,
349+
name: str,
350+
namespace: str,
351+
local_queue: Optional[str],
352+
user_labels: dict,
343353
):
344354
component_list = []
345355
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
@@ -355,6 +365,8 @@ def load_components(
355365
]
356366
labels = component["generictemplate"]["metadata"]["labels"]
357367
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
368+
for key in user_labels:
369+
labels.update({key: user_labels[key]})
358370
component_list.append(component["generictemplate"])
359371

360372
resources = "---\n" + "---\n".join(
@@ -395,6 +407,7 @@ def generate_appwrapper(
395407
write_to_file: bool,
396408
verify_tls: bool,
397409
local_queue: Optional[str],
410+
user_labels,
398411
):
399412
user_yaml = read_template(template)
400413
appwrapper_name, cluster_name = gen_names(name)
@@ -446,11 +459,13 @@ def generate_appwrapper(
446459
if mcad:
447460
write_user_appwrapper(user_yaml, outfile)
448461
else:
449-
write_components(user_yaml, outfile, namespace, local_queue)
462+
write_components(user_yaml, outfile, namespace, local_queue, user_labels)
450463
return outfile
451464
else:
452465
if mcad:
453466
user_yaml = load_appwrapper(user_yaml, name)
454467
else:
455-
user_yaml = load_components(user_yaml, name, namespace, local_queue)
468+
user_yaml = load_components(
469+
user_yaml, name, namespace, local_queue, user_labels
470+
)
456471
return user_yaml

0 commit comments

Comments
 (0)