@@ -282,7 +282,10 @@ def forward(
282
282
model = DynamicModel ()
283
283
ep = export (model , (w , x , y , z ))
284
284
model (w , x , torch .randn (3 , 4 ), torch .randn (12 ))
285
- ep .module ()(w , x , torch .randn (3 , 4 ), torch .randn (12 ))
285
+ try :
286
+ ep .module ()(w , x , torch .randn (3 , 4 ), torch .randn (12 ))
287
+ except Exception :
288
+ tb .print_exc ()
286
289
287
290
######################################################################
288
291
# Basic concepts: symbols and guards
@@ -411,7 +414,10 @@ def forward(
411
414
# static guard is emitted on a dynamically-marked dimension:
412
415
413
416
dynamic_shapes ["w" ] = (Dim .AUTO , Dim .DYNAMIC )
414
- export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
417
+ try :
418
+ export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
419
+ except Exception :
420
+ tb .print_exc ()
415
421
416
422
######################################################################
417
423
# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
@@ -421,7 +427,10 @@ def forward(
421
427
dynamic_shapes ["w" ] = (Dim .AUTO , Dim .AUTO )
422
428
dynamic_shapes ["x" ] = (Dim .STATIC ,)
423
429
dynamic_shapes ["y" ] = (Dim .AUTO , Dim .DYNAMIC )
424
- export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
430
+ try :
431
+ export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
432
+ except Exception :
433
+ tb .print_exc ()
425
434
426
435
######################################################################
427
436
# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
@@ -439,7 +448,7 @@ def __init__(self):
439
448
440
449
def forward (self , w , x , y , z ):
441
450
assert w .shape [0 ] <= 512
442
- torch ._check (x .shape [0 ] >= 16 )
451
+ torch ._check (x .shape [0 ] >= 4 )
443
452
if w .shape [0 ] == x .shape [0 ] + 2 :
444
453
x0 = x + y
445
454
x1 = self .l (w )
@@ -455,7 +464,10 @@ def forward(self, w, x, y, z):
455
464
"y" : (Dim .AUTO , Dim .AUTO ),
456
465
"z" : (Dim .AUTO ,),
457
466
}
458
- ep = export (DynamicModel (), (w , x , y , z ), dynamic_shapes = dynamic_shapes )
467
+ try :
468
+ ep = export (DynamicModel (), (w , x , y , z ), dynamic_shapes = dynamic_shapes )
469
+ except Exception :
470
+ tb .print_exc ()
459
471
print (ep )
460
472
461
473
######################################################################
@@ -485,7 +497,10 @@ def forward(self, w, x, y, z):
485
497
"input" : (Dim .AUTO , Dim .STATIC ),
486
498
},
487
499
)
488
- ep .module ()(torch .randn (2 , 4 ))
500
+ try :
501
+ ep .module ()(torch .randn (2 , 4 ))
502
+ except Exception :
503
+ tb .print_exc ()
489
504
490
505
######################################################################
491
506
# Named Dims
@@ -539,14 +554,17 @@ def forward(self, x, y):
539
554
return w + torch .ones (4 )
540
555
541
556
dx , dy , d1 = torch .export .dims ("dx" , "dy" , "d1" )
542
- ep = export (
543
- Foo (),
544
- (torch .randn (6 , 4 ), torch .randn (6 , 4 )),
545
- dynamic_shapes = {
546
- "x" : (dx , d1 ),
547
- "y" : (dy , d1 ),
548
- },
549
- )
557
+ try :
558
+ ep = export (
559
+ Foo (),
560
+ (torch .randn (6 , 4 ), torch .randn (6 , 4 )),
561
+ dynamic_shapes = {
562
+ "x" : (dx , d1 ),
563
+ "y" : (dy , d1 ),
564
+ },
565
+ )
566
+ except Exception :
567
+ tb .print_exc ()
550
568
551
569
######################################################################
552
570
# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
@@ -688,7 +706,10 @@ def forward(self, x, y):
688
706
torch .tensor (32 ),
689
707
torch .randn (60 ),
690
708
)
691
- export (Foo (), inps )
709
+ try :
710
+ export (Foo (), inps )
711
+ except Exception :
712
+ tb .print_exc ()
692
713
693
714
######################################################################
694
715
# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with
@@ -700,7 +721,7 @@ class Foo(torch.nn.Module):
700
721
def forward (self , x , y ):
701
722
a = x .item ()
702
723
torch ._check (a >= 0 )
703
- torch ._check (a <= y .shape [0 ])
724
+ torch ._check (a < y .shape [0 ])
704
725
return y [a ]
705
726
706
727
inps = (
@@ -732,7 +753,10 @@ def forward(self, x, y):
732
753
torch .tensor (32 ),
733
754
torch .randn (60 ),
734
755
)
735
- export (Foo (), inps , strict = False )
756
+ try :
757
+ export (Foo (), inps , strict = False )
758
+ except Exception :
759
+ tb .print_exc ()
736
760
737
761
######################################################################
738
762
# For these errors, some basic options you have are:
0 commit comments