Skip to content

Commit a8a1f33

Browse files
authored
Vulkan: Add DP4A MMQ and Q8_1 quantization shader (#12135)
* Vulkan: Add DP4A MMQ and Q8_1 quantization shader * Add q4_0 x q8_1 matrix matrix multiplication support * Vulkan: Add int8 coopmat MMQ support * Vulkan: Add q4_1, q5_0 and q5_1 quants, improve integer dot code * Add GL_EXT_integer_dot_product check * Remove ggml changes, fix mmq pipeline picker * Remove ggml changes, restore Intel coopmat behaviour * Fix glsl compile attempt when integer vec dot is not supported * Remove redundant code, use non-saturating integer dot, enable all matmul sizes for mmq * Remove redundant comment * Fix integer dot check * Fix compile issue with unsupported int dot glslc * Update Windows build Vulkan SDK version
1 parent 1790e73 commit a8a1f33

File tree

10 files changed

+1146
-95
lines changed

10 files changed

+1146
-95
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ jobs:
803803
env:
804804
OPENBLAS_VERSION: 0.3.23
805805
SDE_VERSION: 9.33.0-2024-01-07
806-
VULKAN_VERSION: 1.4.304.1
806+
VULKAN_VERSION: 1.4.309.0
807807

808808
strategy:
809809
matrix:

ggml/src/ggml-vulkan/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ if (Vulkan_FOUND)
6969
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
7070
endif()
7171

72+
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
73+
# If it's not, there will be an error to stderr.
74+
# If it's supported, set a define to indicate that we should compile those shaders
75+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
76+
OUTPUT_VARIABLE glslc_output
77+
ERROR_VARIABLE glslc_error)
78+
79+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
80+
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
81+
else()
82+
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
83+
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
84+
endif()
85+
7286
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
7387
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
7488

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 440 additions & 80 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void main() {
212212
#else
213213
ACC_TYPE sums[WMITER * TM * WNITER * TN];
214214
FLOAT_TYPE cache_a[WMITER * TM];
215-
FLOAT_TYPE cache_b[WNITER * TN];
215+
FLOAT_TYPE cache_b[TN];
216216

217217
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
218218
sums[i] = ACC_TYPE(0.0f);
@@ -744,16 +744,14 @@ void main() {
744744
}
745745
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
746746
[[unroll]] for (uint j = 0; j < TN; j++) {
747-
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
747+
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
748748
}
749-
}
750749

751-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
752750
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
753751
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
754752
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
755753
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
756-
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
754+
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
757755
}
758756
}
759757
}

0 commit comments

Comments
 (0)