Skip to content

Commit 23c8dcc

Browse files
committed
Added Mtls patch
(cherry picked from commit de2de96fc88022df783b637ccb145d1d73ba66ff) Review changes
1 parent 2fe0a52 commit 23c8dcc

File tree

5 files changed

+261
-7
lines changed

5 files changed

+261
-7
lines changed

config/rbac/role.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ rules:
4444
- subjectaccessreviews
4545
verbs:
4646
- create
47+
- apiGroups:
48+
- config.openshift.io
49+
resources:
50+
- ingresses
51+
verbs:
52+
- get
4753
- apiGroups:
4854
- ""
4955
resources:

main.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ import (
5252
"sigs.k8s.io/yaml"
5353

5454
routev1 "github.com/openshift/api/route/v1"
55+
clientset "github.com/openshift/client-go/config/clientset/versioned"
5556

5657
"github.com/project-codeflare/codeflare-operator/pkg/config"
5758
"github.com/project-codeflare/codeflare-operator/pkg/controllers"
@@ -75,6 +76,8 @@ func init() {
7576
utilruntime.Must(routev1.Install(scheme))
7677
}
7778

79+
// +kubebuilder:rbac:groups=config.openshift.io,resources=ingresses,verbs=get;
80+
7881
func main() {
7982
var configMapName string
8083
flag.StringVar(&configMapName, "config", "codeflare-operator-config",
@@ -117,6 +120,7 @@ func main() {
117120
KubeRay: &config.KubeRayConfiguration{
118121
RayDashboardOAuthEnabled: ptr.To(true),
119122
IngressDomain: "",
123+
MTLSEnabled: ptr.To(true),
120124
},
121125
}
122126

@@ -155,6 +159,13 @@ func main() {
155159
certsReady := make(chan struct{})
156160
exitOnError(setupCertManagement(mgr, namespace, certsReady), "unable to setup cert-controller")
157161

162+
if cfg.KubeRay.IngressDomain == "" {
163+
configClient, err := clientset.NewForConfig(kubeConfig)
164+
exitOnError(err, "unable to create Route Client Set")
165+
cfg.KubeRay.IngressDomain, err = getClusterDomain(ctx, configClient)
166+
exitOnError(err, cfg.KubeRay.IngressDomain)
167+
}
168+
158169
go setupControllers(mgr, kubeClient, cfg, isOpenShift(ctx, kubeClient.DiscoveryClient), certsReady)
159170

160171
setupLog.Info("setting up health endpoints")
@@ -332,3 +343,17 @@ func isOpenShift(ctx context.Context, dc discovery.DiscoveryInterface) bool {
332343
logger.Info("We detected being on Vanilla Kubernetes!")
333344
return false
334345
}
346+
347+
func getClusterDomain(ctx context.Context, configClient *clientset.Clientset) (string, error) {
348+
ingress, err := configClient.ConfigV1().Ingresses().Get(ctx, "cluster", metav1.GetOptions{})
349+
if err != nil {
350+
return "", fmt.Errorf("failed to get Ingress object: %v", err)
351+
}
352+
353+
domain := ingress.Spec.Domain
354+
if domain == "" {
355+
return "", fmt.Errorf("domain is not set in the Ingress object")
356+
}
357+
358+
return domain, nil
359+
}

pkg/config/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type KubeRayConfiguration struct {
3535
RayDashboardOAuthEnabled *bool `json:"rayDashboardOAuthEnabled,omitempty"`
3636

3737
IngressDomain string `json:"ingressDomain"`
38+
39+
MTLSEnabled *bool `json:"mTLSEnabled,omitempty"`
3840
}
3941

4042
type ControllerManager struct {

pkg/controllers/raycluster_webhook.go

Lines changed: 217 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package controllers
1818

1919
import (
2020
"context"
21+
"strconv"
2122

2223
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
2324

@@ -36,6 +37,7 @@ import (
3637
const (
3738
oauthProxyContainerName = "oauth-proxy"
3839
oauthProxyVolumeName = "proxy-tls-secret"
40+
initContainerName = "create-cert"
3941
)
4042

4143
// log is for logging in this package.
@@ -66,17 +68,47 @@ var _ webhook.CustomValidator = &rayClusterWebhook{}
6668
func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) error {
6769
rayCluster := obj.(*rayv1.RayCluster)
6870

69-
if !ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
70-
return nil
71-
}
71+
if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
72+
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
73+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
7274

73-
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
75+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
7476

75-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
77+
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
78+
}
7679

77-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
80+
if ptr.Deref(w.Config.MTLSEnabled, true) {
81+
rayclusterlog.V(2).Info("Adding create-cert Init Containers")
82+
// HeadGroupSpec //
83+
// Append the list of environment variables for the ray-head container
84+
for index := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
85+
for _, envVar := range envVarList() {
86+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[index].Env = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[index].Env, envVar, withEnvVarName(envVar.Name))
87+
}
88+
}
89+
90+
// Append the create-cert Init Container
91+
rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, w.rayHeadInitContainer(rayCluster), withContainerName(initContainerName))
92+
93+
// Append the CA volumes
94+
for _, caVol := range caVolumes(rayCluster) {
95+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
96+
}
97+
// WorkerGroupSpec //
98+
// Append the list of environment variables for the machine-learning container
99+
for index := range rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers {
100+
for _, envVar := range envVarList() {
101+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[index].Env = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[index].Env, envVar, withEnvVarName(envVar.Name))
102+
}
103+
}
104+
// Append the CA volumes
105+
for _, caVol := range caVolumes(rayCluster) {
106+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
107+
}
108+
// Append the create-cert Init Container
109+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(rayCluster), withContainerName(initContainerName))
78110

79-
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
111+
}
80112

