Skip to content

Commit f2c6618

Browse files
committed
test: Parameterize PyTorch image
1 parent 06a8659 commit f2c6618

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

test/e2e/mnist_pytorch_mcad_job_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func TestMNISTPyTorchMCAD(t *testing.T) {
7777
Containers: []corev1.Container{
7878
{
7979
Name: "job",
80-
Image: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime",
80+
Image: GetPyTorchImage(),
8181
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
8282
VolumeMounts: []corev1.VolumeMount{
8383
{

test/support/codeflare.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ import (
2323
// The environment variables hereafter can be used to change the components
2424
// used for testing.
2525
const (
26-
CodeFlareTestSdkVersion = "CODEFLARE_TEST_SDK_VERSION"
27-
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
28-
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
26+
CodeFlareTestSdkVersion = "CODEFLARE_TEST_SDK_VERSION"
27+
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
28+
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
29+
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
2930
)
3031

3132
func GetCodeFlareSDKVersion() string {
@@ -40,6 +41,10 @@ func GetRayImage() string {
4041
return lookupEnvOrDefault(CodeFlareTestRayImage, RayImage)
4142
}
4243

44+
func GetPyTorchImage() string {
45+
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
46+
}
47+
4348
func lookupEnvOrDefault(key, value string) string {
4449
if v, ok := os.LookupEnv(key); ok {
4550
return v

0 commit comments

Comments
 (0)