19
19
from pytorch_lightning .callbacks .progress import TQDMProgressBar
20
20
from torch import nn
21
21
from torch .nn import functional as F
22
- from torch .utils .data import DataLoader , random_split
22
+ from torch .utils .data import DataLoader , random_split , RandomSampler
23
23
from torchmetrics import Accuracy
24
24
from torchvision import transforms
25
25
from torchvision .datasets import MNIST
@@ -127,7 +127,7 @@ def setup(self, stage=None):
127
127
)
128
128
129
129
def train_dataloader (self ):
130
- return DataLoader (self .mnist_train , batch_size = BATCH_SIZE )
130
+ return DataLoader (self .mnist_train , batch_size = BATCH_SIZE , sampler = RandomSampler ( self . mnist_train , num_samples = 1000 ) )
131
131
132
132
def val_dataloader (self ):
133
133
return DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
@@ -147,10 +147,11 @@ def test_dataloader(self):
147
147
trainer = Trainer (
148
148
accelerator = "auto" ,
149
149
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
150
- max_epochs = 5 ,
150
+ max_epochs = 3 ,
151
151
callbacks = [TQDMProgressBar (refresh_rate = 20 )],
152
152
num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
153
153
devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
154
+ replace_sampler_ddp = False ,
154
155
strategy = "ddp" ,
155
156
)
156
157
0 commit comments