@@ -311,6 +311,7 @@ The purpose for calibration is to run through some sample examples that is repre
311
311
the statistics of the Tensors and we can later use this information to calculate quantization parameters.
312
312
313
313
.. code :: python
314
+
314
315
def calibrate (model , data_loader ):
315
316
model.eval()
316
317
with torch.no_grad():
@@ -320,17 +321,19 @@ the statistics of the Tensors and we can later use this information to calculate
320
321
321
322
7. Convert the Model to a Quantized Model
322
323
-----------------------------------------
323
- ``convert_fx `` takes a calibrated model and produces a quantized model.
324
+ ``convert_fx `` takes a calibrated model and produces a quantized model.
324
325
325
326
.. code :: python
326
- quantized_model = convert_fx(prepared_model)
327
+
328
+ quantized_model = convert_fx(prepared_model)
327
329
print (quantized_model)
328
-
330
+
329
331
8. Evaluation
330
332
-------------
331
333
We can now print the size and accuracy of the quantized model.
332
334
333
335
.. code :: python
336
+
334
337
print (" Size of model before quantization" )
335
338
print_size_of_model(float_model)
336
339
print (" Size of model after quantization" )
@@ -372,6 +375,7 @@ we'll first call fuse explicitly to fuse the conv and bn in the model:
372
375
Note that ``fuse_fx `` only works in eval mode.
373
376
374
377
.. code :: python
378
+
375
379
fused = fuse_fx(float_model)
376
380
377
381
conv1_weight_after_fuse = fused.conv1[0 ].weight[0 ]
@@ -383,6 +387,7 @@ Note that ``fuse_fx`` only works in eval mode.
383
387
--------------------------------------------------------------------
384
388
385
389
.. code :: python
390
+
386
391
scripted_float_model_file = " resnet18_scripted.pth"
387
392
388
393
print (" Size of baseline model" )
@@ -397,6 +402,7 @@ quantized in eager mode. FX graph mode and eager mode produce very similar quant
397
402
so the expectation is that the accuracy and speedup are similar as well.
398
403
399
404
.. code :: python
405
+
400
406
print (" Size of Fx graph mode quantized model" )
401
407
print_size_of_model(quantized_model)
402
408
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
0 commit comments