Skip to content

Commit bb0a0a7

Browse files
committed
add functions for creating ray with oauth proxy in front of the dashboard
Signed-off-by: Kevin <kpostlet@redhat.com>
1 parent 0d9b23c commit bb0a0a7

File tree

4 files changed

+230
-6
lines changed

4 files changed

+230
-6
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818
cluster setup queue, a list of all existing clusters, and the user's working namespace.
1919
"""
2020

21-
from os import stat
21+
from os import stat, environ
2222
from time import sleep
2323
from typing import List, Optional, Tuple, Dict
2424

2525
import openshift as oc
26+
from kubernetes import config
2627
from ray.job_submission import JobSubmissionClient
28+
import urllib3
2729

2830
from ..utils import pretty_print
2931
from ..utils.generate_yaml import generate_appwrapper
32+
from ..utils.openshift_oauth import create_openshift_oauth_objects, delete_openshift_oauth_objects, download_tls_cert
3033
from .config import ClusterConfiguration
3134
from .model import (
3235
AppWrapper,
@@ -37,6 +40,9 @@
3740
)
3841

3942

43+
k8_client = config.new_client_from_config()
44+
45+
4046
class Cluster:
4147
"""
4248
An object for requesting, bringing up, and taking down resources.
@@ -57,6 +63,21 @@ def __init__(self, config: ClusterConfiguration):
5763
self.config = config
5864
self.app_wrapper_yaml = self.create_app_wrapper()
5965
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
66+
self._client = None
67+
68+
@property
69+
def client(self):
70+
if self._client:
71+
return self._client
72+
if self.config.openshift_oauth:
73+
# user must be logged in to OpenShift
74+
self._client = JobSubmissionClient(
75+
self.cluster_dashboard_uri(),
76+
headers={"Authorization": k8_client.configuration.auth_settings()["BearerToken"]["value"]}
77+
)
78+
else:
79+
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
80+
return self._client
6081

6182
def create_app_wrapper(self):
6283
"""
@@ -102,6 +123,7 @@ def create_app_wrapper(self):
102123
env=env,
103124
local_interactive=local_interactive,
104125
image_pull_secrets=image_pull_secrets,
126+
openshift_oauth=self.config.openshift_oauth,
105127
)
106128

107129
# creates a new cluster with the provided or default spec
@@ -111,6 +133,9 @@ def up(self):
111133
the MCAD queue.
112134
"""
113135
namespace = self.config.namespace
136+
if self.config.openshift_oauth:
137+
create_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
138+
114139
try:
115140
with oc.project(namespace):
116141
oc.invoke("apply", ["-f", self.app_wrapper_yaml])
@@ -146,6 +171,8 @@ def down(self):
146171
print("Cluster not found, have you run cluster.up() yet?")
147172
else:
148173
raise osp
174+
if self.config.openshift_oauth:
175+
delete_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
149176