81113
return nil
82114
}
@@ -225,3 +257,181 @@ func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
225257
},
226258
}
227259
}
260+
261+
func initCaVolumeMounts() []corev1.VolumeMount {
262+
keyVolumes := []corev1.VolumeMount{
263+
{
264+
Name: "ca-vol",
265+
MountPath: "/home/ray/workspace/ca",
266+
ReadOnly: true,
267+
},
268+
{
269+
Name: "server-cert",
270+
MountPath: "/home/ray/workspace/tls",
271+
ReadOnly: false,
272+
},
273+
}
274+
return keyVolumes
275+
}
276+
277+
func envVarList() []corev1.EnvVar {
278+
envList := []corev1.EnvVar{
279+
{
280+
Name: "MY_POD_IP",
281+
ValueFrom: &corev1.EnvVarSource{
282+
FieldRef: &corev1.ObjectFieldSelector{
283+
FieldPath: "status.podIP",
284+
},
285+
},
286+
},
287+
{
288+
Name: "RAY_USE_TLS",
289+
Value: "1",
290+
},
291+
{
292+
Name: "RAY_TLS_SERVER_CERT",
293+
Value: "/home/ray/workspace/tls/server.crt",
294+
},
295+
{
296+
Name: "RAY_TLS_SERVER_KEY",
297+
Value: "/home/ray/workspace/tls/server.key",
298+
},
299+
{
300+
Name: "RAY_TLS_CA_CERT",
301+
Value: "/home/ray/workspace/tls/ca.crt",
302+
},
303+
}
304+
return envList
305+
}
306+
307+
func caVolumes(rayCluster *rayv1.RayCluster) []corev1.Volume {
308+
secretName := `ca-secret-` + rayCluster.Name
309+
caVolumes := []corev1.Volume{
310+
{
311+
Name: "ca-vol",
312+
VolumeSource: corev1.VolumeSource{
313+
Secret: &corev1.SecretVolumeSource{
314+
SecretName: secretName,
315+
},
316+
},
317+
},
318+
{
319+
Name: "server-cert",
320+
VolumeSource: corev1.VolumeSource{
321+
EmptyDir: &corev1.EmptyDirVolumeSource{},
322+
},
323+
},
324+
}
325+
return caVolumes
326+
}
327+
328+
func (w *rayClusterWebhook) rayHeadInitContainer(rayCluster *rayv1.RayCluster) corev1.Container {
329+
rayClientRoute := "rayclient-" + rayCluster.Name + "-" + rayCluster.Namespace + "." + w.Config.IngressDomain
330+
// Service name for basic interactive
331+
svcDomain := rayCluster.Name + "-head-svc." + rayCluster.Namespace + ".svc"
332+
333+
initContainerHead := corev1.Container{
334+
Name: "create-cert",
335+
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
336+
Command: []string{
337+
"sh",
338+
"-c",
339+
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)\nDNS.5 = ` + rayClientRoute + `\nDNS.6 = ` + svcDomain + `">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
340+
},
341+
VolumeMounts: initCaVolumeMounts(),
342+
}
343+
return initContainerHead
344+
}
345+
346+
func rayWorkerInitContainer(rayCluster *rayv1.RayCluster) corev1.Container {
347+
initContainerWorker := corev1.Container{
348+
Name: "create-cert",
349+
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
350+
Command: []string{
351+
"sh",
352+
"-c",
353+
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
354+
},
355+
VolumeMounts: initCaVolumeMounts(),
356+
}
357+
return initContainerWorker
358+
}
359+
360+
func (w *rayClusterWebhook) validateHeadInitContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
361+
var allErrors field.ErrorList
362+
363+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, w.rayHeadInitContainer(rayCluster), byContainerName,
364+
field.NewPath("spec", "headGroupSpec", "template", "spec", "initContainers"),
365+
"create-cert Init Container is immutable"); err != nil {
366+
allErrors = append(allErrors, err)
367+
}
368+
369+
return allErrors
370+
}
371+
372+
func (w *rayClusterWebhook) validateWorkerInitContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
373+
var allErrors field.ErrorList
374+
375+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(rayCluster), byContainerName,
376+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "initContainers"),
377+
"create-cert Init Container is immutable"); err != nil {
378+
allErrors = append(allErrors, err)
379+
}
380+
381+
return allErrors
382+
}
383+
384+
func validateCaVolumes(rayCluster *rayv1.RayCluster) field.ErrorList {
385+
var allErrors field.ErrorList
386+
387+
for _, caVol := range caVolumes(rayCluster) {
388+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, byVolumeName,
389+
field.NewPath("spec", "headGroupSpec", "template", "spec", "volumes"),
390+
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
391+
allErrors = append(allErrors, err)
392+
}
393+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, byVolumeName,
394+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "volumes"),
395+
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
396+
allErrors = append(allErrors, err)
397+
}
398+
}
399+
400+
return allErrors
401+
}
402+
403+
func validateEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
404+
var allErrors field.ErrorList
405+
item := 0
406+
for index, container := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
407+
if container.Name == "ray-head" {
408+
item = index
409+
break
410+
}
411+
}
412+
413+
for _, envVar := range envVarList() {
414+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[item].Env, envVar, byEnvVarName,
415+
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(item), "env"),
416+
"RAY_TLS environment variables are immutable"); err != nil {
417+
allErrors = append(allErrors, err)
418+
}
419+
}
420+
421+
for index, container := range rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers {
422+
if container.Name == "machine-learning" {
423+
item = index
424+
break
425+
}
426+
}
427+
428+
for _, envVar := range envVarList() {
429+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[item].Env, envVar, byEnvVarName,
430+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "containers", strconv.Itoa(item), "env"),
431+
"RAY_TLS environment variables are immutable"); err != nil {
432+
allErrors = append(allErrors, err)
433+
}
434+
}
435+
436+
return allErrors
437+
}

pkg/controllers/support.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,14 @@ func withVolumeName(name string) compare[corev1.Volume] {
140140
return v1.Name == name
141141
}
142142
}
143+
144+
var byEnvVarName = compare[corev1.EnvVar](
145+
func(e1, e2 corev1.EnvVar) bool {
146+
return e1.Name == e2.Name
147+
})
148+
149+
func withEnvVarName(name string) compare[corev1.EnvVar] {
150+
return func(e1, e2 corev1.EnvVar) bool {
151+
return e1.Name == name
152+
}
153+
}

0 commit comments

Comments
 (0)