Skip to content

Commit 904109e

Browse files
authored
vulkan: fix group_norm (#10496)
Fix bad calculation of the end of the range. Add a backend test that covers the bad case (taken from stable diffusion). Fixes leejet/stable-diffusion.cpp#439.
1 parent 45abe0f commit 904109e

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7157,7 +7157,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
71577157
const int32_t max_period = tensor->op_params[1];
71587158
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
71597159
} else if (tensor->op == GGML_OP_POOL_2D) {
7160-
enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
7160+
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
71617161
const int32_t k0 = tensor->op_params[1];
71627162
const int32_t k1 = tensor->op_params[2];
71637163
const int32_t s0 = tensor->op_params[3];

ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void main() {
1919

2020
const uint tid = gl_LocalInvocationID.x;
2121
const uint start = gl_WorkGroupID.x * group_size + tid;
22-
const uint end = start + group_size;
22+
const uint end = (gl_WorkGroupID.x + 1) * group_size;
2323

2424
tmp[tid] = 0.0f;
2525

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3796,7 +3796,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
37963796
test_cases.emplace_back(new test_upscale());
37973797
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
37983798
test_cases.emplace_back(new test_upscale_ext());
3799-
test_cases.emplace_back(new test_group_norm());
3799+
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
3800+
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
38003801
test_cases.emplace_back(new test_acc());
38013802
test_cases.emplace_back(new test_pad());
38023803
test_cases.emplace_back(new test_arange());

0 commit comments

Comments
 (0)