@@ -28,42 +28,72 @@ def quantize(
28
28
"""
29
29
30
30
with unset_fake_temporarily ():
31
- if isinstance (input_tensor , TRTTensor ) and input_tensor .dtype not in (
32
- trt .float32 ,
33
- trt .float16 ,
34
- ):
35
- raise ValueError (
36
- f"quantize converter received an input of { input_tensor .dtype } type. Supported types: float32 | float16"
37
- )
31
+ if isinstance (input_tensor , (torch .Tensor , TRTTensor )):
32
+ if input_tensor .dtype not in (
33
+ trt .float32 ,
34
+ trt .float16 ,
35
+ trt .bfloat16 ,
36
+ torch .bfloat16 ,
37
+ torch .float16 ,
38
+ torch .float32 ,
39
+ ):
40
+ raise ValueError (
41
+ f"quantize converter received an input of { input_tensor .dtype } type. Supported types: float32 | float16 | bfloat16"
42
+ )
38
43
if num_bits != 8 or exponent_bits not in (0 , 4 ):
39
44
raise ValueError (
40
45
f"quantize converter currently only accept INT8 or FP8 based quantize, got { num_bits = } , { exponent_bits = } "
41
46
)
47
+ else :
48
+ raise ValueError (
49
+ f"quantize converter received an input of { type (input_tensor )} type. Supported types: torch.Tensor | TRTTensor"
50
+ )
51
+
42
52
if num_bits == 8 and exponent_bits == 0 :
53
+ dtype = trt .DataType .INT8
43
54
max_bound = 127
44
55
elif num_bits == 8 and exponent_bits == 4 :
56
+ dtype = trt .DataType .FP8
45
57
max_bound = 448
46
58
47
59
amax = to_torch (amax , None )
60
+ axis = None
61
+ # int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
62
+ if dtype == trt .DataType .INT8 and amax .numel () > 1 :
63
+ # if the amax has more than one element, calculate the axis, otherwise axis value will be ignored
64
+ amax_init_shape = amax .shape
65
+ amax = amax .squeeze ().data
66
+ assert (
67
+ len (amax .shape ) == 1
68
+ ), f"TensorRT does not support multi-axis quantization. { name = } { amax_init_shape = } { amax .shape = } "
69
+ axis = list (amax_init_shape ).index (list (amax .shape )[0 ])
70
+ assert (
71
+ axis == 0
72
+ ), f"{ name = } { amax = } is per-channel quantization, expected axis to be 0, but got { axis = } "
73
+ else :
74
+ # int8 activation and fp8 weight/activation quantization is per-tensor quantization, it can only have single amax value
75
+ assert (
76
+ amax .numel () == 1
77
+ ), f"{ name = } is per-tensor quantization, expected amax is a singular value, but got { amax .shape = } "
48
78
scale = torch .divide (amax , max_bound )
79
+ scale .masked_fill_ (scale == 0 , 1.0 )
49
80
scale = get_trt_tensor (ctx , scale , name + "_scale" )
50
- # Add Q node
51
- quantize_layer = ctx .net .add_quantize (input_tensor , scale )
52
- if num_bits == 8 and exponent_bits == 0 :
53
- quantize_layer .set_output_type (0 , trt .DataType .INT8 )
54
- elif num_bits == 8 and exponent_bits == 4 :
55
- quantize_layer .set_output_type (0 , trt .DataType .FP8 )
81
+ input_tensor = get_trt_tensor (ctx , input_tensor , name )
56
82
83
+ # Add Q node
84
+ quantize_layer = ctx .net .add_quantize (input_tensor , scale , dtype )
85
+ if axis is not None :
86
+ quantize_layer .axis = axis
57
87
set_layer_name (quantize_layer , target , name + "_quantize" , source_ir )
58
88
q_output = quantize_layer .get_output (0 )
59
89
# Add DQ node
60
- dequantize_layer = ctx .net .add_dequantize (q_output , scale )
90
+ dequantize_layer = ctx .net .add_dequantize (
91
+ q_output , scale , output_type = input_tensor .dtype
92
+ )
93
+ dequantize_layer .to_type = input_tensor .dtype
94
+ if axis is not None :
95
+ dequantize_layer .axis = axis
61
96
set_layer_name (dequantize_layer , target , name + "_dequantize" , source_ir )
62
- if num_bits == 8 and exponent_bits == 0 :
63
- dequantize_layer .precision = trt .DataType .INT8
64
- elif num_bits == 8 and exponent_bits == 4 :
65
- # Set DQ layer precision to FP8
66
- dequantize_layer .precision = trt .DataType .FP8
67
97
dq_output = dequantize_layer .get_output (0 )
68
98
69
99
return dq_output
0 commit comments