Skip to content

Commit b344439

Browse files
ggerganovAcly
andauthored
sync : ggml (#13268)
* vulkan : kernels for depthwise 2D convolution (CONV_2D_DW) (ggml/1204) * vulkan : add kernels for depthwise 2d convolution (OP_CONV_2D_DW) * review: remove src_x/y < 0 checks; add performance tests * sync : ggml ggml-ci * vulkan : fix lint (#0) --------- Co-authored-by: Acly <aclysia@gmail.com>
1 parent a75cb30 commit b344439

File tree

5 files changed

+225
-1
lines changed

5 files changed

+225
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ struct vk_device_struct {
389389
vk_pipeline pipeline_rwkv_wkv6_f32;
390390
vk_pipeline pipeline_rwkv_wkv7_f32;
391391
vk_pipeline pipeline_opt_step_adamw_f32;
392+
vk_pipeline pipeline_conv2d_dw_whcn_f32;
393+
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
392394

393395
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
394396
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -701,6 +703,24 @@ struct vk_op_rwkv_wkv7_push_constants {
701703
uint32_t H;
702704
};
703705

706+
struct vk_op_conv2d_dw_push_constants {
707+
uint32_t ne;
708+
uint32_t batches;
709+
uint32_t channels;
710+
uint32_t dst_w;
711+
uint32_t dst_h;
712+
uint32_t src_w;
713+
uint32_t src_h;
714+
uint32_t knl_w;
715+
uint32_t knl_h;
716+
int32_t stride_x;
717+
int32_t stride_y;
718+
int32_t pad_x;
719+
int32_t pad_y;
720+
int32_t dilation_x;
721+
int32_t dilation_y;
722+
};
723+
704724
struct vk_op_upscale_push_constants {
705725
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
706726
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2610,6 +2630,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
26102630

26112631
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
26122632

2633+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2634+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2635+
26132636
for (auto &c : compiles) {
26142637
c.wait();
26152638
}
@@ -6137,6 +6160,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
61376160
return ctx->device->pipeline_leaky_relu_f32;
61386161
}
61396162
return nullptr;
6163+
case GGML_OP_CONV_2D_DW:
6164+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6165+
if (ggml_is_contiguous(src1)) {
6166+
return ctx->device->pipeline_conv2d_dw_whcn_f32;
6167+
} else if (ggml_is_contiguous_channels(src1)) {
6168+
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6169+
}
6170+
}
6171+
return nullptr;
61406172
default:
61416173
return nullptr;
61426174
}
@@ -6163,6 +6195,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
61636195
case GGML_OP_REPEAT_BACK:
61646196
case GGML_OP_ROPE:
61656197
case GGML_OP_RMS_NORM:
6198+
case GGML_OP_CONV_2D_DW:
61666199
return true;
61676200
default:
61686201
return false;
@@ -6459,6 +6492,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
64596492
case GGML_OP_CONCAT:
64606493
case GGML_OP_UPSCALE:
64616494
case GGML_OP_UNARY:
6495+
case GGML_OP_CONV_2D_DW:
64626496
{
64636497
const uint32_t ne = ggml_nelements(dst);
64646498
if (ne > 262144) {
@@ -7245,6 +7279,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
72457279
}, dryrun);
72467280
}
72477281

