Skip to content

Commit 6b31dd0

Browse files
conv-bn folding should be used for CNN-based Vision Models when AMP is used with oneDNN Graph (#2535)
* Update tuning guide to reflect folding conv-bn when oneDNN Graph is used with AMP By default, conv-bn folding isn't done with CNN based models when AMP is used with oneDNN Graph. `torch.fx.experimental.optimize.fuse` should be used for such models --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 3a769b0 commit 6b31dd0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

recipes_source/recipes/tuning_guide.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ def fused_gelu(x):
295295
torch._C._jit_set_autocast_mode(False)
296296

297297
with torch.no_grad(), torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
298+
# Conv-BatchNorm folding for CNN-based Vision Models should be done with ``torch.fx.experimental.optimization.fuse`` when AMP is used
299+
import torch.fx.experimental.optimization as optimization
300+
# Please note that optimization.fuse need not be called when AMP is not used
301+
model = optimization.fuse(model)
298302
model = torch.jit.trace(model, (example_input))
299303
model = torch.jit.freeze(model)
300304
# a couple of warm-up runs

0 commit comments

Comments
 (0)