From 0fc6fb6b7073d29b83b3313d90733818aa291fd5 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 4 Nov 2023 01:41:59 -0500 Subject: [PATCH 1/2] removed deepcopy from fuse method --- intermediate_source/fx_conv_bn_fuser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index 90620ceba4e..6bc12bf5362 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -150,7 +150,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc def fuse(model: torch.nn.Module) -> torch.nn.Module: - model = copy.deepcopy(model) + model, state_dict = type(model)(), model.state_dict() + model.load_state_dict(state_dict) + model.eval() # The first step of most FX passes is to symbolically trace our model to # obtain a `GraphModule`. This is a representation of our original model # that is functionally identical to our original model, except that we now From c8436a4f74df999b1314f4e11e1b342af3dcf2f6 Mon Sep 17 00:00:00 2001 From: Mustafa Date: Sat, 4 Nov 2023 01:42:16 -0500 Subject: [PATCH 2/2] removed deepcopy from fuse_conv_bn_eval method --- intermediate_source/fx_conv_bn_fuser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index 6bc12bf5362..13a451fe4b2 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -104,7 +104,9 @@ def fuse_conv_bn_eval(conv, bn): module `C` such that C(x) == B(A(x)) in inference mode. """ assert(not (conv.training or bn.training)), "Fusion only for eval!" - fused_conv = copy.deepcopy(conv) + fused_conv = type(conv)(conv.in_channels, conv.out_channels, conv.kernel_size) + fused_conv.load_state_dict(conv.state_dict()) + fused_conv.eval() fused_conv.weight, fused_conv.bias = \ fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,