Skip to content

Fix formatting in the FX Graph Mode Quantization guide #2362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions prototype_source/fx_graph_mode_quant_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_

FX Graph Mode Quantization requires a symbolically traceable model.
We use the FX framework (TODO: link) to convert a symbolically traceable nn.Module instance to IR,
We use the FX framework to convert a symbolically traceable nn.Module instance to IR,
and we operate on the IR to execute the quantization passes.
Please post your question about symbolically tracing your model in `PyTorch Discussion Forum <https://discuss.pytorch.org/c/quantization/17>`_

Expand All @@ -22,16 +22,19 @@ You can use any combination of these options:
b. Write your own observed and quantized submodule


####################################################################
If the code that is not symbolically traceable does not need to be quantized, we have the following two options
to run FX Graph Mode Quantization:
1.a. Symbolically trace only the code that needs to be quantized


Symbolically trace only the code that needs to be quantized
-----------------------------------------------------------------
When the whole model is not symbolically traceable but the submodule we want to quantize is
symbolically traceable, we can run quantization only on that submodule.

before:

.. code:: python

class M(nn.Module):
def forward(self, x):
x = non_traceable_code_1(x)
Expand All @@ -42,6 +45,7 @@ before:
after:

.. code:: python

class FP32Traceable(nn.Module):
def forward(self, x):
x = traceable_code(x)
Expand Down Expand Up @@ -69,8 +73,7 @@ Note if original model needs to be preserved, you will have to
copy it yourself before calling the quantization APIs.


#####################################################
1.b. Skip symbolically trace the non-traceable code
Skip symbolically trace the non-traceable code
---------------------------------------------------
When we have some non-traceable code in the module, and this part of code doesn’t need to be quantized,
we can factor out this part of the code into a submodule and skip symbolically trace that submodule.
Expand Down Expand Up @@ -134,8 +137,7 @@ quantization code:

If the code that is not symbolically traceable needs to be quantized, we have the following two options:

##########################################################
2.a Refactor your code to make it symbolically traceable
Refactor your code to make it symbolically traceable
--------------------------------------------------------
If it is easy to refactor the code and make the code symbolically traceable,
we can refactor the code and remove the use of non-traceable constructs in python.
Expand Down Expand Up @@ -167,15 +169,10 @@ after:
return x.permute(0, 2, 1, 3)


quantization code:

This can be combined with other approaches and the quantization code
depends on the model.



#######################################################
2.b. Write your own observed and quantized submodule
Write your own observed and quantized submodule
-----------------------------------------------------

If the non-traceable code can’t be refactored to be symbolically traceable,
Expand Down Expand Up @@ -207,8 +204,8 @@ non-traceable logic, wrapped in a module
class FP32NonTraceable:
...


2. Define observed version of FP32NonTraceable
2. Define observed version of
FP32NonTraceable

.. code:: python

Expand Down