Skip to content

Commit 03faf3e

Browse files
remove dropout for inference path (#8)
1 parent aaed398 commit 03faf3e

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

intel_pytorch_extension_py/conf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ def get_default_recipe(self, configures):
7676
# add add
7777
if len(pre_ops) > 0:
7878
for key, value in pre_ops.items():
79-
if value == 'conv2d' or value == 'conv3d':
79+
if value == 'conv2d' or value == 'conv3d' or value == 'linear':
8080
default_configures[cur_id]['inputs_quantized'][key] = False
8181
break
8282

83-
# if add pre_op hasn't conv, not need add q, dq for accuracy.
83+
# if add pre_op hasn't conv and linear, not need add q, dq for accuracy.
8484
pre_inputs = pre_ops.values()
85-
if cur_op == 'add' and ('conv2d' not in pre_inputs and 'conv3d' not in pre_inputs):
85+
if cur_op == 'add' and \
86+
('conv2d' not in pre_inputs and 'conv3d' not in pre_inputs and 'linear' not in pre_inputs):
8687
default_configures[cur_id]['inputs_quantized'][0] = False
8788
default_configures[cur_id]['inputs_quantized'][1] = False
8889

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .fuse_utils import conv_bn_fuse
1+
from .fuse_utils import conv_bn_fuse, remove_dropout
22

intel_pytorch_extension_py/fx/fuse_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
2+
import torch.nn as nn
23
import torch.fx as fx
4+
from torch.fx.node import Argument, Target
35
from torch.nn.utils.fusion import fuse_conv_bn_eval
4-
from typing import Type, Dict, Any, Tuple, Iterable
5-
import torch
6+
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast, Callable
67
import copy
78

89
def _parent_name(target : str) -> Tuple[str, str]:
@@ -63,3 +64,17 @@ def conv_bn_fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
6364
new_graph.erase_node(node)
6465
return fx.GraphModule(fx_model, new_graph)
6566

67+
def remove_dropout(model: nn.Module) -> nn.Module:
68+
"""
69+
Removes all dropout layers from the module.
70+
"""
71+
fx_model = fx.symbolic_trace(model)
72+
73+
class DropoutRemover(torch.fx.Transformer):
74+
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
75+
if isinstance(self.submodules[target], nn.Dropout):
76+
assert len(args) == 1
77+
return args[0]
78+
else:
79+
return super().call_module(target, args, kwargs)
80+
return DropoutRemover(fx_model).transform()

0 commit comments

Comments
 (0)