Skip to content

Commit 8250b2f

Browse files
tiangolosoumith
authored andcommitted
Fix blitz neural networks shape mismatch (#213)
1 parent f8c423f commit 8250b2f

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

beginner_source/blitz/neural_networks_tutorial.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def num_flat_features(self, x):
160160

161161
output = net(input)
162162
target = Variable(torch.arange(1, 11)) # a dummy target, for example
163+
target = target.view(1, -1) # make it the same shape as output
163164
criterion = nn.MSELoss()
164165

165166
loss = criterion(output, target)

0 commit comments

Comments
 (0)