Skip to content
This repository was archived by the owner on Aug 7, 2019. It is now read-only.

Commit ebabd44

Browse files
authored
Merge pull request #4 from yf225/debug
Fix numpy_extensions_tutorial.py
2 parents d8ebba0 + cca6b48 commit ebabd44

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

advanced_source/numpy_extensions_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def backward(ctx, grad_output):
103103
# the previous line can be expressed equivalently as:
104104
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
105105
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
106-
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter), torch.from_numpy(grad_bias)
106+
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
107107

108108

109109
class ScipyConv2d(Module):

0 commit comments

Comments
 (0)