Skip to content

Commit 06d83b1

Browse files
committed
Added Mtls patch
(cherry picked from commit de2de96fc88022df783b637ccb145d1d73ba66ff) Review changes
1 parent 7c64408 commit 06d83b1

File tree

5 files changed

+264
-7
lines changed

5 files changed

+264
-7
lines changed

config/rbac/role.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ rules:
1717
- subjectaccessreviews
1818
verbs:
1919
- create
20+
- apiGroups:
21+
- config.openshift.io
22+
resources:
23+
- ingresses
24+
verbs:
25+
- get
2026
- apiGroups:
2127
- ""
2228
resources:

main.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import (
4747
"sigs.k8s.io/yaml"
4848

4949
routev1 "github.com/openshift/api/route/v1"
50+
clientset "github.com/openshift/client-go/config/clientset/versioned"
5051

5152
"github.com/project-codeflare/codeflare-operator/pkg/config"
5253
"github.com/project-codeflare/codeflare-operator/pkg/controllers"
@@ -72,6 +73,8 @@ func init() {
7273
utilruntime.Must(routev1.Install(scheme))
7374
}
7475

76+
// +kubebuilder:rbac:groups=config.openshift.io,resources=ingresses,verbs=get;
77+
7578
func main() {
7679
var configMapName string
7780
flag.StringVar(&configMapName, "config", "codeflare-operator-config",
@@ -116,6 +119,7 @@ func main() {
116119
KubeRay: &config.KubeRayConfiguration{
117120
RayDashboardOAuthEnabled: pointer.Bool(true),
118121
IngressDomain: "",
122+
MTLSEnabled: pointer.Bool(true),
119123
},
120124
}
121125

@@ -150,6 +154,12 @@ func main() {
150154
OpenShift := isOpenShift(ctx, kubeClient.DiscoveryClient)
151155

152156
if OpenShift {
157+
if cfg.KubeRay.IngressDomain == "" {
158+
configClient, err := clientset.NewForConfig(kubeConfig)
159+
exitOnError(err, "unable to create Route Client Set")
160+
cfg.KubeRay.IngressDomain, err = getClusterDomain(ctx, configClient)
161+
exitOnError(err, cfg.KubeRay.IngressDomain)
162+
}
153163
// TODO: setup the RayCluster webhook on vanilla Kubernetes
154164
exitOnError(controllers.SetupRayClusterWebhookWithManager(mgr, cfg.KubeRay), "error setting up RayCluster webhook")
155165
}
@@ -274,3 +284,17 @@ func isOpenShift(ctx context.Context, dc discovery.DiscoveryInterface) bool {
274284
logger.Info("We detected being on Vanilla Kubernetes!")
275285
return false
276286
}
287+
288+
func getClusterDomain(ctx context.Context, configClient *clientset.Clientset) (string, error) {
289+
ingress, err := configClient.ConfigV1().Ingresses().Get(ctx, "cluster", metav1.GetOptions{})
290+
if err != nil {
291+
return "", fmt.Errorf("failed to get Ingress object: %v", err)
292+
}
293+
294+
domain := ingress.Spec.Domain
295+
if domain == "" {
296+
return "", fmt.Errorf("domain is not set in the Ingress object")
297+
}
298+
299+
return domain, nil
300+
}

pkg/config/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type KubeRayConfiguration struct {
3535
RayDashboardOAuthEnabled *bool `json:"rayDashboardOAuthEnabled,omitempty"`
3636

3737
IngressDomain string `json:"ingressDomain"`
38+
39+
MTLSEnabled *bool `json:"mTLSEnabled,omitempty"`
3840
}
3941

4042
type ControllerManager struct {

pkg/controllers/raycluster_webhook.go

Lines changed: 221 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package controllers
1818

1919
import (
2020
"context"
21+
"strconv"
2122

2223
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
2324

@@ -36,6 +37,7 @@ import (
3637
const (
3738
oauthProxyContainerName = "oauth-proxy"
3839
oauthProxyVolumeName = "proxy-tls-secret"
40+
initContainerName = "create-cert"
3941
)
4042

4143
// log is for logging in this package.
@@ -66,17 +68,51 @@ var _ webhook.CustomValidator = &rayClusterWebhook{}
6668
func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) error {
6769
rayCluster := obj.(*rayv1.RayCluster)
6870

69-
if !pointer.BoolDeref(w.Config.RayDashboardOAuthEnabled, true) {
70-
return nil
71-
}
72-
7371
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
7472

75-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
73+
if pointer.BoolDeref(w.Config.RayDashboardOAuthEnabled, true) {
74+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
7675

77-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
76+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
7877

79-
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
78+
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
79+
}
80+
81+
if pointer.BoolDeref(w.Config.MTLSEnabled, true) {
82+
// HeadGroupSpec //
83+
// Append the list of environment variables for the ray-head container
84+
for index, container := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
85+
if container.Name == "ray-head" {
86+
for _, envVar := range envVarList() {
87+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[index].Env = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[index].Env, envVar, withEnvVarName(envVar.Name))
88+
}
89+
}
90+
}
91+
92+
// Append the create-cert Init Container
93+
rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, w.rayHeadInitContainer(rayCluster), withContainerName(initContainerName))
94+
95+
// Append the CA volumes
96+
for _, caVol := range caVolumes(rayCluster) {
97+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
98+
}
99+
// WorkerGroupSpec //
100+
// Append the list of environment variables for the machine-learning container
101+
for index, container := range rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers {
102+
if container.Name == "machine-learning" {
103+
for _, envVar := range envVarList() {
104+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[index].Env = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[index].Env, envVar, withEnvVarName(envVar.Name))
105+
}
106+
}
107+
}
108+
// Append the CA volumes
109+
for _, caVol := range caVolumes(rayCluster) {
110+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
111+
}
112+
// Append the create-cert Init Container
113+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(rayCluster), withContainerName(initContainerName))
114+
115+
}
80116

81117
return nil
82118
}
@@ -216,3 +252,181 @@ func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
216252
},
217253
}
218254
}
255+
256+
func initCaVolumeMounts() []corev1.VolumeMount {
257+
keyVolumes := []corev1.VolumeMount{
258+
{
259+
Name: "ca-vol",
260+
MountPath: "/home/ray/workspace/ca",
261+
ReadOnly: true,
262+
},
263+
{
264+
Name: "server-cert",
265+
MountPath: "/home/ray/workspace/tls",
266+
ReadOnly: false,
267+
},
268+
}
269+
return keyVolumes
270+
}
271+
272+
func envVarList() []corev1.EnvVar {
273+
envList := []corev1.EnvVar{
274+
{
275+
Name: "MY_POD_IP",
276+
ValueFrom: &corev1.EnvVarSource{
277+
FieldRef: &corev1.ObjectFieldSelector{
278+
FieldPath: "status.podIP",
279+
},
280+
},
281+
},
282+
{
283+
Name: "RAY_USE_TLS",
284+
Value: "1",
285+
},
286+
{
287+
Name: "RAY_TLS_SERVER_CERT",
288+
Value: "/home/ray/workspace/tls/server.crt",
289+
},
290+
{
291+
Name: "RAY_TLS_SERVER_KEY",
292+
Value: "/home/ray/workspace/tls/server.key",
293+
},
294+
{
295+
Name: "RAY_TLS_CA_CERT",
296+
Value: "/home/ray/workspace/tls/ca.crt",
297+
},
298+
}
299+
return envList
300+
}
301+
302+
func caVolumes(rayCluster *rayv1.RayCluster) []corev1.Volume {
303+
secretName := `ca-secret-` + rayCluster.Name
304+
caVolumes := []corev1.Volume{
305+
{
306+
Name: "ca-vol",
307+
VolumeSource: corev1.VolumeSource{
308+
Secret: &corev1.SecretVolumeSource{
309+
SecretName: secretName,
310+
},
311+
},
312+
},
313+
{
314+
Name: "server-cert",
315+
VolumeSource: corev1.VolumeSource{
316+
EmptyDir: &corev1.EmptyDirVolumeSource{},
317+
},
318+
},
319+
}
320+
return caVolumes
321+
}
322+
323+
func (w *rayClusterWebhook) rayHeadInitContainer(rayCluster *rayv1.RayCluster) corev1.Container {
324+
rayClientRoute := "rayclient-" + rayCluster.Name + "-" + rayCluster.Namespace + "." + w.Config.IngressDomain
325+
// Service name for basic interactive
326+
svcDomain := rayCluster.Name + "-head-svc." + rayCluster.Namespace + ".svc"
327+
328+
initContainerHead := corev1.Container{
329+
Name: "create-cert",
330+
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
331+
Command: []string{
332+
"sh",
333+
"-c",
334+
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)\nDNS.5 = ` + rayClientRoute + `\nDNS.6 = ` + svcDomain + `">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
335+
},
336+
VolumeMounts: initCaVolumeMounts(),
337+
}
338+
return initContainerHead
339+
}
340+
341+
func rayWorkerInitContainer(rayCluster *rayv1.RayCluster) corev1.Container {
342+
initContainerWorker := corev1.Container{
343+
Name: "create-cert",
344+
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
345+
Command: []string{
346+
"sh",
347+
"-c",
348+
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
349+
},
350+
VolumeMounts: initCaVolumeMounts(),
351+
}
352+
return initContainerWorker
353+
}
354+
355+
func (w *rayClusterWebhook) validateHeadInitContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
356+
var allErrors field.ErrorList
357+
358+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, w.rayHeadInitContainer(rayCluster), byContainerName,
359+
field.NewPath("spec", "headGroupSpec", "template", "spec", "initContainers"),
360+
"create-cert Init Container is immutable"); err != nil {
361+
allErrors = append(allErrors, err)
362+
}
363+
364+
return allErrors
365+
}
366+
367+
func (w *rayClusterWebhook) validateWorkerInitContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
368+
var allErrors field.ErrorList
369+
370+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(rayCluster), byContainerName,
371+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "initContainers"),
372+
"create-cert Init Container is immutable"); err != nil {
373+
allErrors = append(allErrors, err)
374+
}
375+
376+
return allErrors
377+
}
378+
379+
func validateCaVolumes(rayCluster *rayv1.RayCluster) field.ErrorList {
380+
var allErrors field.ErrorList
381+
382+
for _, caVol := range caVolumes(rayCluster) {
383+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, byVolumeName,
384+
field.NewPath("spec", "headGroupSpec", "template", "spec", "volumes"),
385+
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
386+
allErrors = append(allErrors, err)
387+
}
388+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, byVolumeName,
389+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "volumes"),
390+
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
391+
allErrors = append(allErrors, err)
392+
}
393+
}
394+
395+
return allErrors
396+
}
397+
398+
func validateEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
399+
var allErrors field.ErrorList
400+
item := 0
401+
for index, container := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
402+
if container.Name == "ray-head" {
403+
item = index
404+
break
405+
}
406+
}
407+
408+
for _, envVar := range envVarList() {
409+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[item].Env, envVar, byEnvVarName,
410+
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(item), "env"),
411+
"RAY_TLS environment variables are immutable"); err != nil {
412+
allErrors = append(allErrors, err)
413+
}
414+
}
415+
416+
for index, container := range rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers {
417+
if container.Name == "machine-learning" {
418+
item = index
419+
break
420+
}
421+
}
422+
423+
for _, envVar := range envVarList() {
424+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[item].Env, envVar, byEnvVarName,
425+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "containers", strconv.Itoa(item), "env"),
426+
"RAY_TLS environment variables are immutable"); err != nil {
427+
allErrors = append(allErrors, err)
428+
}
429+
}
430+
431+
return allErrors
432+
}

pkg/controllers/support.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,14 @@ func withVolumeName(name string) compare[corev1.Volume] {
140140
return v1.Name == name
141141
}
142142
}
143+
144+
var byEnvVarName = compare[corev1.EnvVar](
145+
func(e1, e2 corev1.EnvVar) bool {
146+
return e1.Name == e2.Name
147+
})
148+
149+
func withEnvVarName(name string) compare[corev1.EnvVar] {
150+
return func(e1, e2 corev1.EnvVar) bool {
151+
return e1.Name == name
152+
}
153+
}

0 commit comments

Comments
 (0)