150177
def status(
151178
self, print_to_console: bool = True
@@ -254,7 +281,7 @@ def cluster_dashboard_uri(self) -> str:
254281
route = route.out().split(" ")
255282
route = [x for x in route if f"ray-dashboard-{self.config.name}" in x]
256283
route = route[0].strip().strip("'")
257-
return f"http://{route}"
284+
return f"https://{route}"
258285
except:
259286
return "Dashboard route not available yet, have you run cluster.up()?"
260287

@@ -264,14 +291,14 @@ def list_jobs(self) -> List:
264291
"""
265292
dashboard_route = self.cluster_dashboard_uri()
266293
client = JobSubmissionClient(dashboard_route)
267-
return client.list_jobs()
294+
return self.client.list_jobs()
268295

269296
def job_status(self, job_id: str) -> str:
270297
"""
271298
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
272299
"""
273300
dashboard_route = self.cluster_dashboard_uri()
274-
client = JobSubmissionClient(dashboard_route)
301+
client = JobSubmissionClient(dashboard_route,)
275302
return client.get_job_status(job_id)
276303

277304
def job_logs(self, job_id: str) -> str:
@@ -285,7 +312,7 @@ def job_logs(self, job_id: str) -> str:
285312
def torchx_config(
286313
self, working_dir: str = None, requirements: str = None
287314
) -> Dict[str, str]:
288-
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
315+
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
289316
to_return = {
290317
"cluster_name": self.config.name,
291318
"dashboard_address": dashboard_address,

src/codeflare_sdk/cluster/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ class ClusterConfiguration:
5050
image: str = "quay.io/project-codeflare/ray:2.5.0-py38-cu116"
5151
local_interactive: bool = False
5252
image_pull_secrets: list = field(default_factory=list)
53+
openshift_oauth: bool = False

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@
2121
import sys
2222
import argparse
2323
import uuid
24+
from os import urandom
25+
from base64 import b64encode
26+
from urllib3.util import parse_url
27+
2428
import openshift as oc
29+
from kubernetes import client, config
2530

31+
k8_client = config.new_client_from_config()
2632

2733
def read_template(template):
2834
with open(template, "r") as stream:
@@ -44,12 +50,14 @@ def gen_names(name):
4450

4551
def update_dashboard_route(route_item, cluster_name, namespace):
4652
metadata = route_item.get("generictemplate", {}).get("metadata")
47-
metadata["name"] = f"ray-dashboard-{cluster_name}"
53+
metadata["name"] = gen_dashboard_route_name(cluster_name)
4854
metadata["namespace"] = namespace
4955
metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc"
5056
spec = route_item.get("generictemplate", {}).get("spec")
5157
spec["to"]["name"] = f"{cluster_name}-head-svc"
5258

59+
def gen_dashboard_route_name(cluster_name):
60+
return f"ray-dashboard-{cluster_name}"
5361

5462
# ToDo: refactor the update_x_route() functions
5563
def update_rayclient_route(route_item, cluster_name, namespace):
@@ -289,6 +297,64 @@ def write_user_appwrapper(user_yaml, output_file_name):
289297
print(f"Written to: {output_file_name}")
290298

291299

300+
def enable_openshift_oauth(user_yaml, cluster_name, namespace):
301+
tls_mount_location = "/etc/tls/private"
302+
oauth_port = 443
303+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
304+
tls_secret_name = f"{cluster_name}-proxy-tls-secret"
305+
tls_volume_name = "proxy-tls-secret"
306+
port_name = "oauth-proxy"
307+
_,_,host,_,_,_,_ = parse_url(k8_client.configuration.host)
308+
host = host.replace("api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps.")
309+
oauth_sidecar = _create_oauth_sidecar_object(
310+
namespace, tls_mount_location, oauth_port, oauth_sa_name, tls_volume_name, port_name
311+
)
312+
tls_secret_volume = client.V1Volume(
313+
name=tls_volume_name,secret=client.V1SecretVolumeSource(secret_name=tls_secret_name)
314+
)
315+
# allows for setting value of Cluster object when initializing object from an existing AppWrapper on cluster
316+
user_yaml["metadata"]["annotations"] = user_yaml["metadata"].get("annotations", {})
317+
user_yaml["metadata"]["annotations"]["codeflare-sdk-use-oauth"] = "true" # if the user gets an
318+
ray_headgroup_pod = user_yaml["spec"]["resources"]["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]
319+
user_yaml["spec"]["resources"]["GenericItems"].pop(1)
320+
ray_headgroup_pod["serviceAccount"] = oauth_sa_name
321+
ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", [])
322+
ray_headgroup_pod["volumes"].append(k8_client.sanitize_for_serialization(tls_secret_volume))
323+
ray_headgroup_pod["containers"].append(k8_client.sanitize_for_serialization(oauth_sidecar))
324+
# add volume to headnode
325+
# add sidecar container to ray object
326+
327+
def _create_oauth_sidecar_object(
328+
namespace: str,
329+
tls_mount_location: str,
330+
oauth_port: int,
331+
oauth_sa_name: str,
332+
tls_volume_name: str,
333+
port_name: str
334+
) -> client.V1Container:
335+
return client.V1Container(
336+
args=[
337+
f"--https-address=:{oauth_port}",
338+
"--provider=openshift",
339+
f"--openshift-service-account={oauth_sa_name}",
340+
"--upstream=http://localhost:8265",
341+
f"--tls-cert={tls_mount_location}/tls.crt",
342+
f"--tls-key={tls_mount_location}/tls.key",
343+
"--cookie-secret=SECRET",
344+
# f"--cookie-secret={b64encode(urandom(64)).decode('utf-8')}", # create random string for encrypting cookie
345+
f'--openshift-delegate-urls={{"/":{{"resource":"pods","namespace":"{namespace}","verb":"get"}}}}'
346+
],
347+
image="registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366",
348+
name="oauth-proxy",
349+
ports=[client.V1ContainerPort(container_port=oauth_port,name=port_name)],
350+
resources = client.V1ResourceRequirements(limits=None,requests=None),
351+
volume_mounts=[
352+
client.V1VolumeMount(
353+
mount_path=tls_mount_location,name=tls_volume_name,read_only=True
354+
)
355+
],
356+
)
357+
292358
def generate_appwrapper(
293359
name: str,
294360
namespace: str,
@@ -305,6 +371,7 @@ def generate_appwrapper(
305371
env,
306372
local_interactive: bool,
307373
image_pull_secrets: list,
374+
openshift_oauth: bool,
308375
):
309376
user_yaml = read_template(template)
310377
appwrapper_name, cluster_name = gen_names(name)
@@ -335,6 +402,10 @@ def generate_appwrapper(
335402
enable_local_interactive(resources, cluster_name, namespace)
336403
else:
337404
disable_raycluster_tls(resources["resources"])
405+
406+
if openshift_oauth:
407+
enable_openshift_oauth(user_yaml, cluster_name, namespace)
408+
338409
outfile = appwrapper_name + ".yaml"
339410
write_user_appwrapper(user_yaml, outfile)
340411
return outfile
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from urllib3.util import parse_url
2+
from .generate_yaml import gen_dashboard_route_name
3+
from base64 import b64decode
4+
5+
from kubernetes import config, client
6+
7+
k8_client = config.new_client_from_config()
8+
core_api = client.CoreV1Api(k8_client)
9+
rbac_auth_api = client.RbacAuthorizationV1Api(k8_client)
10+
networking_api = client.NetworkingV1Api(k8_client)
11+
12+
def create_openshift_oauth_objects(cluster_name, namespace):
13+
oauth_port = 443
14+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
15+
tls_secret_name = _gen_tls_secret_name(cluster_name)
16+
service_name = f"{cluster_name}-oauth"
17+
port_name = "oauth-proxy"
18+
host = parse_url(k8_client.configuration.host).host
19+
20+
# replace "^api" with the expected host
21+
host = f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps" + host.lstrip("api")
22+
23+
oauth_crb = client.V1ClusterRoleBinding(
24+
api_version="rbac.authorization.k8s.io/v1", kind="ClusterRoleBinding",
25+
metadata=client.V1ObjectMeta(name=f"{cluster_name}-rb"),
26+
role_ref=client.V1RoleRef(
27+
api_group="rbac.authorization.k8s.io",
28+
kind="ClusterRole",
29+
name="system:auth-delegator",
30+
),
31+
subjects=[client.V1Subject(kind="ServiceAccount", name=oauth_sa_name, namespace=namespace)]
32+
)
33+
oauth_sa = client.V1ServiceAccount(
34+
api_version="v1",
35+
kind="ServiceAccount",
36+
metadata=client.V1ObjectMeta(
37+
name=oauth_sa_name,
38+
namespace=namespace,
39+
annotations={"serviceaccounts.openshift.io/oauth-redirecturi.first": f"https://{host}"}
40+
)
41+
)
42+
oauth_service = _create_oauth_service_obj(
43+
cluster_name, namespace, oauth_port, tls_secret_name, service_name, port_name
44+
)
45+
ingress = _create_oauth_ingress_object(cluster_name, namespace, service_name, port_name, host)
46+
core_api.create_namespaced_service_account(namespace=namespace, body=oauth_sa)
47+
core_api.create_namespaced_service(namespace=namespace, body=oauth_service)
48+
networking_api.create_namespaced_ingress(namespace=namespace, body=ingress)
49+
rbac_auth_api.create_cluster_role_binding(body=oauth_crb)
50+
51+
def _gen_tls_secret_name(cluster_name):
52+
return f"{cluster_name}-proxy-tls-secret"
53+
54+
def delete_openshift_oauth_objects(cluster_name, namespace):
55+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
56+
service_name = f"{cluster_name}-oauth"
57+
core_api.delete_namespaced_service_account(name=oauth_sa_name, namespace=namespace)
58+
core_api.delete_namespaced_service(name=service_name, namespace=namespace)
59+
networking_api.delete_namespaced_ingress(name=f"{cluster_name}-ingress", namespace=namespace)
60+
rbac_auth_api.delete_cluster_role_binding(name= f"{cluster_name}-rb")
61+
62+
def download_tls_cert(cluster_name, namespace, output_file):
63+
b64_tls_cert = core_api.read_namespaced_secret(
64+
name=_gen_tls_secret_name(cluster_name=cluster_name),namespace=namespace
65+
).data['tls.crt']
66+
with open(output_file, "w+") as f:
67+
f.write(b64decode(b64_tls_cert).decode("ascii"))
68+
69+
def _create_oauth_service_obj(
70+
cluster_name: str,
71+
namespace: str,
72+
oauth_port: int,
73+
tls_secret_name: str,
74+
service_name: str,
75+
port_name: str,
76+
) -> client.V1Service:
77+
return client.V1Service(
78+
api_version="v1",
79+
kind="Service",
80+
metadata=client.V1ObjectMeta(
81+
annotations={"service.beta.openshift.io/serving-cert-secret-name": tls_secret_name},
82+
name=service_name,
83+
namespace=namespace
84+
),
85+
spec=client.V1ServiceSpec(
86+
ports=[client.V1ServicePort(name=port_name, protocol="TCP", port=oauth_port, target_port=oauth_port)],
87+
selector={
88+
"app.kubernetes.io/created-by": "kuberay-operator",
89+
"app.kubernetes.io/name": "kuberay",
90+
"ray.io/cluster": cluster_name,
91+
"ray.io/identifier": f"{cluster_name}-head",
92+
"ray.io/node-type": "head",
93+
}
94+
)
95+
)
96+
97+
def _create_oauth_ingress_object(
98+
cluster_name: str,
99+
namespace: str,
100+
service_name: str,
101+
port_name: str,
102+
host: str,
103+
) -> client.V1Ingress:
104+
return client.V1Ingress(
105+
api_version="networking.k8s.io/v1",
106+
kind="Ingress",
107+
metadata=client.V1ObjectMeta(
108+
annotations={"route.openshift.io/termination": "passthrough"},
109+
name=f"{cluster_name}-ingress",
110+
namespace=namespace
111+
),
112+
spec=client.V1IngressSpec(rules=[client.V1IngressRule(
113+
host=host,
114+
http=client.V1HTTPIngressRuleValue(paths=[
115+
client.V1HTTPIngressPath(
116+
backend=client.V1IngressBackend(
117+
service=client.V1IngressServiceBackend(
118+
name=service_name,port=client.V1ServiceBackendPort(name=port_name)
119+
)
120+
),
121+
path_type="ImplementationSpecific"
122+
)
123+
])
124+
)]),
125+
)

0 commit comments

Comments
 (0)