Skip to content

Commit 9a96e45

Browse files
committed
test: Pass MNIST training with CodeFlare SDK on OpenShift
1 parent 49563ef commit 9a96e45

File tree

6 files changed

+225
-82
lines changed

6 files changed

+225
-82
lines changed

test/e2e/mnist.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1-
# In[]
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
215
import os
316

417
import torch
518
from pytorch_lightning import LightningModule, Trainer
619
from pytorch_lightning.callbacks.progress import TQDMProgressBar
7-
from pytorch_lightning.loggers import CSVLogger
820
from torch import nn
921
from torch.nn import functional as F
1022
from torch.utils.data import DataLoader, random_split

test/e2e/mnist_raycluster_sdk.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import sys
2+
3+
from time import sleep
4+
5+
from torchx.specs.api import AppState, is_terminal
6+
7+
from codeflare_sdk.cluster.cluster import Cluster, ClusterConfiguration
8+
from codeflare_sdk.job.jobs import DDPJobDefinition
9+
10+
namespace = sys.argv[1]
11+
12+
cluster = Cluster(ClusterConfiguration(
13+
name='mnist',
14+
namespace=namespace,
15+
min_worker=1,
16+
max_worker=1,
17+
min_cpus='500m',
18+
max_cpus=1,
19+
min_memory=0.5,
20+
max_memory=1,
21+
gpu=0,
22+
instascale=False,
23+
))
24+
25+
cluster.up()
26+
27+
cluster.status()
28+
29+
cluster.wait_ready()
30+
31+
cluster.status()
32+
33+
cluster.details()
34+
35+
jobdef = DDPJobDefinition(
36+
name="mnist",
37+
script="mnist.py",
38+
scheduler_args={"requirements": "requirements.txt"},
39+
)
40+
job = jobdef.submit(cluster)
41+
42+
done = False
43+
time = 0
44+
timeout = 300
45+
while not done:
46+
status = job.status()
47+
if is_terminal(status.state):
48+
break
49+
if not done:
50+
print(status)
51+
if timeout and time >= timeout:
52+
raise TimeoutError(f"job has timed out after waiting {timeout}s")
53+
sleep(5)
54+
time += 5
55+
56+
print(f"Job has completed: {status.state}")
57+
58+
print(job.logs())
59+
60+
cluster.down()
61+
62+
if not status.state == AppState.SUCCEEDED:
63+
exit(1)
64+
else:
65+
exit(0)

test/e2e/mnist_raycluster_sdk_test.go

Lines changed: 110 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,57 +23,119 @@ import (
2323

2424
batchv1 "k8s.io/api/batch/v1"
2525
corev1 "k8s.io/api/core/v1"
26+
rbacv1 "k8s.io/api/rbac/v1"
2627
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28+
"k8s.io/apimachinery/pkg/labels"
29+
30+
rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
2731

2832
. "github.com/project-codeflare/codeflare-operator/test/support"
33+
mcadv1beta1 "github.com/project-codeflare/multi-cluster-app-dispatcher/pkg/apis/controller/v1beta1"
2934
)
3035

3136
func TestMNISTRayClusterSDK(t *testing.T) {
3237
test := With(t)
3338
test.T().Parallel()
3439

35-
test.T().Skip("Requires https://github.com/project-codeflare/codeflare-sdk/pull/146")
40+
if !IsOpenShift(test) {
41+
test.T().Skip("Requires https://github.com/project-codeflare/codeflare-sdk/pull/146")
42+
}
3643

3744
// Create a namespace
3845
namespace := test.NewTestNamespace()
3946

40-
// SDK script
41-
sdk := &corev1.ConfigMap{
47+
// Test configuration
48+
configMap := &corev1.ConfigMap{
4249
TypeMeta: metav1.TypeMeta{
4350
APIVersion: corev1.SchemeGroupVersion.String(),
4451
Kind: "ConfigMap",
4552
},
4653
ObjectMeta: metav1.ObjectMeta{
47-
Name: "sdk",
54+
Name: "mnist-raycluster-sdk",
4855
Namespace: namespace.Name,
4956
},
5057
BinaryData: map[string][]byte{
51-
"sdk.py": ReadFile(test, "sdk.py"),
58+
// SDK script
59+
"mnist_raycluster_sdk.py": ReadFile(test, "mnist_raycluster_sdk.py"),
60+
// pip requirements
61+
"requirements.txt": ReadFile(test, "requirements.txt"),
62+
// MNIST training script
63+
"mnist.py": ReadFile(test, "mnist.py"),
5264
},
5365
Immutable: Ptr(true),
5466
}
55-
sdk, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), sdk, metav1.CreateOptions{})
67+
configMap, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), configMap, metav1.CreateOptions{})
5668
test.Expect(err).NotTo(HaveOccurred())
57-
test.T().Logf("Created ConfigMap %s/%s successfully", sdk.Namespace, sdk.Name)
69+
test.T().Logf("Created ConfigMap %s/%s successfully", configMap.Namespace, configMap.Name)
5870

