Skip to content

Commit 018a7dc

Browse files
committed
formatting
1 parent ae209cc commit 018a7dc

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

prototype_source/pt2e_quant_ptq_static.rst

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -435,47 +435,47 @@ Convert the Calibrated Model to a Quantized Model
435435
quantized_model = convert_pt2e(prepared_model)
436436
print(quantized_model)
437437
438-
.. note::
439-
At this step, we currently have two representations that you can choose from, but exact representation
440-
we offer in the long term might change based on feedback from PyTorch users.
438+
At this step, we currently have two representations that you can choose from, but exact representation
439+
we offer in the long term might change based on feedback from PyTorch users.
441440

442-
* Q/DQ Representation (default)
441+
* Q/DQ Representation (default)
443442

444-
Previous documentation for `representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-
443+
Previous documentation for `representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-
445444
Extending-PyTorch-Quantization-to-Custom-Backends.md>`_ all quantized operators are represented as ``dequantize -> fp32_op -> qauntize``.
446445

447-
.. code-block:: python
448-
449-
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
450-
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
451-
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
452-
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
453-
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
454-
weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
455-
out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
456-
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
457-
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
458-
return out_i8
459-
460-
* Reference Quantized Model Representation (WIP, expected to be ready at end of August): we have special representation for selected ops (for example, quantized linear), other ops are represented as (``dq -> float32_op -> q``), and ``q/dq`` are decomposed into more primitive operators.
461-
462-
You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``.
446+
.. code-block:: python
463447
464-
.. code-block:: python
465-
466-
# Reference Quantized Pattern for quantized linear
467-
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
468-
x_int16 = x_int8.to(torch.int16)
469-
weight_int16 = weight_int8.to(torch.int16)
470-
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
471-
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
472-
bias_scale = x_scale * weight_scale
473-
bias_int32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, bias_fp32, bias_scale / out_scale)
474-
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
475-
return out_int8
448+
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
449+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
450+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
451+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
452+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
453+
weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
454+
out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
455+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
456+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
457+
return out_i8
458+
459+
* Reference Quantized Model Representation (WIP, expected to be ready at end of August): we have special representation for selected ops (for example, quantized linear), other ops are represented as (``dq -> float32_op -> q``), and ``q/dq`` are decomposed into more primitive operators.
476460

461+
You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``.
477462

478-
See `here <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/representation/rewrite.py>`_ for the most up-to-date reference representations.
463+
.. code-block:: python
464+
465+
# Reference Quantized Pattern for quantized linear
466+
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
467+
x_int16 = x_int8.to(torch.int16)
468+
weight_int16 = weight_int8.to(torch.int16)
469+
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
470+
bias_scale = x_scale * weight_scale
471+
bias_int32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
472+
acc_int32 = acc_int32 + bias_int32
473+
acc_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale) + output_zero_point
474+
out_int8 = torch.ops.aten.clamp(acc_int32, qmin, qmax).to(torch.int8)
475+
return out_int8
476+
477+
478+
See `here <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/representation/rewrite.py>`_ for the most up-to-date reference representations.
479479

480480

481481
Checking Model Size and Accuracy Evaluation

0 commit comments

Comments
 (0)