Skip to content

Commit 1c85429

Browse files
Use deep semantic comparison in RayCluster validation webhook
1 parent a51f336 commit 1c85429

File tree

3 files changed

+158
-59
lines changed

3 files changed

+158
-59
lines changed

config/webhook/manifests.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ webhooks:
2121
- v1
2222
operations:
2323
- CREATE
24-
- UPDATE
2524
resources:
2625
- rayclusters
2726
sideEffects: None

pkg/controllers/raycluster_webhook.go

Lines changed: 107 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ import (
3333
"github.com/project-codeflare/codeflare-operator/pkg/config"
3434
)
3535

36+
const (
37+
oauthProxyContainerName = "oauth-proxy"
38+
oauthProxyVolumeName = "proxy-tls-secret"
39+
)
40+
3641
// log is for logging in this package.
3742
var rayclusterlog = logf.Log.WithName("raycluster-resource")
3843

@@ -47,7 +52,7 @@ func SetupRayClusterWebhookWithManager(mgr ctrl.Manager, cfg *config.KubeRayConf
4752
Complete()
4853
}
4954

50-
// +kubebuilder:webhook:path=/mutate-ray-io-v1-raycluster,mutating=true,failurePolicy=fail,sideEffects=None,groups=ray.io,resources=rayclusters,verbs=create;update,versions=v1,name=mraycluster.kb.io,admissionReviewVersions=v1
55+
// +kubebuilder:webhook:path=/mutate-ray-io-v1-raycluster,mutating=true,failurePolicy=fail,sideEffects=None,groups=ray.io,resources=rayclusters,verbs=create,versions=v1,name=mraycluster.kb.io,admissionReviewVersions=v1
5156
// +kubebuilder:webhook:path=/validate-ray-io-v1-raycluster,mutating=false,failurePolicy=fail,sideEffects=None,groups=ray.io,resources=rayclusters,verbs=create;update,versions=v1,name=vraycluster.kb.io,admissionReviewVersions=v1
5257

5358
type rayClusterWebhook struct {
@@ -65,18 +70,105 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err
6570
return nil
6671
}
6772

68-
// Check and add OAuth proxy if it does not exist
69-
for _, container := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
70-
if container.Name == "oauth-proxy" {
71-
rayclusterlog.V(2).Info("OAuth sidecar already exists, no patch needed")
72-
return nil
73-
}
73+
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
74+
75+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
76+
77+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
78+
79+
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
80+
81+
return nil
82+
}
83+
84+
func (w *rayClusterWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
85+
rayCluster := obj.(*rayv1.RayCluster)
86+
87+
var warnings admission.Warnings
88+
var allErrors field.ErrorList
89+
90+
allErrors = append(allErrors, validateIngress(rayCluster)...)
91+
92+
return warnings, allErrors.ToAggregate()
93+
}
94+
95+
func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
96+
rayCluster := newObj.(*rayv1.RayCluster)
97+
98+
var warnings admission.Warnings
99+
var allErrors field.ErrorList
100+
101+
if !rayCluster.DeletionTimestamp.IsZero() {
102+
// Object is being deleted, skip validations
103+
return nil, nil
104+
}
105+
106+
allErrors = append(allErrors, validateIngress(rayCluster)...)
107+
allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...)
108+
allErrors = append(allErrors, validateOAuthProxyVolume(rayCluster)...)
109+
allErrors = append(allErrors, validateHeadGroupServiceAccountName(rayCluster)...)
110+
111+
return warnings, allErrors.ToAggregate()
112+
}
113+
114+
func (w *rayClusterWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
115+
// Optional: Add delete validation logic here
116+
return nil, nil
117+
}
118+
119+
func validateOAuthProxyContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
120+
var allErrors field.ErrorList
121+
122+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), byContainerName,
123+
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers"),
124+
"OAuth Proxy container is immutable"); err != nil {
125+
allErrors = append(allErrors, err)
74126
}
75127

