Skip to content

Commit faa08fe

Browse files
committed
add functions for creating ray with oauth proxy in front of the dashboard
Signed-off-by: Kevin <kpostlet@redhat.com>
1 parent c2013ba commit faa08fe

File tree

5 files changed

+255
-7
lines changed

5 files changed

+255
-7
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
cluster setup queue, a list of all existing clusters, and the user's working namespace.
1919
"""
2020

21+
<<<<<<< HEAD
2122
from time import sleep
2223
from typing import List, Optional, Tuple, Dict
2324

25+
import openshift as oc
26+
from kubernetes import config
27+
>>>>>>> bb0a0a7 (add functions for creating ray with oauth proxy in front of the dashboard)
2428
from ray.job_submission import JobSubmissionClient
29+
import urllib3
2530

2631
from .auth import config_check, api_config_handler
2732
from ..utils import pretty_print
2833
from ..utils.generate_yaml import generate_appwrapper
2934
from ..utils.kube_api_helpers import _kube_api_error_handling
35+
from ..utils.openshift_oauth import create_openshift_oauth_objects, delete_openshift_oauth_objects, download_tls_cert
3036
from .config import ClusterConfiguration
3137
from .model import (
3238
AppWrapper,
@@ -41,6 +47,9 @@
4147
import requests
4248

4349

50+
k8_client = config.new_client_from_config()
51+
52+
4453
class Cluster:
4554
"""
4655
An object for requesting, bringing up, and taking down resources.
@@ -61,6 +70,21 @@ def __init__(self, config: ClusterConfiguration):
6170
self.config = config
6271
self.app_wrapper_yaml = self.create_app_wrapper()
6372
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
73+
self._client = None
74+
75+
@property
76+
def client(self):
77+
if self._client:
78+
return self._client
79+
if self.config.openshift_oauth:
80+
# user must be logged in to OpenShift
81+
self._client = JobSubmissionClient(
82+
self.cluster_dashboard_uri(),
83+
headers={"Authorization": k8_client.configuration.auth_settings()["BearerToken"]["value"]}
84+
)
85+
else:
86+
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
87+
return self._client
6488

6589
def evaluate_dispatch_priority(self):
6690
priority_class = self.config.dispatch_priority
@@ -141,6 +165,7 @@ def create_app_wrapper(self):
141165
image_pull_secrets=image_pull_secrets,
142166
dispatch_priority=dispatch_priority,
143167
priority_val=priority_val,
168+
openshift_oauth=self.config.openshift_oauth,
144169
)
145170

146171
# creates a new cluster with the provided or default spec
@@ -150,6 +175,9 @@ def up(self):
150175
the MCAD queue.
151176
"""
152177
namespace = self.config.namespace
178+
if self.config.openshift_oauth:
179+
create_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
180+
153181
try:
154182
config_check()
155183
api_instance = client.CustomObjectsApi(api_config_handler())
@@ -184,6 +212,9 @@ def down(self):
184212
except Exception as e: # pragma: no cover
185213
return _kube_api_error_handling(e)
186214

215+
if self.config.openshift_oauth:
216+
delete_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
217+
187218
def status(
188219
self, print_to_console: bool = True
189220
) -> Tuple[CodeFlareClusterStatus, bool]:
@@ -322,14 +353,14 @@ def list_jobs(self) -> List:
322353
"""
323354
dashboard_route = self.cluster_dashboard_uri()
324355
client = JobSubmissionClient(dashboard_route)
325-
return client.list_jobs()
356+
return self.client.list_jobs()
326357

327358
def job_status(self, job_id: str) -> str:
328359
"""
329360
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
330361
"""
331362
dashboard_route = self.cluster_dashboard_uri()
332-
client = JobSubmissionClient(dashboard_route)
363+
client = JobSubmissionClient(dashboard_route,)
333364
return client.get_job_status(job_id)
334365

