Skip to content

Commit e7a93b6

Browse files
sutaakarFiona-Waters
authored andcommitted
Refactor Machine pool functions
1 parent 348bd00 commit e7a93b6

File tree

7 files changed

+143
-108
lines changed

7 files changed

+143
-108
lines changed

test/e2e/instascale.go renamed to test/e2e/instascale_app_wrapper.go

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,55 +9,10 @@ import (
99
"k8s.io/apimachinery/pkg/api/resource"
1010
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1111

12-
ocmsdk "github.com/openshift-online/ocm-sdk-go"
13-
1412
. "github.com/project-codeflare/codeflare-operator/test/support"
1513
)
1614

17-
func TestConfig(test Test, namespace string) (*corev1.ConfigMap, error) {
18-
// Test configuration
19-
configMap := &corev1.ConfigMap{
20-
TypeMeta: metav1.TypeMeta{
21-
APIVersion: corev1.SchemeGroupVersion.String(),
22-
Kind: "ConfigMap",
23-
},
24-
ObjectMeta: metav1.ObjectMeta{
25-
Name: "mnist-mcad",
26-
Namespace: namespace,
27-
},
28-
BinaryData: map[string][]byte{
29-
// pip requirements
30-
"requirements.txt": ReadFile(test, "mnist_pip_requirements.txt"),
31-
// MNIST training script
32-
"mnist.py": ReadFile(test, "mnist.py"),
33-
},
34-
Immutable: Ptr(true),
35-
}
36-
37-
config, err := test.Client().Core().CoreV1().ConfigMaps(namespace).Create(test.Ctx(), configMap, metav1.CreateOptions{})
38-
test.Expect(err).NotTo(HaveOccurred())
39-
test.T().Logf("Created ConfigMap %s/%s successfully", config.Namespace, config.Name)
40-
41-
return configMap, err
42-
}
43-
44-
func CreateConnection(test Test) (*ocmsdk.Connection, error) {
45-
instascaleOCMSecret, err := test.Client().Core().CoreV1().Secrets("default").Get(test.Ctx(), "instascale-ocm-secret", metav1.GetOptions{})
46-
if err != nil {
47-
test.T().Errorf("unable to retrieve instascale-ocm-secret - Error : %v", err)
48-
}
49-
test.Expect(err).NotTo(HaveOccurred())
50-
ocmToken := string(instascaleOCMSecret.Data["token"])
51-
test.T().Logf("Retrieved Secret %s successfully", instascaleOCMSecret.Name)
52-
53-
connection, err := CreateOCMConnection(ocmToken)
54-
if err != nil {
55-
test.T().Errorf("Unable to create ocm connection - Error : %v", err)
56-
}
57-
return connection, err
58-
}
59-
60-
func JobAppwrapperSetup(test Test, namespace *corev1.Namespace, config *corev1.ConfigMap) (*batchv1.Job, *mcadv1beta1.AppWrapper, error) {
15+
func createInstaScaleJobAppWrapper(test Test, namespace *corev1.Namespace, config *corev1.ConfigMap) (*batchv1.Job, *mcadv1beta1.AppWrapper, error) {
6116
// Batch Job
6217
job := &batchv1.Job{
6318
TypeMeta: metav1.TypeMeta{
@@ -78,7 +33,7 @@ func JobAppwrapperSetup(test Test, namespace *corev1.Namespace, config *corev1.C
7833
Name: "job",
7934
Image: GetPyTorchImage(),
8035
Env: []corev1.EnvVar{
81-
corev1.EnvVar{Name: "PYTHONUSERBASE", Value: "/test2"},
36+
{Name: "PYTHONUSERBASE", Value: "/workdir"},
8237
},
8338
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
8439
Args: []string{"$PYTHONUSERBASE"},
@@ -92,7 +47,7 @@ func JobAppwrapperSetup(test Test, namespace *corev1.Namespace, config *corev1.C
9247
MountPath: "/workdir",
9348
},
9449
},
95-
WorkingDir: "workdir",
50+
WorkingDir: "/workdir",
9651
},
9752
},
9853
Volumes: []corev1.Volume{
@@ -125,7 +80,7 @@ func JobAppwrapperSetup(test Test, namespace *corev1.Namespace, config *corev1.C
12580
Name: "test-instascale",
12681
Namespace: namespace.Name,
12782
Labels: map[string]string{
128-
"orderedinstance": "m5.xlarge_g4dn.xlarge",
83+
"orderedinstance": "g4dn.xlarge",
12984
},
13085
},
13186
Spec: mcadv1beta1.AppWrapperSpec{
@@ -170,7 +125,7 @@ func JobAppwrapperSetup(test Test, namespace *corev1.Namespace, config *corev1.C
170125
test.Expect(err).NotTo(HaveOccurred())
171126
test.T().Logf("AppWrapper created successfully %s/%s", aw.Namespace, aw.Name)
172127

173-
test.Eventually(AppWrapper(test, namespace, aw.Name), TestTimeoutShort).
128+
test.Eventually(AppWrapper(test, namespace, aw.Name), TestTimeoutGpuProvisioning).
174129
Should(WithTransform(AppWrapperState, Equal(mcadv1beta1.AppWrapperStateActive)))
175130

176131
return job, aw, err

test/e2e/instascale_machinepool_test.go

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package e2e
22

33
import (
44
"testing"
5-
"time"
65

76
. "github.com/onsi/gomega"
87
mcadv1beta1 "github.com/project-codeflare/multi-cluster-app-dispatcher/pkg/apis/controller/v1beta1"
@@ -21,33 +20,31 @@ func TestInstascaleMachinePool(t *testing.T) {
2120
namespace := test.NewTestNamespace()
2221

2322
// Test configuration
24-
config, err := TestConfig(test, namespace.Name)
25-
test.Expect(err).NotTo(HaveOccurred())
23+
testConfigData := map[string][]byte{
24+
// pip requirements
25+
"requirements.txt": ReadFile(test, "mnist_pip_requirements.txt"),
26+
// MNIST training script
27+
"mnist.py": ReadFile(test, "mnist.py"),
28+
}
29+
cm := CreateConfigMap(test, namespace.Name, testConfigData)
2630

2731
//create OCM connection
28-
connection, err := CreateConnection(test)
29-
test.Expect(err).NotTo(HaveOccurred())
32+
connection := CreateOCMConnection(test)
3033

3134
defer connection.Close()
3235

3336
// check existing cluster machine pool resources
3437
// look for machine pool with aw name - expect not to find it
35-
foundMachinePool, err := CheckMachinePools(connection, TestName)
36-
test.Expect(err).NotTo(HaveOccurred())
37-
test.Expect(foundMachinePool).To(BeFalse())
38+
test.Expect(GetMachinePools(test, connection)).
39+
ShouldNot(ContainElement(WithTransform(MachinePoolId, Equal("test-instascale-g4dn-xlarge"))))
3840

3941
// Setup batch job and AppWrapper
40-
job, aw, err := JobAppwrapperSetup(test, namespace, config)
42+
job, aw, err := createInstaScaleJobAppWrapper(test, namespace, cm)
4143
test.Expect(err).NotTo(HaveOccurred())
4244

43-
// time.Sleep is used twice throughout the test, each for 30 seconds. Can look into using sync package waitGroup instead if that makes more sense
44-
// wait for required resources to scale up before checking them again
45-
time.Sleep(TestTimeoutMedium)
46-
4745
// look for machine pool with aw name - expect to find it
48-
foundMachinePool, err = CheckMachinePools(connection, TestName)
49-
test.Expect(err).NotTo(HaveOccurred())
50-
test.Expect(foundMachinePool).To(BeTrue())
46+
test.Eventually(MachinePools(test, connection), TestTimeoutLong).
47+
Should(ContainElement(WithTransform(MachinePoolId, Equal("test-instascale-g4dn-xlarge"))))
5148

5249
// Assert that the job has completed
5350
test.T().Logf("Waiting for Job %s/%s to complete", job.Namespace, job.Name)
@@ -64,12 +61,8 @@ func TestInstascaleMachinePool(t *testing.T) {
6461
test.Eventually(AppWrapper(test, namespace, aw.Name), TestTimeoutShort).
6562
Should(WithTransform(AppWrapperState, Equal(mcadv1beta1.AppWrapperStateCompleted)))
6663

67-
// allow time for the resources to scale down before checking them again
68-
time.Sleep(TestTimeoutMedium)
69-
7064
// look for machine pool with aw name - expect not to find it
71-
foundMachinePool, err = CheckMachinePools(connection, TestName)
72-
test.Expect(err).NotTo(HaveOccurred())
73-
test.Expect(foundMachinePool).To(BeFalse())
65+
test.Eventually(MachinePools(test, connection), TestTimeoutLong).
66+
ShouldNot(ContainElement(WithTransform(MachinePoolId, Equal("test-instascale-g4dn-xlarge"))))
7467

7568
}

test/support/clusterpools.go

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"strings"
88

9+
"github.com/onsi/gomega"
910
ocmsdk "github.com/openshift-online/ocm-sdk-go"
1011
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
1112
)
@@ -15,42 +16,21 @@ var (
1516
TestName string = "test-instascale"
1617
)
1718

18-
func CreateOCMConnection(secret string) (*ocmsdk.Connection, error) {
19-
logger, err := ocmsdk.NewGoLoggerBuilder().
20-
Debug(false).
21-
Build()
22-
if err != nil {
23-
fmt.Fprintf(os.Stderr, "Can't build logger: %v\n", err)
24-
return nil, err
25-
}
26-
connection, err := ocmsdk.NewConnectionBuilder().
27-
Logger(logger).
28-
Tokens(string(secret)).
29-
Build()
30-
if err != nil || connection == nil {
31-
fmt.Fprintf(os.Stderr, "Can't build connection: %v\n", err)
32-
return nil, err
19+
func MachinePools(t Test, connection *ocmsdk.Connection) func(g gomega.Gomega) []*cmv1.MachinePool {
20+
return func(g gomega.Gomega) []*cmv1.MachinePool {
21+
machinePoolsListResponse, err := connection.ClustersMgmt().V1().Clusters().Cluster(ClusterID).MachinePools().List().Send()
22+
g.Expect(err).NotTo(gomega.HaveOccurred())
23+
return machinePoolsListResponse.Items().Slice()
3324
}
34-
35-
return connection, nil
3625
}
3726

38-
func CheckMachinePools(connection *ocmsdk.Connection, awName string) (foundMachinePool bool, err error) {
39-
machinePoolsConnection := connection.ClustersMgmt().V1().Clusters().Cluster(ClusterID).MachinePools().List()
40-
machinePoolsListResponse, err := machinePoolsConnection.Send()
41-
if err != nil {
42-
return false, fmt.Errorf("unable to send request, error: %v", err)
43-
}
44-
machinePoolsList := machinePoolsListResponse.Items()
45-
machinePoolsList.Range(func(index int, item *cmv1.MachinePool) bool {
46-
instanceName, _ := item.GetID()
47-
if strings.Contains(instanceName, awName) {
48-
foundMachinePool = true
49-
}
50-
return true
51-
})
27+
func GetMachinePools(t Test, connection *ocmsdk.Connection) []*cmv1.MachinePool {
28+
t.T().Helper()
29+
return MachinePools(t, connection)(t)
30+
}
5231

53-
return foundMachinePool, err
32+
func MachinePoolId(machinePool *cmv1.MachinePool) string {
33+
return machinePool.ID()
5434
}
5535

5636
func CheckNodePools(connection *ocmsdk.Connection, awName string) (foundNodePool bool, err error) {

test/support/codeflare.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ const (
3030
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
3131

3232
// The testing output directory, to write output files into.
33-
3433
CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR"
34+
35+
// The name of a secret containing InstaScale OCM token.
36+
InstaScaleOcmSecretName = "INSTASCALE_OCM_SECRET_NAME"
37+
// The namespace where a secret containing InstaScale OCM token is stored.
38+
InstaScaleOcmSecretNamespace = "INSTASCALE_OCM_SECRET_NAMESPACE"
3539
)
3640

3741
func GetCodeFlareSDKVersion() string {
@@ -50,6 +54,14 @@ func GetPyTorchImage() string {
5054
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
5155
}
5256

57+
func GetInstaScaleOcmSecretName() string {
58+
return lookupEnvOrDefault(InstaScaleOcmSecretName, "instascale-ocm-secret")
59+
}
60+
61+
func GetInstaScaleOcmSecretNamespace() string {
62+
return lookupEnvOrDefault(InstaScaleOcmSecretNamespace, "default")
63+
}
64+
5365
func lookupEnvOrDefault(key, value string) string {
5466
if v, ok := os.LookupEnv(key); ok {
5567
return v

test/support/config_map.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
Copyright 2023.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package support
18+
19+
import (
20+
"github.com/onsi/gomega"
21+
corev1 "k8s.io/api/core/v1"
22+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
23+
)
24+
25+
func CreateConfigMap(t Test, namespace string, content map[string][]byte) *corev1.ConfigMap {
26+
configMap := &corev1.ConfigMap{
27+
TypeMeta: metav1.TypeMeta{
28+
APIVersion: corev1.SchemeGroupVersion.String(),
29+
Kind: "ConfigMap",
30+
},
31+
ObjectMeta: metav1.ObjectMeta{
32+
GenerateName: "config-",
33+
Namespace: namespace,
34+
},
35+
BinaryData: content,
36+
Immutable: Ptr(true),
37+
}
38+
39+
configMap, err := t.Client().Core().CoreV1().ConfigMaps(namespace).Create(t.Ctx(), configMap, metav1.CreateOptions{})
40+
t.Expect(err).NotTo(gomega.HaveOccurred())
41+
t.T().Logf("Created ConfigMap %s/%s successfully", configMap.Namespace, configMap.Name)
42+
43+
return configMap
44+
}

test/support/ocm.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
Copyright 2023.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package support
18+
19+
import (
20+
"fmt"
21+
"os"
22+
23+
"github.com/onsi/gomega"
24+
ocmsdk "github.com/openshift-online/ocm-sdk-go"
25+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
26+
)
27+
28+
func CreateOCMConnection(test Test) *ocmsdk.Connection {
29+
instascaleOCMSecret, err := test.Client().Core().CoreV1().Secrets(GetInstaScaleOcmSecretNamespace()).Get(test.Ctx(), GetInstaScaleOcmSecretName(), metav1.GetOptions{})
30+
test.Expect(err).NotTo(gomega.HaveOccurred())
31+
32+
ocmToken := string(instascaleOCMSecret.Data["token"])
33+
test.T().Logf("Retrieved Secret %s/%s successfully", instascaleOCMSecret.Namespace, instascaleOCMSecret.Name)
34+
35+
connection, err := buildOCMConnection(ocmToken)
36+
test.Expect(err).NotTo(gomega.HaveOccurred())
37+
return connection
38+
}
39+
40+
func buildOCMConnection(secret string) (*ocmsdk.Connection, error) {
41+
connection, err := ocmsdk.NewConnectionBuilder().
42+
Tokens(secret).
43+
Build()
44+
if err != nil || connection == nil {
45+
fmt.Fprintf(os.Stderr, "Can't build connection: %v\n", err)
46+
return nil, err
47+
}
48+
49+
return connection, nil
50+
}

test/support/support.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ import (
3030
var (
3131
ApplyOptions = metav1.ApplyOptions{FieldManager: "codeflare-test", Force: true}
3232

33-
TestTimeoutShort = 1 * time.Minute
34-
TestTimeoutMedium = 2 * time.Minute
35-
TestTimeoutLong = 5 * time.Minute
33+
TestTimeoutShort = 1 * time.Minute
34+
TestTimeoutMedium = 2 * time.Minute
35+
TestTimeoutLong = 5 * time.Minute
36+
TestTimeoutGpuProvisioning = 30 * time.Minute
3637
)
3738

3839
func init() {

0 commit comments

Comments
 (0)