Skip to content

Commit 6bd6f31

Browse files
committed
add functions for creating ray with oauth proxy in front of the dashboard
Signed-off-by: Kevin <kpostlet@redhat.com>
1 parent c2013ba commit 6bd6f31

File tree

6 files changed

+290
-30
lines changed

6 files changed

+290
-30
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@
2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
2323

24+
import openshift as oc
25+
from kubernetes import config
2426
from ray.job_submission import JobSubmissionClient
27+
import urllib3
2528

2629
from .auth import config_check, api_config_handler
2730
from ..utils import pretty_print
2831
from ..utils.generate_yaml import generate_appwrapper
2932
from ..utils.kube_api_helpers import _kube_api_error_handling
33+
from ..utils.openshift_oauth import create_openshift_oauth_objects, delete_openshift_oauth_objects, download_tls_cert
3034
from .config import ClusterConfiguration
3135
from .model import (
3236
AppWrapper,
@@ -41,6 +45,9 @@
4145
import requests
4246

4347

48+
k8_client = config.new_client_from_config()
49+
50+
4451
class Cluster:
4552
"""
4653
An object for requesting, bringing up, and taking down resources.
@@ -61,6 +68,21 @@ def __init__(self, config: ClusterConfiguration):
6168
self.config = config
6269
self.app_wrapper_yaml = self.create_app_wrapper()
6370
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
71+
self._client = None
72+
73+
@property
74+
def client(self):
75+
if self._client:
76+
return self._client
77+
if self.config.openshift_oauth:
78+
self._client = JobSubmissionClient(
79+
self.cluster_dashboard_uri(),
80+
headers={"Authorization": k8_client.configuration.auth_settings()["BearerToken"]["value"]},
81+
verify=False,
82+
)
83+
else:
84+
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
85+
return self._client
6486

6587
def evaluate_dispatch_priority(self):
6688
priority_class = self.config.dispatch_priority
@@ -141,6 +163,7 @@ def create_app_wrapper(self):
141163
image_pull_secrets=image_pull_secrets,
142164
dispatch_priority=dispatch_priority,
143165
priority_val=priority_val,
166+
openshift_oauth=self.config.openshift_oauth,
144167
)
145168

146169
# creates a new cluster with the provided or default spec
@@ -150,6 +173,9 @@ def up(self):
150173
the MCAD queue.
151174
"""
152175
namespace = self.config.namespace
176+
if self.config.openshift_oauth:
177+
create_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
178+
153179
try:
154180
config_check()
155181
api_instance = client.CustomObjectsApi(api_config_handler())
@@ -184,6 +210,9 @@ def down(self):
184210
except Exception as e: # pragma: no cover
185211
return _kube_api_error_handling(e)
186212

213+
if self.config.openshift_oauth:
214+
delete_openshift_oauth_objects(cluster_name=self.config.name, namespace=namespace)
215+
187216
def status(
188217
self, print_to_console: bool = True
189218
) -> Tuple[CodeFlareClusterStatus, bool]:
@@ -252,7 +281,13 @@ def status(
252281
return status, ready
253282

254283
def is_dashboard_ready(self) -> bool:
255-
response = requests.get(self.cluster_dashboard_uri(), timeout=5)
284+
try:
285+
response = requests.get(
286+
self.cluster_dashboard_uri(), headers=self.client._headers, timeout=5, verify=self.client._verify
287+
)
288+
except requests.exceptions.SSLError:
289+
# SSL exception occurs when oauth ingress has been created but cluster is not up
290+
return False
256291
if response.status_code == 200:
257292
return True
258293
else:
@@ -311,7 +346,8 @@ def cluster_dashboard_uri(self) -> str:
311346
return _kube_api_error_handling(e)
312347

313348
for route in routes["items"]:
314-
if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}":
349+
if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}" or \
350+
route["metadata"]["name"].startswith(f"{self.config.name}-ingress"):
315351
protocol = "https" if route["spec"].get("tls") else "http"
316352
return f"{protocol}://{route['spec']['host']}"
317353
return "Dashboard route not available yet, have you run cluster.up()?"
@@ -320,30 +356,24 @@ def list_jobs(self) -> List:
320356
"""
321357
This method accesses the head ray node in your cluster and lists the running jobs.
322358
"""
323-
dashboard_route = self.cluster_dashboard_uri()
324-
client = JobSubmissionClient(dashboard_route)
325-
return client.list_jobs()
359+
return self.client.list_jobs()
326360

327361
def job_status(self, job_id: str) -> str:
328362
"""
329363
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
330364
"""
331-
dashboard_route = self.cluster_dashboard_uri()
332-
client = JobSubmissionClient(dashboard_route)
333-
return client.get_job_status(job_id)
365+
return self.client.get_job_status(job_id)
334366

335367
def job_logs(self, job_id: str) -> str:
336368
"""
337369
This method accesses the head ray node in your cluster and returns the logs for the provided job id.
338370
"""
339-
dashboard_route = self.cluster_dashboard_uri()
340-
client = JobSubmissionClient(dashboard_route)
341-
return client.get_job_logs(job_id)
371+
return self.client.get_job_logs(job_id)
342372

