diff --git a/logreg/main.py b/logreg/main.py index 98c3ef00e3..f2781e3830 100755 --- a/logreg/main.py +++ b/logreg/main.py @@ -54,7 +54,7 @@ def get_batch(batch_size=32): # Reset gradients for param in fc.parameters(): - param.grad.zero_() + param.grad.data.zero_() # Forward pass output = l1(fc(batch_x), batch_y) @@ -65,7 +65,7 @@ def get_batch(batch_size=32): # Apply gradients for param in fc.parameters(): - param.data.add_(-1 * param.grad) + param.data.add_(-1 * param.grad.data) # Stop criterion if loss < 0.1: