diff --git a/beginner_source/Intro_to_TorchScript_tutorial.py b/beginner_source/Intro_to_TorchScript_tutorial.py index 1c7502b8a46..d9b0023a6ab 100644 --- a/beginner_source/Intro_to_TorchScript_tutorial.py +++ b/beginner_source/Intro_to_TorchScript_tutorial.py @@ -31,7 +31,7 @@ """ -import torch # This is all you need to use both PyTorch and TorchScript! +import torch # This is all you need to use both PyTorch and TorchScript! print(torch.__version__) @@ -125,11 +125,11 @@ def forward(self, x, h): # class MyDecisionGate(torch.nn.Module): - def forward(self, x): - if x.sum() > 0: - return x - else: - return -x + def forward(self, x): + if x.sum() > 0: + return x + else: + return -x class MyCell(torch.nn.Module): def __init__(self): @@ -256,11 +256,11 @@ def forward(self, x, h): # class MyDecisionGate(torch.nn.Module): - def forward(self, x): - if x.sum() > 0: - return x - else: - return -x + def forward(self, x): + if x.sum() > 0: + return x + else: + return -x class MyCell(torch.nn.Module): def __init__(self, dg): @@ -342,13 +342,13 @@ def forward(self, xs): # class WrapRNN(torch.nn.Module): - def __init__(self): - super(WrapRNN, self).__init__() - self.loop = torch.jit.script(MyRNNLoop()) + def __init__(self): + super(WrapRNN, self).__init__() + self.loop = torch.jit.script(MyRNNLoop()) - def forward(self, xs): - y, h = self.loop(xs) - return torch.relu(y) + def forward(self, xs): + y, h = self.loop(xs) + return torch.relu(y) traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4))) print(traced.code)