diff --git a/recipes_source/recipes/tuning_guide.py b/recipes_source/recipes/tuning_guide.py index 0f82fb76d3d..dd615714a24 100644 --- a/recipes_source/recipes/tuning_guide.py +++ b/recipes_source/recipes/tuning_guide.py @@ -295,6 +295,10 @@ def fused_gelu(x): torch._C._jit_set_autocast_mode(False) with torch.no_grad(), torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16): + # Conv-BatchNorm folding for CNN-based Vision Models should be done with ``torch.fx.experimental.optimization.fuse`` when AMP is used + import torch.fx.experimental.optimization as optimization + # Please note that optimization.fuse need not be called when AMP is not used + model = optimization.fuse(model) model = torch.jit.trace(model, (example_input)) model = torch.jit.freeze(model) # a couple of warm-up runs