343373
def torchx_config(
344374
self, working_dir: str = None, requirements: str = None
345375
) -> Dict[str, str]:
346-
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
376+
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
347377
to_return = {
348378
"cluster_name": self.config.name,
349379
"dashboard_address": dashboard_address,
@@ -587,7 +617,8 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
587617
)
588618
ray_route = None
589619
for route in routes["items"]:
590-
if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}":
620+
if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}" or \
621+
route["metadata"]["name"].startswith(f"{rc['metadata']['name']}-ingress"):
591622
protocol = "https" if route["spec"].get("tls") else "http"
592623
ray_route = f"{protocol}://{route['spec']['host']}"
593624

src/codeflare_sdk/cluster/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ class ClusterConfiguration:
4848
local_interactive: bool = False
4949
image_pull_secrets: list = field(default_factory=list)
5050
dispatch_priority: str = None
51+
openshift_oauth: bool = False # NOTE: to use the user must have permission to create ClusterRoleBindings

src/codeflare_sdk/job/jobs.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@
1818
from pathlib import Path
1919

2020
from torchx.components.dist import ddp
21-
from torchx.runner import get_runner
21+
from torchx.runner import get_runner, Runner
22+
from torchx.schedulers.ray_scheduler import RayScheduler
2223
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2324

25+
from ray.job_submission import JobSubmissionClient
26+
27+
import openshift as oc
28+
2429
if TYPE_CHECKING:
2530
from ..cluster.cluster import Cluster
2631
from ..cluster.cluster import get_current_namespace
32+
from ..utils.openshift_oauth import download_tls_cert
2733

2834
all_jobs: List["Job"] = []
29-
torchx_runner = get_runner()
30-
3135

