Skip to content

Commit 48b2722

Browse files
committed
RHOAIENG-3771 - Reduce execution time of E2E tests
1 parent de6cdd5 commit 48b2722

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tests/e2e/mnist.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2020
from torch import nn
2121
from torch.nn import functional as F
22-
from torch.utils.data import DataLoader, random_split
22+
from torch.utils.data import DataLoader, random_split, RandomSampler
2323
from torchmetrics import Accuracy
2424
from torchvision import transforms
2525
from torchvision.datasets import MNIST
@@ -127,7 +127,7 @@ def setup(self, stage=None):
127127
)
128128

129129
def train_dataloader(self):
130-
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
130+
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))
131131

132132
def val_dataloader(self):
133133
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
@@ -147,10 +147,11 @@ def test_dataloader(self):
147147
trainer = Trainer(
148148
accelerator="auto",
149149
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
150-
max_epochs=5,
150+
max_epochs=3,
151151
callbacks=[TQDMProgressBar(refresh_rate=20)],
152152
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
153153
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
154+
replace_sampler_ddp=False,
154155
strategy="ddp",
155156
)
156157

0 commit comments

Comments
 (0)