From 48b2722726934a28fae7922453d1197d316edb0f Mon Sep 17 00:00:00 2001 From: Jiri Petrlik Date: Fri, 1 Mar 2024 16:01:30 +0100 Subject: [PATCH] RHOAIENG-3771 - Reduce execution time of E2E tests --- tests/e2e/mnist.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/e2e/mnist.py b/tests/e2e/mnist.py index a99589659..91d1e0d21 100644 --- a/tests/e2e/mnist.py +++ b/tests/e2e/mnist.py @@ -19,7 +19,7 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, random_split, RandomSampler from torchmetrics import Accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -127,7 +127,7 @@ def setup(self, stage=None): ) def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=BATCH_SIZE) + return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000)) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=BATCH_SIZE) @@ -147,10 +147,11 @@ def test_dataloader(self): trainer = Trainer( accelerator="auto", # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs - max_epochs=5, + max_epochs=3, callbacks=[TQDMProgressBar(refresh_rate=20)], num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)), devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)), + replace_sampler_ddp=False, strategy="ddp", )