76-
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
128+
return allErrors
129+
}
130+
131+
func validateOAuthProxyVolume(rayCluster *rayv1.RayCluster) field.ErrorList {
132+
var allErrors field.ErrorList
77133

78-
newOAuthSidecar := corev1.Container{
79-
Name: "oauth-proxy",
134+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), byVolumeName,
135+
field.NewPath("spec", "headGroupSpec", "template", "spec", "volumes"),
136+
"OAuth Proxy TLS Secret volume is immutable"); err != nil {
137+
allErrors = append(allErrors, err)
138+
}
139+
140+
return allErrors
141+
}
142+
143+
func validateIngress(rayCluster *rayv1.RayCluster) field.ErrorList {
144+
var allErrors field.ErrorList
145+
146+
if pointer.BoolDeref(rayCluster.Spec.HeadGroupSpec.EnableIngress, false) {
147+
allErrors = append(allErrors, field.Invalid(
148+
field.NewPath("spec", "headGroupSpec", "enableIngress"),
149+
rayCluster.Spec.HeadGroupSpec.EnableIngress,
150+
"RayCluster resources with EnableIngress set to true or unspecified is not allowed"))
151+
}
152+
153+
return allErrors
154+
}
155+
156+
func validateHeadGroupServiceAccountName(rayCluster *rayv1.RayCluster) field.ErrorList {
157+
var allErrors field.ErrorList
158+
159+
if rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName != rayCluster.Name+"-oauth-proxy" {
160+
allErrors = append(allErrors, field.Invalid(
161+
field.NewPath("spec", "headGroupSpec", "template", "spec", "serviceAccountName"),
162+
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName,
163+
"RayCluster head group service account is immutable"))
164+
}
165+
166+
return allErrors
167+
}
168+
169+
func oauthProxyContainer(rayCluster *rayv1.RayCluster) corev1.Container {
170+
return corev1.Container{
171+
Name: oauthProxyContainerName,
80172
Image: "registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366",
81173
Ports: []corev1.ContainerPort{
82174
{ContainerPort: 8443, Name: "oauth-proxy"},
@@ -106,59 +198,21 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err
106198
},
107199
VolumeMounts: []corev1.VolumeMount{
108200
{
109-
Name: "proxy-tls-secret",
201+
Name: oauthProxyVolumeName,
110202
MountPath: "/etc/tls/private",
111203
ReadOnly: true,
112204
},
113205
},
114206
}
207+
}
115208

116-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = append(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, newOAuthSidecar)
117-
118-
tlsSecretVolume := corev1.Volume{
119-
Name: "proxy-tls-secret",
209+
func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
210+
return corev1.Volume{
211+
Name: oauthProxyVolumeName,
120212
VolumeSource: corev1.VolumeSource{
121213
Secret: &corev1.SecretVolumeSource{
122214
SecretName: rayCluster.Name + "-proxy-tls-secret",
123215
},
124216
},
125217
}
126-
127-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = append(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, tlsSecretVolume)
128-
129-
// Ensure the service account is set
130-
if rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName == "" {
131-
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
132-
}
133-
134-
return nil
135-
}
136-
137-
func (w *rayClusterWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
138-
raycluster := obj.(*rayv1.RayCluster)
139-
var warnings admission.Warnings
140-
var allErrors field.ErrorList
141-
specPath := field.NewPath("spec")
142-
143-
if pointer.BoolDeref(raycluster.Spec.HeadGroupSpec.EnableIngress, false) {
144-
rayclusterlog.Info("Creating RayCluster resources with EnableIngress set to true or unspecified is not allowed")
145-
allErrors = append(allErrors, field.Invalid(specPath.Child("headGroupSpec").Child("enableIngress"), raycluster.Spec.HeadGroupSpec.EnableIngress, "creating RayCluster resources with EnableIngress set to true or unspecified is not allowed"))
146-
}
147-
148-
return warnings, allErrors.ToAggregate()
149-
}
150-
151-
func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
152-
newRayCluster := newObj.(*rayv1.RayCluster)
153-
if !newRayCluster.DeletionTimestamp.IsZero() {
154-
// Object is being deleted, skip validations
155-
return nil, nil
156-
}
157-
warnings, err := w.ValidateCreate(ctx, newRayCluster)
158-
return warnings, err
159-
}
160-
161-
func (w *rayClusterWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
162-
// Optional: Add delete validation logic here
163-
return nil, nil
164218
}

