|
| 1 | +Leverage Advanced Matrix Extensions |
| 2 | +============================================== |
| 3 | + |
| 4 | +Introduction |
| 5 | +-------------- |
| 6 | + |
| 7 | +Advanced Matrix Extensions (AMX), also known as Intel Advanced Matrix Extensions (Intel AMX), is an extensions to the x86 instruction set architecture (ISA). |
| 8 | +AMX is designed to improve performance of deep-learning training and inference on the CPU and is ideal for workloads like natural-language processing, recommendation systems and image recognition. |
| 9 | +AMX supports two data types, INT8 and BFloat16, compared to AVX512 FP32, it can achieve up to 32x and 16x acceleration, respectively. |
| 10 | +For more detailed information of AMX, see `here <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ and `here <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/ai-solution-brief.html>`_. |
| 11 | + |
| 12 | +Note: AMX will have FP16 support on the next generation of Xeon. |
| 13 | + |
| 14 | +AMX in PyTorch |
| 15 | +-------------- |
| 16 | + |
| 17 | +PyTorch leverages AMX for computing intensive operators with BFloat16 and quantization with INT8 by its backend oneDNN |
| 18 | +to get higher performance out-of-box on x86 CPUs with AMX support. |
| 19 | +The operation is fully handled by oneDNN according to the execution code path generated. I.e. when a supported operation gets executed into oneDNN implementation on a hardware platform with AMX support, AMX instructions will be invoked automatically inside oneDNN. |
| 20 | +No manual operations are required to enable this feature. |
| 21 | + |
| 22 | +BF16 CPU ops that can leverage AMX: |
| 23 | +""""""""""""""""""""""""""""""""""" |
| 24 | + |
| 25 | +``conv1d``, |
| 26 | +``conv2d``, |
| 27 | +``conv3d``, |
| 28 | +``conv_transpose1d``, |
| 29 | +``conv_transpose2d``, |
| 30 | +``conv_transpose3d``, |
| 31 | +``bmm``, |
| 32 | +``mm``, |
| 33 | +``baddbmm``, |
| 34 | +``addmm``, |
| 35 | +``addbmm``, |
| 36 | +``linear``, |
| 37 | +``matmul``, |
| 38 | +``_convolution`` |
| 39 | + |
| 40 | +Quantization CPU ops that can leverage AMX: |
| 41 | +""""""""""""""""""""""""""""""""""" |
| 42 | + |
| 43 | +``conv1d``, |
| 44 | +``conv2d``, |
| 45 | +``conv3d``, |
| 46 | +``conv1d``, |
| 47 | +``conv2d``, |
| 48 | +``conv3d``, |
| 49 | +``conv_transpose1d``, |
| 50 | +``conv_transpose2d``, |
| 51 | +``conv_transpose3d``, |
| 52 | +``linear`` |
| 53 | + |
| 54 | +Preliminary requirements to activate AMX support for PyTorch: |
| 55 | +''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' |
| 56 | + |
| 57 | +All of the following Instruction sets onboard the hardware platform |
| 58 | +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" |
| 59 | + |
| 60 | ++---------+----------+----------+----------+-------------+-------------+----------+----------+----------+ |
| 61 | +| AVX512F | AVX512BW | AVX512VL | AVX512DQ | AVX512_VNNI | AVX512_BF16 | AMX_TILE | AMX_INT8 | AMX_BF16 | |
| 62 | ++---------+----------+----------+----------+-------------+-------------+----------+----------+----------+ |
| 63 | + |
| 64 | +Software |
| 65 | +"""""""" |
| 66 | + |
| 67 | +For linux: |
| 68 | + |
| 69 | ++----------------------+---------------+ |
| 70 | +| linux kernel >= 5.16 | Python >= 3.8 | |
| 71 | ++----------------------+---------------+ |
| 72 | + |
| 73 | + |
| 74 | +Guidelines of leveraging AMX with workloads |
| 75 | +------------------------------------------ |
| 76 | + |
| 77 | +For BFloat16 data type: |
| 78 | +'''''''''''''''''''' |
| 79 | + |
| 80 | +Using `torch.cpu.amp` or `torch.autocast("cpu")` would utilize AMX acceleration. |
| 81 | + |
| 82 | + |
| 83 | +``` |
| 84 | +model = model.to(memory_format=torch.channels_last) |
| 85 | +with torch.cpu.amp.autocast(): |
| 86 | + output = model(input) |
| 87 | +``` |
| 88 | + |
| 89 | + |
| 90 | +For quantization: |
| 91 | +''''''''''''''''' |
| 92 | + |
| 93 | +Applying quantization would utilize AMX acceleration. |
| 94 | + |
| 95 | +Note: Use channels last format to get better performance. |
| 96 | + |
| 97 | +For torch.compile: |
| 98 | +''''''''''''''''' |
| 99 | + |
| 100 | +When the generated graph model runs into oneDNN implementations with the supported operators mentioned in lists above, AMX accelerations will be activated. |
| 101 | + |
| 102 | + |
| 103 | +Confirm AMX is being utilized |
| 104 | +'''''''''''''''''''''' |
| 105 | + |
| 106 | +Set environment variable `export ONEDNN_VERBOSE=1` to get oneDNN verbose at runtime. |
| 107 | +For more detailed information of oneDNN, see `here <https://oneapi-src.github.io/oneDNN/index.html>`_. |
| 108 | + |
| 109 | +For example: |
| 110 | + |
| 111 | +Get oneDNN verbose: |
| 112 | + |
| 113 | +``` |
| 114 | +onednn_verbose,info,oneDNN v2.7.3 (commit 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e) |
| 115 | +onednn_verbose,info,cpu,runtime:OpenMP,nthr:128 |
| 116 | +onednn_verbose,info,cpu,isa:Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and Intel AMX with bfloat16 and 8-bit integer support |
| 117 | +onednn_verbose,info,gpu,runtime:none |
| 118 | +onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time |
| 119 | +onednn_verbose,exec,cpu,reorder,simple:any,undef,src_f32::blocked:a:f0 dst_f32::blocked:a:f0,attr-scratchpad:user ,,2,5.2561 |
| 120 | +... |
| 121 | +onednn_verbose,exec,cpu,convolution,jit:avx512_core_amx_bf16,forward_training,src_bf16::blocked:acdb:f0 wei_bf16:p:blocked:ABcd16b16a2b:f0 bia_f32::blocked:a:f0 dst_bf16::blocked:acdb:f0,attr-scratchpad:user ,alg:convolution_direct,mb7_ic2oc1_ih224oh111kh3sh2dh1ph1_iw224ow111kw3sw2dw1pw1,0.628906 |
| 122 | +... |
| 123 | +onednn_verbose,exec,cpu,matmul,brg:avx512_core_amx_int8,undef,src_s8::blocked:ab:f0 wei_s8:p:blocked:BA16a64b4a:f0 dst_s8::blocked:ab:f0,attr-scratchpad:user ,,1x30522:30522x768:1x768,7.66382 |
| 124 | +... |
| 125 | +``` |
| 126 | + |
| 127 | +If we get the verbose of `avx512_core_amx_bf16` for BFloat16 or `avx512_core_amx_int8` for quantization with INT8, it indicates that AMX is activated. |
0 commit comments