Closed
Description
In this code:
The print(traced_cell.code)
output in docs https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html is:
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = torch.add((self.linear).forward(input, ), h, alpha=1)
_1 = torch.tanh(_0)
return (_1, _1)
But I think the right print(traced_cell.code)
output should be like:
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.dg
_1 = (self.linear).forward(input, )
_2 = (_0).forward(_1, )
_3 = torch.tanh(torch.add(_1, h, alpha=1))
return (_3, _3)
The problem of generation of misleading output maybe related to some complex things!
If any guy know the real answer, please help this!!
Continue:
I find that the output seems random, the below img is the my result of two attempts, the output is different!!!
I also test the below code:
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
x, h = torch.ones(5, 4)*-1, torch.ones(5, 4)
# x, h = torch.rand(5, 4), torch.rand(5, 4)
print("########################################################")
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
print("########################################################")
x2, h2 = torch.ones(5, 4), torch.ones(5, 4)
print("########################################################")
traced2_cell = torch.jit.trace(my_cell, (x2, h2))
print(traced2_cell.code)
print("########################################################")