|
3 | 3 | import torch
|
4 | 4 | import _torch_ipex as core
|
5 | 5 |
|
6 |
| - |
7 | 6 | qscheme_dict ={torch.per_tensor_affine:0,
|
8 | 7 | torch.per_channel_affine:1,
|
9 | 8 | torch.per_tensor_symmetric:2,
|
10 | 9 | torch.per_channel_symmetric:3,
|
11 | 10 | torch.torch.per_channel_affine_float_qparams:4}
|
12 | 11 |
|
13 |
| -class AmpConf(object): |
14 |
| - def __init__(self, mixed_dtype=torch.bfloat16, configure_file=None, qscheme=torch.per_tensor_affine): |
15 |
| - self.dtype = mixed_dtype |
| 12 | +class QuantConf(object): |
| 13 | + def __init__(self, configure_file=None, qscheme=torch.per_tensor_affine): |
16 | 14 | self.configure_file = configure_file
|
17 | 15 |
|
18 |
| - if self.dtype == torch.int8: |
19 |
| - core.clear_indicators() |
20 |
| - assert qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric], \ |
21 |
| - "qscheme is only support torch.per_tensor_affine and torch.per_tensor_symmetric now" |
22 |
| - core.set_int8_qscheme(qscheme_dict[qscheme]) |
| 16 | + core.clear_indicators() |
| 17 | + assert qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric], \ |
| 18 | + "qscheme is only support torch.per_tensor_affine and torch.per_tensor_symmetric now" |
| 19 | + core.set_int8_qscheme(qscheme_dict[qscheme]) |
23 | 20 |
|
24 |
| - # for int8 path, if user give a exited configure file, load it. |
25 |
| - if self.configure_file != None and self.dtype == torch.int8: |
| 21 | + # if user provides an existing configuration file, load it |
| 22 | + if self.configure_file != None: |
26 | 23 | if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:
|
27 | 24 | with open(self.configure_file, 'r') as f:
|
28 | 25 | configures = json.load(f)
|
29 | 26 | core.load_indicators_file(configures)
|
30 | 27 | else:
|
31 | 28 | assert False, 'Can not load a empty file or none existed file, plese first do calibartion step'
|
32 | 29 |
|
33 |
| - # for int8 quantization, will save the date after doing calibration step. |
34 |
| - def save(self, configure_file, default_recipe=True): |
35 |
| - core.add_indicators() |
| 30 | + def save(self, configure_file): |
36 | 31 | configures = core.get_int8_configures()
|
37 |
| - if default_recipe: |
38 |
| - configures = self.get_default_recipe(configures) |
39 | 32 | with open(configure_file, 'w') as fp:
|
40 | 33 | json.dump(configures, fp, indent = 4)
|
41 |
| - |
42 |
| - def get_default_recipe(self, configures): |
43 |
| - elt_wise = ['relu', 'sigmoid', 'gelu'] |
44 |
| - inplace_ops = ['relu_', 'add_'] |
45 |
| - shape_ops = ['flatten'] |
46 |
| - # get default recipe, |
47 |
| - # q+dq+conv+q+dq+relu => q+dq+conv+relu |
48 |
| - # q+dq+op1+q+dq+q+dq+op2+q+dq => q+dq+op1+q+dq+op2+q+dq |
49 |
| - default_configures = configures |
50 |
| - num_ops = len(default_configures) |
51 |
| - for cur_id in range(num_ops): |
52 |
| - cur_op = default_configures[cur_id]['name'] |
53 |
| - if cur_op == 'dropout': |
54 |
| - continue |
55 |
| - inputs = default_configures[cur_id]['inputs_flow'] |
56 |
| - num_input = len(inputs) |
57 |
| - pre_ops = {} |
58 |
| - for i_num in range(num_input): |
59 |
| - inp = inputs[i_num] |
60 |
| - for pre_id in range(cur_id): |
61 |
| - pre_op = default_configures[pre_id]['name'] |
62 |
| - pre_out = default_configures[pre_id]['outputs_flow'] |
63 |
| - num_out= len(pre_out) |
64 |
| - for o_num in range(num_out): |
65 |
| - # pre_op+qu+dequ+qu+dequ+cur_op+qu+dequ -> pre_op+qu+dequ+cur_op+qu+dequ. |
66 |
| - # for relu, sigmoid or other elt_wise ops, id pre_op is conv, linear, then |
67 |
| - # remove qu+dequ between them for fusion: pre_op+cur_op+qu_dequ. |
68 |
| - if pre_out[o_num] == inp: |
69 |
| - if (cur_op not in inplace_ops) \ |
70 |
| - or (cur_op in inplace_ops and \ |
71 |
| - (pre_op == 'conv2d' or pre_op == 'conv3d' or pre_op == 'linear')): |
72 |
| - if pre_op not in inplace_ops and pre_op != 'dropout': |
73 |
| - default_configures[pre_id]['outputs_quantized'][o_num] = False |
74 |
| - if cur_op in elt_wise \ |
75 |
| - and (pre_op == 'conv2d' or pre_op == 'conv3d' or pre_op == 'linear' or pre_op == 'add'): |
76 |
| - default_configures[cur_id]['inputs_quantized'][i_num] = False |
77 |
| - if cur_op == 'add': |
78 |
| - pre_ops[i_num] = pre_op |
79 |
| - if cur_op in shape_ops: |
80 |
| - # for pooling case, the input and output always has same scale and zero point, |
81 |
| - # if the pooling's post ops is flatten, need sync flatten's input and output's |
82 |
| - # scale and zero point to pooling. |
83 |
| - if pre_op in ['max_pool2d', 'adaptive_avg_pool2d']: |
84 |
| - default_configures[cur_id]['input_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num] |
85 |
| - default_configures[cur_id]['input_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num] |
86 |
| - default_configures[cur_id]['output_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num] |
87 |
| - default_configures[cur_id]['output_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num] |
88 |
| - if pre_op in shape_ops: |
89 |
| - # if pre op is flatten, sync the input's scale and zero point to flatten. |
90 |
| - default_configures[cur_id]['input_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num] |
91 |
| - default_configures[cur_id]['input_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num] |
92 |
| - # conv op conv op |
93 |
| - # \ / \ / |
94 |
| - # q q \ q |
95 |
| - # \ / => \ / |
96 |
| - # dq dq \ dq |
97 |
| - # \ / \ / |
98 |
| - # add add |
99 |
| - if len(pre_ops) > 0: |
100 |
| - for key, value in pre_ops.items(): |
101 |
| - if value == 'conv2d' or value == 'conv3d' or value == 'linear': |
102 |
| - default_configures[cur_id]['inputs_quantized'][key] = False |
103 |
| - break |
104 |
| - |
105 |
| - # if add pre_op hasn't conv and linear, not need add q, dq for accuracy. |
106 |
| - pre_inputs = pre_ops.values() |
107 |
| - if cur_op == 'add' and \ |
108 |
| - ('conv2d' not in pre_inputs and 'conv3d' not in pre_inputs and 'linear' not in pre_inputs): |
109 |
| - default_configures[cur_id]['inputs_quantized'][0] = False |
110 |
| - default_configures[cur_id]['inputs_quantized'][1] = False |
111 |
| - |
112 |
| - # post process for add, linear, if cur op hasn't post quantized op, i.e. 'outputs_quantized' is True, |
113 |
| - # for good perfromance, the default recipe: |
114 |
| - # int8_input -> op -> q -> dq will converted to int8_input -> op. |
115 |
| - ops_remove_q_dq_after = ['add', 'linear', 'conv2d'] |
116 |
| - # post process for flatten, if flatten's pre-pop and post op are fp32 op, don't need add q and dq |
117 |
| - # before and after it. |
118 |
| - ops_remove_q_dq_before_after = ['flatten'] |
119 |
| - for cur_id in range(num_ops): |
120 |
| - cur_op = default_configures[cur_id]['name'] |
121 |
| - if cur_op in ops_remove_q_dq_after and default_configures[cur_id]['outputs_quantized'][0]: |
122 |
| - default_configures[cur_id]['outputs_quantized'][0] = False |
123 |
| - if cur_op in ops_remove_q_dq_before_after and default_configures[cur_id]['inputs_quantized'][0] \ |
124 |
| - and default_configures[cur_id]['outputs_quantized'][0]: |
125 |
| - default_configures[cur_id]['inputs_quantized'][0] = False |
126 |
| - default_configures[cur_id]['outputs_quantized'][0] = False |
127 |
| - |
128 |
| - return default_configures |
| 34 | + # clear indicators after saved |
| 35 | + core.clear_indicators() |
0 commit comments