7282+
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7283+
vk_op_conv2d_dw_push_constants p{};
7284+
p.ne = ggml_nelements(dst);
7285+
p.channels = dst->ne[2];
7286+
p.batches = dst->ne[3];
7287+
p.dst_w = dst->ne[0];
7288+
p.dst_h = dst->ne[1];
7289+
p.src_w = src1->ne[0];
7290+
p.src_h = src1->ne[1];
7291+
p.knl_w = src0->ne[0];
7292+
p.knl_h = src0->ne[1];
7293+
p.stride_x = dst->op_params[0];
7294+
p.stride_y = dst->op_params[1];
7295+
p.pad_x = dst->op_params[2];
7296+
p.pad_y = dst->op_params[3];
7297+
p.dilation_x = dst->op_params[4];
7298+
p.dilation_y = dst->op_params[5];
7299+
7300+
GGML_ASSERT(src0->ne[3] == p.channels);
7301+
GGML_ASSERT(src1->ne[3] == p.batches);
7302+
7303+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
7304+
}
7305+
72487306
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
72497307
const float * op_params = (const float *)dst->op_params;
72507308
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -8265,6 +8323,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
82658323
case GGML_OP_IM2COL:
82668324
case GGML_OP_TIMESTEP_EMBEDDING:
82678325
case GGML_OP_POOL_2D:
8326+
case GGML_OP_CONV_2D_DW:
82688327
case GGML_OP_RWKV_WKV6:
82698328
case GGML_OP_RWKV_WKV7:
82708329
case GGML_OP_LEAKY_RELU:
@@ -8328,6 +8387,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
83288387
case GGML_OP_IM2COL:
83298388
case GGML_OP_TIMESTEP_EMBEDDING:
83308389
case GGML_OP_POOL_2D:
8390+
case GGML_OP_CONV_2D_DW:
83318391
case GGML_OP_LEAKY_RELU:
83328392
{
83338393
// These operations all go through ggml_vk_op_f32, so short-circuit and
@@ -8501,6 +8561,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
85018561
case GGML_OP_POOL_2D:
85028562
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
85038563

8564+
break;
8565+
case GGML_OP_CONV_2D_DW:
8566+
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
8567+
85048568
break;
85058569
case GGML_OP_LEAKY_RELU:
85068570
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -8622,6 +8686,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
86228686
case GGML_OP_IM2COL:
86238687
case GGML_OP_TIMESTEP_EMBEDDING:
86248688
case GGML_OP_POOL_2D:
8689+
case GGML_OP_CONV_2D_DW:
86258690
case GGML_OP_RWKV_WKV6:
86268691
case GGML_OP_RWKV_WKV7:
86278692
case GGML_OP_LEAKY_RELU:
@@ -9599,6 +9664,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
95999664
case GGML_OP_COUNT_EQUAL:
96009665
case GGML_OP_IM2COL:
96019666
case GGML_OP_TIMESTEP_EMBEDDING:
9667+
case GGML_OP_CONV_2D_DW:
96029668
case GGML_OP_POOL_2D:
96039669
case GGML_OP_RWKV_WKV6:
96049670
case GGML_OP_RWKV_WKV7:
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
5+
layout (push_constant) uniform parameter
6+
{
7+
uint ne;
8+
uint batches;
9+
uint channels;
10+
uint dst_w;
11+
uint dst_h;
12+
uint src_w;
13+
uint src_h;
14+
uint knl_w;
15+
uint knl_h;
16+
int stride_x;
17+
int stride_y;
18+
int pad_x;
19+
int pad_y;
20+
int dilation_x;
21+
int dilation_y;
22+
} p;
23+
24+
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
25+
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
26+
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
27+
28+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
29+
30+
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
31+
uint i0 = idx / p.dst_w;
32+
uint dst_x = idx - i0 * p.dst_w;
33+
uint i1 = i0 / p.dst_h;
34+
uint dst_y = i0 - i1 * p.dst_h;
35+
uint n = i1 / p.channels;
36+
uint c = i1 - n * p.channels;
37+
38+
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
39+
uint knl_i = c * p.knl_h * p.knl_w;
40+
41+
FLOAT_TYPE sum = 0.0;
42+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
43+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
44+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
45+
continue;
46+
}
47+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
48+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
49+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
50+
continue;
51+
}
52+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
53+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
54+
sum = fma(v, k, sum);
55+
}
56+
}
57+
return sum;
58+
}
59+
60+
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
61+
uint i0 = idx / p.channels;
62+
uint c = idx - i0 * p.channels;
63+
uint i1 = i0 / p.dst_w;
64+
uint dst_x = i0 - i1 * p.dst_w;
65+
uint n = i1 / p.dst_h;
66+
uint dst_y = i1 - n * p.dst_h;
67+
68+
uint src_i = n * p.channels * p.src_h * p.src_w;
69+
uint src_row = p.src_w * p.channels;
70+
uint knl_row = p.knl_w * p.channels;
71+
72+
FLOAT_TYPE sum = 0.0;
73+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
74+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
75+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
76+
continue;
77+
}
78+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
79+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
80+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
81+
continue;
82+
}
83+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
84+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
85+
sum = fma(v, k, sum);
86+
}
87+
}
88+
return sum;
89+
}
90+
91+
void main() {
92+
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
93+
if (idx >= p.ne) {
94+
return;
95+
}
96+
97+
FLOAT_TYPE result =
98+
#ifdef WHCN
99+
conv_2d_dw_whcn(idx);
100+
#else
101+
conv_2d_dw_cwhn(idx);
102+
#endif
103+
dst_data[idx] = D_TYPE(result);
104+
}
105+

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,9 @@ void process_shaders() {
584584

585585
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
586586

587+
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
588+
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
589+
587590
for (auto &c : compiles) {
588591
c.wait();
589592
}

scripts/sync-ggml.last

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f3a375f20bf56860b30e7c511d03593a1e393345
1+
0482de9c63b9134eb462c7732888c0ee0dbc2755

tests/test-backend-ops.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2765,6 +2765,48 @@ struct test_im2col : public test_case {
27652765
}
27662766
};
27672767

2768+
// GGML_OP_CONV_2D_DW
2769+
struct test_conv_2d_dw : public test_case {
2770+
const std::array<int64_t, 4> ne_input;
2771+
const std::array<int64_t, 4> ne_kernel;
2772+
const int stride;
2773+
const int padding;
2774+
const int dilation;
2775+
const bool cwhn;
2776+
2777+
std::string vars() override {
2778+
return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
2779+
}
2780+
2781+
test_conv_2d_dw(std::array<int64_t, 4> ne_input = {64, 64, 16, 1},
2782+
std::array<int64_t, 4> ne_kernel = {3, 3, 1, 16},
2783+
int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
2784+
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}
2785+
2786+
ggml_tensor * build_graph(ggml_context * ctx) override {
2787+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
2788+
ggml_set_name(input, "input");
2789+
2790+
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
2791+
ggml_set_name(kernel, "kernel");
2792+
2793+
if (cwhn) {
2794+
// change memory layout to channel-most-contiguous (CWHN),
2795+
// then permute it back so NE matches the original input
2796+
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
2797+
input = ggml_permute(ctx, input, 2, 0, 1, 3);
2798+
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
2799+
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
2800+
}
2801+
2802+
ggml_tensor * out = ggml_conv_2d_dw_direct(
2803+
ctx, kernel, input,
2804+
stride, stride, padding, padding, dilation, dilation);
2805+
ggml_set_name(out, "out");
2806+
return out;
2807+
}
2808+
};
2809+
27682810
// GGML_OP_CONCAT
27692811
struct test_concat : public test_case {
27702812
const ggml_type type;
@@ -3975,6 +4017,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39754017
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
39764018
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
39774019

4020+
test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
4021+
test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
4022+
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
4023+
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
4024+
39784025
test_cases.emplace_back(new test_conv_transpose_1d());
39794026
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
39804027
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
@@ -4549,6 +4596,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
45494596
}
45504597
}
45514598

4599+
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
4600+
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
4601+
45524602
return test_cases;
45534603
}
45544604

0 commit comments

Comments
 (0)