335366
def job_logs(self, job_id: str) -> str:
@@ -343,7 +374,7 @@ def job_logs(self, job_id: str) -> str:
343374
def torchx_config(
344375
self, working_dir: str = None, requirements: str = None
345376
) -> Dict[str, str]:
346-
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
377+
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
347378
to_return = {
348379
"cluster_name": self.config.name,
349380
"dashboard_address": dashboard_address,

src/codeflare_sdk/cluster/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ class ClusterConfiguration:
4848
local_interactive: bool = False
4949
image_pull_secrets: list = field(default_factory=list)
5050
dispatch_priority: str = None
51+
openshift_oauth: bool = False

src/codeflare_sdk/job/jobs.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,35 @@
2121
from torchx.runner import get_runner
2222
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2323

24+
from ray.job_submission import JobSubmissionClient
25+
26+
import openshift as oc
27+
2428
if TYPE_CHECKING:
2529
from ..cluster.cluster import Cluster
2630
from ..cluster.cluster import get_current_namespace
2731

2832
all_jobs: List["Job"] = []
2933
torchx_runner = get_runner()
3034

31-
3235
class JobDefinition(metaclass=abc.ABCMeta):
3336
def _dry_run(self, cluster: "Cluster"):
3437
pass
3538

3639
def submit(self, cluster: "Cluster"):
3740
pass
3841

42+
def _get_torchx_runner(self, cluster: "Cluster"):
43+
return get_runner(
44+
scheduler_params={
45+
"ray_client": JobSubmissionClient(
46+
address=cluster.cluster_dashboard_uri(),
47+
headers={"Authorization": f"Bearer {oc.get_auth_token()}"},
48+
verify=cluster.config.openshift_oauth,
49+
)
50+
}
51+
)
52+
3953

4054
class Job(metaclass=abc.ABCMeta):
4155
def status(self):
@@ -171,9 +185,11 @@ def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None
171185
self.job_definition = job_definition
172186
self.cluster = cluster
173187
if self.cluster:
174-
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
188+
runner = job_definition._get_torchx_runner(cluster=cluster)
189+
self._app_handle = runner.schedule(job_definition._dry_run(cluster))
175190
else:
176-
self._app_handle = torchx_runner.schedule(
191+
runner = get_runner()
192+
self._app_handle = runner.schedule(
177193
job_definition._dry_run_no_cluster()
178194
)
179195
all_jobs.append(self)

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,20 @@
2121
import sys
2222
import argparse
2323
import uuid
24+
<<<<<<< HEAD
2425
from kubernetes import client, config
2526
from .kube_api_helpers import _kube_api_error_handling
2627
from ..cluster.auth import api_config_handler
28+
=======
29+
from os import urandom
30+
from base64 import b64encode
31+
from urllib3.util import parse_url
32+
>>>>>>> bb0a0a7 (add functions for creating ray with oauth proxy in front of the dashboard)
2733

34+
import openshift as oc
35+
from kubernetes import client, config
36+
37+
k8_client = config.new_client_from_config()
2838

2939
def read_template(template):
3040
with open(template, "r") as stream:
@@ -46,12 +56,14 @@ def gen_names(name):
4656

4757
def update_dashboard_route(route_item, cluster_name, namespace):
4858
metadata = route_item.get("generictemplate", {}).get("metadata")
49-
metadata["name"] = f"ray-dashboard-{cluster_name}"
59+
metadata["name"] = gen_dashboard_route_name(cluster_name)
5060
metadata["namespace"] = namespace
5161
metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc"
5262
spec = route_item.get("generictemplate", {}).get("spec")
5363
spec["to"]["name"] = f"{cluster_name}-head-svc"
5464

65+
def gen_dashboard_route_name(cluster_name):
66+
return f"ray-dashboard-{cluster_name}"
5567

5668
# ToDo: refactor the update_x_route() functions
5769
def update_rayclient_route(route_item, cluster_name, namespace):
@@ -347,6 +359,64 @@ def write_user_appwrapper(user_yaml, output_file_name):
347359
print(f"Written to: {output_file_name}")
348360

349361

362+
def enable_openshift_oauth(user_yaml, cluster_name, namespace):
363+
tls_mount_location = "/etc/tls/private"
364+
oauth_port = 443
365+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
366+
tls_secret_name = f"{cluster_name}-proxy-tls-secret"
367+
tls_volume_name = "proxy-tls-secret"
368+
port_name = "oauth-proxy"
369+
_,_,host,_,_,_,_ = parse_url(k8_client.configuration.host)
370+
host = host.replace("api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps.")
371+
oauth_sidecar = _create_oauth_sidecar_object(
372+
namespace, tls_mount_location, oauth_port, oauth_sa_name, tls_volume_name, port_name
373+
)
374+
tls_secret_volume = client.V1Volume(
375+
name=tls_volume_name,secret=client.V1SecretVolumeSource(secret_name=tls_secret_name)
376+
)
377+
# allows for setting value of Cluster object when initializing object from an existing AppWrapper on cluster
378+
user_yaml["metadata"]["annotations"] = user_yaml["metadata"].get("annotations", {})
379+
user_yaml["metadata"]["annotations"]["codeflare-sdk-use-oauth"] = "true" # if the user gets an
380+
ray_headgroup_pod = user_yaml["spec"]["resources"]["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]
381+
user_yaml["spec"]["resources"]["GenericItems"].pop(1)
382+
ray_headgroup_pod["serviceAccount"] = oauth_sa_name
383+
ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", [])
384+
ray_headgroup_pod["volumes"].append(k8_client.sanitize_for_serialization(tls_secret_volume))
385+
ray_headgroup_pod["containers"].append(k8_client.sanitize_for_serialization(oauth_sidecar))
386+
# add volume to headnode
387+
# add sidecar container to ray object
388+
389+
def _create_oauth_sidecar_object(
390+
namespace: str,
391+
tls_mount_location: str,
392+
oauth_port: int,
393+
oauth_sa_name: str,
394+
tls_volume_name: str,
395+
port_name: str
396+
) -> client.V1Container:
397+
return client.V1Container(
398+
args=[
399+
f"--https-address=:{oauth_port}",
400+
"--provider=openshift",
401+
f"--openshift-service-account={oauth_sa_name}",
402+
"--upstream=http://localhost:8265",
403+
f"--tls-cert={tls_mount_location}/tls.crt",
404+
f"--tls-key={tls_mount_location}/tls.key",
405+
"--cookie-secret=SECRET",
406+
# f"--cookie-secret={b64encode(urandom(64)).decode('utf-8')}", # create random string for encrypting cookie
407+
f'--openshift-delegate-urls={{"/":{{"resource":"pods","namespace":"{namespace}","verb":"get"}}}}'
408+
],
409+
image="registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366",
410+
name="oauth-proxy",
411+
ports=[client.V1ContainerPort(container_port=oauth_port,name=port_name)],
412+
resources = client.V1ResourceRequirements(limits=None,requests=None),
413+
volume_mounts=[
414+
client.V1VolumeMount(
415+
mount_path=tls_mount_location,name=tls_volume_name,read_only=True
416+
)
417+
],
418+
)
419+
350420
def generate_appwrapper(
351421
name: str,
352422
namespace: str,
@@ -365,6 +435,7 @@ def generate_appwrapper(
365435
image_pull_secrets: list,
366436
dispatch_priority: str,
367437
priority_val: int,
438+
openshift_oauth: bool,
368439
):
369440
user_yaml = read_template(template)
370441
appwrapper_name, cluster_name = gen_names(name)
@@ -396,6 +467,10 @@ def generate_appwrapper(
396467
enable_local_interactive(resources, cluster_name, namespace)
397468
else:
398469
disable_raycluster_tls(resources["resources"])
470+
471+
if openshift_oauth:
472+
enable_openshift_oauth(user_yaml, cluster_name, namespace)
473+
399474
outfile = appwrapper_name + ".yaml"
400475
write_user_appwrapper(user_yaml, outfile)
401476
return outfile

0 commit comments

Comments
 (0)