Skip to content

Commit d0447a1

Browse files
committed
rework dry_run handling in mnist, mnist_hogwild
1 parent 9d0e6e7 commit d0447a1

File tree

4 files changed

+5
-41
lines changed

4 files changed

+5
-41
lines changed

mnist/main.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def train(args, model, device, train_loader, optimizer, epoch):
4747
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
4848
epoch, batch_idx * len(data), len(train_loader.dataset),
4949
100. * batch_idx / len(train_loader), loss.item()))
50+
if args.dry_run:
51+
break
5052

5153

5254
def test(model, device, test_loader):
@@ -113,26 +115,6 @@ def main():
113115
transform=transform)
114116
dataset2 = datasets.MNIST('../data', train=False,
115117
transform=transform)
116-
if args.dry_run:
117-
from torch.utils.data.sampler import Sampler
118-
119-
class DryRunSampler(Sampler):
120-
r"""Return only two datum from the set of data
121-
"""
122-
123-
def __init__(self, data_source):
124-
self.data_source = data_source
125-
126-
def __iter__(self):
127-
return iter(range(2))
128-
129-
def __len__(self):
130-
return 2
131-
132-
133-
kwargs['sampler'] = DryRunSampler(dataset1)
134-
kwargs['shuffle'] = False
135-
136118
train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
137119
test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)
138120

mnist_hogwild/main.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,6 @@ def forward(self, x):
5151
return F.log_softmax(x, dim=1)
5252

5353

54-
class DryRunSampler(Sampler):
55-
r"""Return only two datum from the set of data
56-
"""
57-
58-
def __init__(self, data_source):
59-
self.data_source = data_source
60-
61-
def __iter__(self):
62-
return iter(range(2))
63-
64-
def __len__(self):
65-
return 2
66-
67-
6854
if __name__ == '__main__':
6955
args = parser.parse_args()
7056

@@ -85,12 +71,6 @@ def __len__(self):
8571
'pin_memory': True,
8672
})
8773

88-
if args.dry_run:
89-
90-
kwargs['sampler'] = DryRunSampler(dataset1)
91-
kwargs['shuffle'] = False
92-
93-
9474
torch.manual_seed(args.seed)
9575
mp.set_start_method('spawn')
9676

mnist_hogwild/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def train_epoch(epoch, args, model, device, data_loader, optimizer):
3535
print('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
3636
pid, epoch, batch_idx * len(data), len(data_loader.dataset),
3737
100. * batch_idx / len(data_loader), loss.item()))
38+
if args.dry_run:
39+
break
3840

3941

4042
def test_epoch(model, device, data_loader):

run_python_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function clean() {
172172
}
173173

174174
function run_all() {
175-
#cpp
175+
# cpp
176176
dcgan
177177
# distributed
178178
fast_neural_style

0 commit comments

Comments
 (0)