Skip to content

Commit 63a0a00

Browse files
jeffbolznvarthw
authored andcommitted
vulkan: Optimize binary ops (ggml-org#10270)
Reuse the index calculations across all of src0/src1/dst. Add a shader variant for when src0/src1 are the same dimensions and additional modulus for src1 aren't needed. Div/mod are slow, so add "fast" div/mod that have a fast path when the calculation isn't needed or can be done more cheaply.
1 parent 6655f9d commit 63a0a00

File tree

9 files changed

+117
-52
lines changed

9 files changed

+117
-52
lines changed

ggml/src/ggml-vulkan.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,10 @@ struct vk_device_struct {
192192
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
193193
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
194194
vk_pipeline pipeline_acc_f32;
195-
vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
196-
vk_pipeline pipeline_mul_f32;
197-
vk_pipeline pipeline_div_f32;
195+
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
196+
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
197+
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
198+
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
198199
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
199200
vk_pipeline pipeline_upscale_f32;
200201
vk_pipeline pipeline_scale_f32;
@@ -1456,13 +1457,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
14561457
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
14571458
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
14581459

1459-
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1460-
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1460+
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1461+
ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1462+
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1463+
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
14611464

14621465
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
14631466

1464-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1465-
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1467+
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1468+
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1469+
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1470+
ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
14661471

14671472
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
14681473
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3801,20 +3806,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
38013806
return nullptr;
38023807
case GGML_OP_ADD:
38033808
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3804-
return ctx->device->pipeline_add_f32;
3809+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
38053810
}
38063811
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
3807-
return ctx->device->pipeline_add_f16_f32_f16;
3812+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
38083813
}
38093814
return nullptr;
38103815
case GGML_OP_MUL:
38113816
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3812-
return ctx->device->pipeline_mul_f32;
3817+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
38133818
}
38143819
return nullptr;
38153820
case GGML_OP_DIV:
38163821
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3817-
return ctx->device->pipeline_div_f32;
3822+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
38183823
}
38193824
return nullptr;
38203825
case GGML_OP_CONCAT:

ggml/src/vulkan-shaders/acc.comp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_binary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = gl_GlobalInvocationID.x;
810
if (idx >= p.ne) {
@@ -15,10 +17,13 @@ void main() {
1517
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
1618
const uint ox = src1_i % p.nb01;
1719

20+
uint i00, i01, i02, i03;
21+
get_indices(idx, i00, i01, i02, i03);
22+
1823
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
19-
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
24+
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
2025
} else {
21-
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
26+
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
2227
}
2328
}
2429

ggml/src/vulkan-shaders/add.comp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
11
#version 450
22

3+
#extension GL_EXT_shader_16bit_storage : require
4+
35
#include "types.comp"
46
#include "generic_binary_head.comp"
57

8+
const uint num_threads = 256;
9+
10+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11+
612
void main() {
7-
const uint idx = get_idx();
13+
uint idx = get_idx();
814

9-
if (idx >= p.ne) {
10-
return;
11-
}
15+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
16+
const uint num_iter = 2;
1217

13-
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
18+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
19+
if (idx >= p.ne) {
20+
continue;
21+
}
22+
uint i00, i01, i02, i03;
23+
get_indices(idx, i00, i01, i02, i03);
24+
25+
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
26+
27+
idx += num_threads;
28+
}
1429
}

ggml/src/vulkan-shaders/concat.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_binary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
810
const int dim = p.param3;

ggml/src/vulkan-shaders/div.comp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@
33
#include "types.comp"
44
#include "generic_binary_head.comp"
55

6+
const uint num_threads = 256;
7+
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
610
void main() {
7-
const uint idx = get_idx();
11+
uint idx = get_idx();
812

9-
if (idx >= p.ne) {
10-
return;
11-
}
13+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
14+
const uint num_iter = 2;
1215

13-
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
16+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
17+
if (idx >= p.ne) {
18+
continue;
19+
}
20+
uint i00, i01, i02, i03;
21+
get_indices(idx, i00, i01, i02, i03);
22+
23+
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
24+
25+
idx += num_threads;
26+
}
1427
}
Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#extension GL_EXT_shader_16bit_storage : require
2+
#extension GL_EXT_control_flow_attributes : require
23

34
layout (push_constant) uniform parameter
45
{
@@ -10,43 +11,50 @@ layout (push_constant) uniform parameter
1011
float param1; float param2; int param3;
1112
} p;
1213

13-
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
14-
1514
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1615
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1716
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
1817

18+
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
19+
layout(constant_id = 0) const bool norepeat = false;
20+
1921
uint get_idx() {
2022
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
2123
}
2224

23-
uint src0_idx(uint idx) {
24-
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
25-
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
26-
const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
27-
const uint i02_offset = i02*p.ne01*p.ne00;
28-
const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
29-
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
30-
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
25+
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
26+
uint fastmod(uint a, uint b) {
27+
if ((b & (b-1)) == 0) {
28+
return a & (b-1);
29+
}
30+
return a % b;
3131
}
3232

33-
uint src1_idx(uint idx) {
34-
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
33+
uint fastdiv(uint a, uint b) {
34+
return (a < b) ? 0 : (a / b);
35+
}
36+
37+
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
38+
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
3539
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
36-
const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
40+
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
3741
const uint i02_offset = i02*p.ne01*p.ne00;
38-
const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
39-
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
42+
i01 = (idx - i03_offset - i02_offset) / p.ne00;
43+
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
44+
}
45+
46+
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
47+
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
48+
}
4049

41-
return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
50+
uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
51+
if (norepeat) {
52+
return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
53+
} else {
54+
return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
55+
}
4256
}
4357

44-
uint dst_idx(uint idx) {
45-
const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
46-
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
47-
const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
48-
const uint i22_offset = i22*p.ne21*p.ne20;
49-
const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
50-
const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
51-
return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
58+
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
59+
return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
5260
}

ggml/src/vulkan-shaders/get_rows.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_binary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint i00 = gl_GlobalInvocationID.x;
810
const uint i10 = gl_GlobalInvocationID.y;

ggml/src/vulkan-shaders/get_rows_quant.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "generic_binary_head.comp"
55
#include "dequant_funcs.comp"
66

7+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
8+
79
void main() {
810
const uint i00 = (gl_GlobalInvocationID.x)*2;
911
const uint i10 = gl_GlobalInvocationID.y;

ggml/src/vulkan-shaders/mul.comp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@
33
#include "types.comp"
44
#include "generic_binary_head.comp"
55

6+
const uint num_threads = 256;
7+
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
610
void main() {
7-
const uint idx = get_idx();
11+
uint idx = get_idx();
812

9-
if (idx >= p.ne) {
10-
return;
11-
}
13+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
14+
const uint num_iter = 2;
1215

13-
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
16+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
17+
if (idx >= p.ne) {
18+
continue;
19+
}
20+
uint i00, i01, i02, i03;
21+
get_indices(idx, i00, i01, i02, i03);
22+
23+
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
24+
25+
idx += num_threads;
26+
}
1427
}

0 commit comments

Comments
 (0)