diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index 90620ceba4e..13a451fe4b2 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -104,7 +104,9 @@ def fuse_conv_bn_eval(conv, bn): module `C` such that C(x) == B(A(x)) in inference mode. """ assert(not (conv.training or bn.training)), "Fusion only for eval!" - fused_conv = copy.deepcopy(conv) + fused_conv = type(conv)(conv.in_channels, conv.out_channels, conv.kernel_size) + fused_conv.load_state_dict(conv.state_dict()) + fused_conv.eval() fused_conv.weight, fused_conv.bias = \ fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, @@ -150,7 +152,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc def fuse(model: torch.nn.Module) -> torch.nn.Module: - model = copy.deepcopy(model) + model, state_dict = type(model)(), model.state_dict() + model.load_state_dict(state_dict) + model.eval() # The first step of most FX passes is to symbolically trace our model to # obtain a `GraphModule`. This is a representation of our original model # that is functionally identical to our original model, except that we now