Skip to content

Commit 935ddb0

Browse files
committed
rebase fixes
1 parent 547d9fe commit 935ddb0

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def forward(
282282
model = DynamicModel()
283283
ep = export(model, (w, x, y, z))
284284
model(w, x, torch.randn(3, 4), torch.randn(12))
285-
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
285+
try:
286+
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
287+
except Exception:
288+
tb.print_exc()
286289

287290
######################################################################
288291
# Basic concepts: symbols and guards
@@ -411,7 +414,10 @@ def forward(
411414
# static guard is emitted on a dynamically-marked dimension:
412415

413416
dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
414-
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
417+
try:
418+
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
419+
except Exception:
420+
tb.print_exc()
415421

416422
######################################################################
417423
# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
@@ -421,7 +427,10 @@ def forward(
421427
dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
422428
dynamic_shapes["x"] = (Dim.STATIC,)
423429
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
424-
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
430+
try:
431+
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
432+
except Exception:
433+
tb.print_exc()
425434

426435
######################################################################
427436
# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
@@ -439,7 +448,7 @@ def __init__(self):
439448

440449
def forward(self, w, x, y, z):
441450
assert w.shape[0] <= 512
442-
torch._check(x.shape[0] >= 16)
451+
torch._check(x.shape[0] >= 4)
443452
if w.shape[0] == x.shape[0] + 2:
444453
x0 = x + y
445454
x1 = self.l(w)
@@ -455,7 +464,10 @@ def forward(self, w, x, y, z):
455464
"y": (Dim.AUTO, Dim.AUTO),
456465
"z": (Dim.AUTO,),
457466
}
458-
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
467+
try:
468+
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
469+
except Exception:
470+
tb.print_exc()
459471
print(ep)
460472

461473
######################################################################
@@ -485,7 +497,10 @@ def forward(self, w, x, y, z):
485497
"input": (Dim.AUTO, Dim.STATIC),
486498
},
487499
)
488-
ep.module()(torch.randn(2, 4))
500+
try:
501+
ep.module()(torch.randn(2, 4))
502+
except Exception:
503+
tb.print_exc()
489504

490505
######################################################################
491506
# Named Dims
@@ -539,14 +554,17 @@ def forward(self, x, y):
539554
return w + torch.ones(4)
540555

541556
dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
542-
ep = export(
543-
Foo(),
544-
(torch.randn(6, 4), torch.randn(6, 4)),
545-
dynamic_shapes={
546-
"x": (dx, d1),
547-
"y": (dy, d1),
548-
},
549-
)
557+
try:
558+
ep = export(
559+
Foo(),
560+
(torch.randn(6, 4), torch.randn(6, 4)),
561+
dynamic_shapes={
562+
"x": (dx, d1),
563+
"y": (dy, d1),
564+
},
565+
)
566+
except Exception:
567+
tb.print_exc()
550568

551569
######################################################################
552570
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
@@ -688,7 +706,10 @@ def forward(self, x, y):
688706
torch.tensor(32),
689707
torch.randn(60),
690708
)
691-
export(Foo(), inps)
709+
try:
710+
export(Foo(), inps)
711+
except Exception:
712+
tb.print_exc()
692713

693714
######################################################################
694715
# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with
@@ -700,7 +721,7 @@ class Foo(torch.nn.Module):
700721
def forward(self, x, y):
701722
a = x.item()
702723
torch._check(a >= 0)
703-
torch._check(a <= y.shape[0])
724+
torch._check(a < y.shape[0])
704725
return y[a]
705726

706727
inps = (
@@ -732,7 +753,10 @@ def forward(self, x, y):
732753
torch.tensor(32),
733754
torch.randn(60),
734755
)
735-
export(Foo(), inps, strict=False)
756+
try:
757+
export(Foo(), inps, strict=False)
758+
except Exception:
759+
tb.print_exc()
736760

737761
######################################################################
738762
# For these errors, some basic options you have are:

0 commit comments

Comments
 (0)