Skip to content

Commit 76e1ec0

Browse files
committed
removed deepcopy from fuse method
1 parent e7563f6 commit 76e1ec0

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

intermediate_source/fx_conv_bn_fuser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc
150150

151151

152152
def fuse(model: torch.nn.Module) -> torch.nn.Module:
153-
model = copy.deepcopy(model)
153+
model, state_dict = type(model)(), model.state_dict()
154+
model.load_state_dict(state_dict)
155+
model.eval()
154156
# The first step of most FX passes is to symbolically trace our model to
155157
# obtain a `GraphModule`. This is a representation of our original model
156158
# that is functionally identical to our original model, except that we now

0 commit comments

Comments
 (0)