Skip to content

Prepare Pytorch MNIST test image for disconnected testing #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/mnist-job-test-image.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This workflow will build the MNIST job test image and push it to the project-codeflare image registry

name: MNIST Job Test Image

on:
workflow_dispatch:
push:
branches:
- main
paths:
- 'test/pytorch_mnist_image/**'

jobs:
push:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set Go
uses: actions/setup-go@v3
with:
go-version: v1.20

- name: Login to Quay.io
uses: redhat-actions/podman-login@v1
with:
username: ${{ secrets.QUAY_ID }}
password: ${{ secrets.QUAY_TOKEN }}
registry: quay.io

- name: Image Build and Push
run: |
make image-mnist-job-test-push
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ ENVTEST_K8S_VERSION = 1.24.2
# used to build the manifests.
ENV ?= default

# Image URL to build MNIST job test image
MNIST_JOB_TEST_VERSION ?= v0.0.2
MNIST_JOB_TEST_IMG ?= $(IMAGE_ORG_BASE)/mnist-job-test:${MNIST_JOB_TEST_VERSION}

# Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set)
ifeq (,$(shell go env GOBIN))
GOBIN=$(shell go env GOPATH)/bin
Expand Down Expand Up @@ -383,3 +387,11 @@ imports: openshift-goimports ## Organize imports in go files using openshift-goi
.PHONY: verify-imports
verify-imports: openshift-goimports ## Run import verifications.
./hack/verify-imports.sh $(OPENSHIFT-GOIMPORTS)

.PHONY: image-mnist-job-test-build
image-mnist-job-test-build: ## Build container image with the MNIST job.
podman build -t ${MNIST_JOB_TEST_IMG} ./test/pytorch_mnist_image

.PHONY: image-mnist-job-test-push
image-mnist-job-test-push: image-mnist-job-test-build ## Push container image with the MNIST job.
podman push ${MNIST_JOB_TEST_IMG}
18 changes: 18 additions & 0 deletions test/pytorch_mnist_image/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Build the manager binary
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime

WORKDIR /test
COPY entrypoint.sh entrypoint.sh

# Install MNIST requirements
COPY mnist_pip_requirements.txt requirements.txt
RUN pip install --requirement requirements.txt

# Prepare MNIST script
COPY mnist.py mnist.py
COPY download_dataset.py download_dataset.py
RUN torchrun download_dataset.py

USER 65532:65532
WORKDIR /workdir
ENTRYPOINT ["/test/entrypoint.sh"]
21 changes: 21 additions & 0 deletions test/pytorch_mnist_image/download_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 IBM, Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
MNIST(PATH_DATASETS, train=True, download=True)
MNIST(PATH_DATASETS, train=False, download=True)
3 changes: 3 additions & 0 deletions test/pytorch_mnist_image/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/sh

torchrun /test/mnist.py
159 changes: 159 additions & 0 deletions test/pytorch_mnist_image/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2022 IBM, Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
import requests
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_WORKDIR = os.environ.get("PATH_WORKDIR", ".")
PATH_DATASETS = os.environ.get("PATH_DATASETS", "/test")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
# %%

print("prior to running the trainer")
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))

class LitMNIST(LightningModule):
def __init__(self, data_dir=PATH_WORKDIR, hidden_size=64, learning_rate=2e-4):

super().__init__()

# Set our init args as class attributes
self.data_dir = data_dir
self.hidden_size = hidden_size
self.learning_rate = learning_rate

# Hardcode some dataset specific attributes
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

# Define PyTorch model
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, self.num_classes),
)

self.val_accuracy = Accuracy()
self.test_accuracy = Accuracy()

def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)

def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.val_accuracy.update(preds, y)

# Calling self.log will surface up scalars for you in TensorBoard
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.val_accuracy, prog_bar=True)

def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.test_accuracy.update(preds, y)

# Calling self.log will surface up scalars for you in TensorBoard
self.log("test_loss", loss, prog_bar=True)
self.log("test_acc", self.test_accuracy, prog_bar=True)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer

####################
# DATA RELATED HOOKS
####################

def prepare_data(self):
MNIST(PATH_DATASETS, train=True, download=True)
MNIST(PATH_DATASETS, train=False, download=True)

def setup(self, stage=None):

# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(PATH_DATASETS, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
PATH_DATASETS, train=False, transform=self.transform
)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


# Init DataLoader from MNIST Dataset

model = LitMNIST()

print("GROUP: ", int(os.environ.get("GROUP_WORLD_SIZE", 1)))
print("LOCAL: ", int(os.environ.get("LOCAL_WORLD_SIZE", 1)))

# Initialize a trainer
trainer = Trainer(
accelerator="auto",
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
max_epochs=5,
callbacks=[TQDMProgressBar(refresh_rate=20)],
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
strategy="ddp",
)

# Train the model ⚡
trainer.fit(model)
3 changes: 3 additions & 0 deletions test/pytorch_mnist_image/mnist_pip_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pytorch_lightning==1.5.10
torchmetrics==0.9.1
torchvision==0.12.0