3236
class JobDefinition(metaclass=abc.ABCMeta):
3337
def _dry_run(self, cluster: "Cluster"):
@@ -92,7 +96,9 @@ def __init__(
9296

9397
def _dry_run(self, cluster: "Cluster"):
9498
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus
95-
return torchx_runner.dryrun(
99+
runner = get_runner(ray_client=cluster.client)
100+
runner._scheduler_instances["ray"] = RayScheduler(session_name=runner._name, ray_client=cluster.client)
101+
return runner.dryrun(
96102
app=ddp(
97103
*self.script_args,
98104
script=self.script,
@@ -116,7 +122,7 @@ def _dry_run(self, cluster: "Cluster"):
116122
scheduler=cluster.torchx_scheduler,
117123
cfg=cluster.torchx_config(**self.scheduler_args),
118124
workspace=self.workspace,
119-
)
125+
), runner
120126

121127
def _missing_spec(self, spec: str):
122128
raise ValueError(f"Job definition missing arg: {spec}")
@@ -125,7 +131,8 @@ def _dry_run_no_cluster(self):
125131
if self.scheduler_args is not None:
126132
if self.scheduler_args.get("namespace") is None:
127133
self.scheduler_args["namespace"] = get_current_namespace()
128-
return torchx_runner.dryrun(
134+
runner = get_runner()
135+
return runner.dryrun(
129136
app=ddp(
130137
*self.script_args,
131138
script=self.script,
@@ -160,7 +167,7 @@ def _dry_run_no_cluster(self):
160167
scheduler="kubernetes_mcad",
161168
cfg=self.scheduler_args,
162169
workspace="",
163-
)
170+
), runner
164171

165172
def submit(self, cluster: "Cluster" = None) -> "Job":
166173
return DDPJob(self, cluster)
@@ -171,18 +178,20 @@ def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None
171178
self.job_definition = job_definition
172179
self.cluster = cluster
173180
if self.cluster:
174-
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
181+
definition, runner = job_definition._dry_run(cluster)
182+
self._app_handle = runner.schedule(definition)
183+
self._runner = runner
175184
else:
176-
self._app_handle = torchx_runner.schedule(
177-
job_definition._dry_run_no_cluster()
178-
)
185+
definition, runner = job_definition._dry_run_no_cluster()
186+
self._app_handle = runner.schedule(definition)
187+
self._runner = runner
179188
all_jobs.append(self)
180189

181190
def status(self) -> str:
182-
return torchx_runner.status(self._app_handle)
191+
return self._runner.status(self._app_handle)
183192

184193
def logs(self) -> str:
185-
return "".join(torchx_runner.log_lines(self._app_handle, None))
194+
return "".join(self._runner.log_lines(self._app_handle, None))
186195

187196
def cancel(self):
188-
torchx_runner.cancel(self._app_handle)
197+
self._runner.cancel(self._app_handle)

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@
2424
from kubernetes import client, config
2525
from .kube_api_helpers import _kube_api_error_handling
2626
from ..cluster.auth import api_config_handler
27+
from os import urandom
28+
from base64 import b64encode
29+
from urllib3.util import parse_url
2730

31+
import openshift as oc
32+
from kubernetes import client, config
33+
34+
k8_client = config.new_client_from_config()
2835

2936
def read_template(template):
3037
with open(template, "r") as stream:
@@ -46,12 +53,14 @@ def gen_names(name):
4653

4754
def update_dashboard_route(route_item, cluster_name, namespace):
4855
metadata = route_item.get("generictemplate", {}).get("metadata")
49-
metadata["name"] = f"ray-dashboard-{cluster_name}"
56+
metadata["name"] = gen_dashboard_route_name(cluster_name)
5057
metadata["namespace"] = namespace
5158
metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc"
5259
spec = route_item.get("generictemplate", {}).get("spec")
5360
spec["to"]["name"] = f"{cluster_name}-head-svc"
5461

62+
def gen_dashboard_route_name(cluster_name):
63+
return f"ray-dashboard-{cluster_name}"
5564

5665
# ToDo: refactor the update_x_route() functions
5766
def update_rayclient_route(route_item, cluster_name, namespace):
@@ -347,6 +356,63 @@ def write_user_appwrapper(user_yaml, output_file_name):
347356
print(f"Written to: {output_file_name}")
348357

349358

359+
def enable_openshift_oauth(user_yaml, cluster_name, namespace):
360+
tls_mount_location = "/etc/tls/private"
361+
oauth_port = 443
362+
oauth_sa_name = f"{cluster_name}-oauth-proxy"
363+
tls_secret_name = f"{cluster_name}-proxy-tls-secret"
364+
tls_volume_name = "proxy-tls-secret"
365+
port_name = "oauth-proxy"
366+
_,_,host,_,_,_,_ = parse_url(k8_client.configuration.host)
367+
host = host.replace("api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps.")
368+
oauth_sidecar = _create_oauth_sidecar_object(
369+
namespace, tls_mount_location, oauth_port, oauth_sa_name, tls_volume_name, port_name
370+
)
371+
tls_secret_volume = client.V1Volume(
372+
name=tls_volume_name,secret=client.V1SecretVolumeSource(secret_name=tls_secret_name)
373+
)
374+
# allows for setting value of Cluster object when initializing object from an existing AppWrapper on cluster
375+
user_yaml["metadata"]["annotations"] = user_yaml["metadata"].get("annotations", {})
376+
user_yaml["metadata"]["annotations"]["codeflare-sdk-use-oauth"] = "true" # if the user gets an
377+
ray_headgroup_pod = user_yaml["spec"]["resources"]["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]
378+
user_yaml["spec"]["resources"]["GenericItems"].pop(1)
379+
ray_headgroup_pod["serviceAccount"] = oauth_sa_name
380+
ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", [])
381+
ray_headgroup_pod["volumes"].append(k8_client.sanitize_for_serialization(tls_secret_volume))
382+
ray_headgroup_pod["containers"].append(k8_client.sanitize_for_serialization(oauth_sidecar))
383+
# add volume to headnode
384+
# add sidecar container to ray object
385+
386+
def _create_oauth_sidecar_object(
387+
namespace: str,
388+
tls_mount_location: str,
389+
oauth_port: int,
390+
oauth_sa_name: str,
391+
tls_volume_name: str,
392+
port_name: str
393+
) -> client.V1Container:
394+
return client.V1Container(
395+
args=[
396+
f"--https-address=:{oauth_port}",
397+
"--provider=openshift",
398+
f"--openshift-service-account={oauth_sa_name}",
399+
"--upstream=http://localhost:8265",
400+
f"--tls-cert={tls_mount_location}/tls.crt",
401+
f"--tls-key={tls_mount_location}/tls.key",
402+
f"--cookie-secret={b64encode(urandom(64)).decode('utf-8')}", # create random string for encrypting cookie
403+
f'--openshift-delegate-urls={{"/":{{"resource":"pods","namespace":"{namespace}","verb":"get"}}}}'
404+
],
405+
image="registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366",
406+
name="oauth-proxy",
407+
ports=[client.V1ContainerPort(container_port=oauth_port,name=port_name)],
408+
resources = client.V1ResourceRequirements(limits=None,requests=None),
409+
volume_mounts=[
410+
client.V1VolumeMount(
411+
mount_path=tls_mount_location,name=tls_volume_name,read_only=True
412+
)
413+
],
414+
)
415+
350416
def generate_appwrapper(
351417
name: str,
352418
namespace: str,
@@ -365,6 +431,7 @@ def generate_appwrapper(
365431
image_pull_secrets: list,
366432
dispatch_priority: str,
367433
priority_val: int,
434+
openshift_oauth: bool,
368435
):
369436
user_yaml = read_template(template)
370437
appwrapper_name, cluster_name = gen_names(name)
@@ -396,6 +463,10 @@ def generate_appwrapper(
396463
enable_local_interactive(resources, cluster_name, namespace)
397464
else:
398465
disable_raycluster_tls(resources["resources"])
466+
467+
if openshift_oauth:
468+
enable_openshift_oauth(user_yaml, cluster_name, namespace)
469+
399470
outfile = appwrapper_name + ".yaml"
400471
write_user_appwrapper(user_yaml, outfile)
401472
return outfile

0 commit comments

Comments
 (0)