pkg/controllers/support.go

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package controllers
33
import (
44
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
55

6+
corev1 "k8s.io/api/core/v1"
67
networkingv1 "k8s.io/api/networking/v1"
7-
"k8s.io/apimachinery/pkg/types"
8+
"k8s.io/apimachinery/pkg/api/equality"
89
"k8s.io/apimachinery/pkg/util/intstr"
10+
"k8s.io/apimachinery/pkg/util/validation/field"
911
v1 "k8s.io/client-go/applyconfigurations/meta/v1"
1012
networkingv1ac "k8s.io/client-go/applyconfigurations/networking/v1"
1113

@@ -29,7 +31,6 @@ func desiredRayClientRoute(cluster *rayv1.RayCluster) *routeapply.RouteApplyConf
2931
)
3032
}
3133

32-
// Create an Ingress object for the RayCluster
3334
func desiredRayClientIngress(cluster *rayv1.RayCluster, ingressHost string) *networkingv1ac.IngressApplyConfiguration {
3435
return networkingv1ac.Ingress(rayClientNameFromCluster(cluster), cluster.Namespace).
3536
WithLabels(map[string]string{"ray.io/cluster-name": cluster.Name}).
@@ -42,7 +43,7 @@ func desiredRayClientIngress(cluster *rayv1.RayCluster, ingressHost string) *net
4243
WithAPIVersion(cluster.APIVersion).
4344
WithKind(cluster.Kind).
4445
WithName(cluster.Name).
45-
WithUID(types.UID(cluster.UID))).
46+
WithUID(cluster.UID)).
4647
WithSpec(networkingv1ac.IngressSpec().
4748
WithIngressClassName("nginx").
4849
WithRules(networkingv1ac.IngressRule().
@@ -65,15 +66,14 @@ func desiredRayClientIngress(cluster *rayv1.RayCluster, ingressHost string) *net
6566
)
6667
}
6768

68-
// Create an Ingress object for the RayCluster
6969
func desiredClusterIngress(cluster *rayv1.RayCluster, ingressHost string) *networkingv1ac.IngressApplyConfiguration {
7070
return networkingv1ac.Ingress(dashboardNameFromCluster(cluster), cluster.Namespace).
7171
WithLabels(map[string]string{"ray.io/cluster-name": cluster.Name}).
7272
WithOwnerReferences(v1.OwnerReference().
7373
WithAPIVersion(cluster.APIVersion).
7474
WithKind(cluster.Kind).
7575
WithName(cluster.Name).
76-
WithUID(types.UID(cluster.UID))).
76+
WithUID(cluster.UID)).
7777
WithSpec(networkingv1ac.IngressSpec().
7878
WithRules(networkingv1ac.IngressRule().
7979
WithHost(ingressHost). // Full Hostname
@@ -94,3 +94,49 @@ func desiredClusterIngress(cluster *rayv1.RayCluster, ingressHost string) *netwo
9494
),
9595
)
9696
}
97+
98+
type compare[T any] func(T, T) bool
99+
100+
func upsert[T any](items []T, item T, predicate compare[T]) []T {
101+
for i, t := range items {
102+
if predicate(t, item) {
103+
items[i] = item
104+
return items
105+
}
106+
}
107+
return append(items, item)
108+
}
109+
110+
func contains[T any](items []T, item T, predicate compare[T], path *field.Path, msg string) *field.Error {
111+
for _, t := range items {
112+
if predicate(t, item) {
113+
if equality.Semantic.DeepDerivative(item, t) {
114+
return nil
115+
}
116+
return field.Invalid(path, t, msg)
117+
}
118+
}
119+
return field.Required(path, msg)
120+
}
121+
122+
var byContainerName = compare[corev1.Container](
123+
func(c1, c2 corev1.Container) bool {
124+
return c1.Name == c2.Name
125+
})
126+
127+
func withContainerName(name string) compare[corev1.Container] {
128+
return func(c1, c2 corev1.Container) bool {
129+
return c1.Name == name
130+
}
131+
}
132+
133+
var byVolumeName = compare[corev1.Volume](
134+
func(v1, v2 corev1.Volume) bool {
135+
return v1.Name == v2.Name
136+
})
137+
138+
func withVolumeName(name string) compare[corev1.Volume] {
139+
return func(v1, v2 corev1.Volume) bool {
140+
return v1.Name == name
141+
}
142+
}

0 commit comments

Comments
 (0)