59-
// pip requirements
60-
requirements := &corev1.ConfigMap{
71+
// SDK client RBAC
72+
serviceAccount := &corev1.ServiceAccount{
6173
TypeMeta: metav1.TypeMeta{
6274
APIVersion: corev1.SchemeGroupVersion.String(),
63-
Kind: "ConfigMap",
75+
Kind: "ServiceAccount",
6476
},
6577
ObjectMeta: metav1.ObjectMeta{
66-
Name: "requirements",
78+
Name: "sdk-user",
6779
Namespace: namespace.Name,
6880
},
69-
BinaryData: map[string][]byte{
70-
"requirements.txt": ReadFile(test, "requirements.txt"),
81+
}
82+
serviceAccount, err = test.Client().Core().CoreV1().ServiceAccounts(namespace.Name).Create(test.Ctx(), serviceAccount, metav1.CreateOptions{})
83+
test.Expect(err).NotTo(HaveOccurred())
84+
85+
role := &rbacv1.Role{
86+
TypeMeta: metav1.TypeMeta{
87+
APIVersion: rbacv1.SchemeGroupVersion.String(),
88+
Kind: "Role",
89+
},
90+
ObjectMeta: metav1.ObjectMeta{
91+
Name: "sdk",
92+
Namespace: namespace.Name,
93+
},
94+
Rules: []rbacv1.PolicyRule{
95+
{
96+
Verbs: []string{"get", "create", "delete", "list", "patch", "update"},
97+
APIGroups: []string{mcadv1beta1.GroupName},
98+
Resources: []string{"appwrappers"},
99+
},
100+
{
101+
Verbs: []string{"get", "list"},
102+
APIGroups: []string{rayv1alpha1.GroupVersion.Group},
103+
Resources: []string{"rayclusters", "rayclusters/status"},
104+
},
105+
{
106+
Verbs: []string{"get", "list"},
107+
APIGroups: []string{"route.openshift.io"},
108+
Resources: []string{"routes"},
109+
},
71110
},
72-
Immutable: Ptr(true),
73111
}
74-
requirements, err = test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), requirements, metav1.CreateOptions{})
112+
role, err = test.Client().Core().RbacV1().Roles(namespace.Name).Create(test.Ctx(), role, metav1.CreateOptions{})
113+
test.Expect(err).NotTo(HaveOccurred())
114+
115+
roleBinding := &rbacv1.RoleBinding{
116+
TypeMeta: metav1.TypeMeta{
117+
APIVersion: rbacv1.SchemeGroupVersion.String(),
118+
Kind: "RoleBinding",
119+
},
120+
ObjectMeta: metav1.ObjectMeta{
121+
Name: "sdk",
122+
},
123+
RoleRef: rbacv1.RoleRef{
124+
APIGroup: rbacv1.SchemeGroupVersion.Group,
125+
Kind: "Role",
126+
Name: role.Name,
127+
},
128+
Subjects: []rbacv1.Subject{
129+
{
130+
Kind: "ServiceAccount",
131+
APIGroup: corev1.SchemeGroupVersion.Group,
132+
Name: serviceAccount.Name,
133+
Namespace: serviceAccount.Namespace,
134+
},
135+
},
136+
}
137+
_, err = test.Client().Core().RbacV1().RoleBindings(namespace.Name).Create(test.Ctx(), roleBinding, metav1.CreateOptions{})
75138
test.Expect(err).NotTo(HaveOccurred())
76-
test.T().Logf("Created ConfigMap %s/%s successfully", requirements.Namespace, requirements.Name)
77139

78140
job := &batchv1.Job{
79141
TypeMeta: metav1.TypeMeta{
@@ -92,54 +154,62 @@ func TestMNISTRayClusterSDK(t *testing.T) {
92154
Spec: corev1.PodSpec{
93155
Containers: []corev1.Container{
94156
{
95-
Name: "sdk",
157+
Name: "test",
96158
Image: "quay.io/opendatahub/notebooks:jupyter-minimal-ubi8-python-3.8-4c8f26e",
97-
Command: []string{"/bin/sh", "-c", "pip install -r /test/runtime/requirements.txt && python /test/job/sdk.py"},
159+
Command: []string{"/bin/sh", "-c", "pip install codeflare-sdk==0.4.4 && cp /test/* . && python mnist_raycluster_sdk.py" + " " + namespace.Name},
98160
VolumeMounts: []corev1.VolumeMount{
99161
{
100-
Name: "sdk",
101-
MountPath: "/test/job",
102-
},
103-
{
104-
Name: "requirements",
105-
MountPath: "/test/runtime",
162+
Name: "test",
163+
MountPath: "/test",
106164
},
107165
},
108166
},
109167
},
110168
Volumes: []corev1.Volume{
111169
{
112-
Name: "sdk",
113-
VolumeSource: corev1.VolumeSource{
114-
ConfigMap: &corev1.ConfigMapVolumeSource{
115-
LocalObjectReference: corev1.LocalObjectReference{
116-
Name: sdk.Name,
117-
},
118-
},
119-
},
120-
},
121-
{
122-
Name: "requirements",
170+
Name: "test",
123171
VolumeSource: corev1.VolumeSource{
124172
ConfigMap: &corev1.ConfigMapVolumeSource{
125173
LocalObjectReference: corev1.LocalObjectReference{
126-
Name: requirements.Name,
174+
Name: configMap.Name,
127175
},
128176
},
129177
},
130178
},
131179
},
132-
RestartPolicy: corev1.RestartPolicyNever,
180+
RestartPolicy: corev1.RestartPolicyNever,
181+
ServiceAccountName: serviceAccount.Name,
133182
},
134183
},
135184
},
136185
}
137186
job, err = test.Client().Core().BatchV1().Jobs(namespace.Name).Create(test.Ctx(), job, metav1.CreateOptions{})
138187
test.Expect(err).NotTo(HaveOccurred())
188+
test.T().Logf("Created Job %s/%s successfully", job.Namespace, job.Name)
139189

140190
defer JobTroubleshooting(test, job)
141191

142-
test.T().Logf("Waiting for Job %s/%s to complete successfully", job.Namespace, job.Name)
143-
test.Eventually(Job(test, job.Namespace, job.Name), TestTimeoutMedium).
144-
Should(WithTransform(ConditionStatus(batchv1.JobComplete), Equal(corev1.ConditionTrue)))
192+
test.T().Logf("Waiting for Job %s/%s to complete", job.Namespace, job.Name)
193+
test.Eventually(Job(test, job.Namespace, job.Name), TestTimeoutLong).Should(
194+
Or(
195+
WithTransform(ConditionStatus(batchv1.JobComplete), Equal(corev1.ConditionTrue)),
196+
WithTransform(ConditionStatus(batchv1.JobFailed), Equal(corev1.ConditionTrue)),
197+
))
198+
199+
// Refresh the job to get the generated pod selector
200+
job = GetJob(test, job.Namespace, job.Name)
201+
202+
// Get the job Pod
203+
pods := GetPods(test, job.Namespace, metav1.ListOptions{
204+
LabelSelector: labels.FormatLabels(job.Spec.Selector.MatchLabels)},
205+
)
206+
test.Expect(pods).To(HaveLen(1))
207+
208+
// Print the job logs
209+
test.T().Logf("Printing Job %s/%s logs", job.Namespace, job.Name)
210+
test.T().Log(GetPodLogs(test, &pods[0], corev1.PodLogOptions{}))
211+
212+
// Assert the job has completed successfully
213+
test.T().Logf("Checking the Job %s/%s has completed successfully", job.Namespace, job.Name)
214+
test.Expect(job).To(WithTransform(ConditionStatus(batchv1.JobComplete), Equal(corev1.ConditionTrue)))
145215
}

test/e2e/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
codeflare-sdk==0.4.4
1+
pytorch_lightning==1.5.10
2+
torchmetrics==0.9.1
3+
torchvision==0.12.0

test/e2e/sdk.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

test/support/openshift.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
Copyright 2023.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package support
18+
19+
import (
20+
"github.com/onsi/gomega"
21+
22+
"k8s.io/apimachinery/pkg/api/errors"
23+
)
24+
25+
func IsOpenShift(test Test) bool {
26+
test.T().Helper()
27+
_, err := test.Client().Core().Discovery().ServerResourcesForGroupVersion("image.openshift.io/v1")
28+
if err != nil && errors.IsNotFound(err) {
29+
return false
30+
}
31+
test.Expect(err).NotTo(gomega.HaveOccurred())
32+
return true
33+
}

0 commit comments

Comments
 (0)