@@ -192,9 +192,10 @@ struct vk_device_struct {
192
192
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
193
193
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
194
194
vk_pipeline pipeline_acc_f32;
195
- vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
196
- vk_pipeline pipeline_mul_f32;
197
- vk_pipeline pipeline_div_f32;
195
+ vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
196
+ vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
197
+ vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
198
+ vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
198
199
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
199
200
vk_pipeline pipeline_upscale_f32;
200
201
vk_pipeline pipeline_scale_f32;
@@ -1456,13 +1457,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
1456
1457
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1457
1458
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1458
1459
1459
- ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1460
- ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1460
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1461
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1462
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1463
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1461
1464
1462
1465
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1463
1466
1464
- ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1465
- ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1467
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1468
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1469
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1470
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1466
1471
1467
1472
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1468
1473
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3801,20 +3806,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3801
3806
return nullptr;
3802
3807
case GGML_OP_ADD:
3803
3808
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3804
- return ctx->device ->pipeline_add_f32 ;
3809
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
3805
3810
}
3806
3811
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
3807
- return ctx->device ->pipeline_add_f16_f32_f16 ;
3812
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
3808
3813
}
3809
3814
return nullptr;
3810
3815
case GGML_OP_MUL:
3811
3816
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3812
- return ctx->device ->pipeline_mul_f32 ;
3817
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
3813
3818
}
3814
3819
return nullptr;
3815
3820
case GGML_OP_DIV:
3816
3821
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3817
- return ctx->device ->pipeline_div_f32 ;
3822
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
3818
3823
}
3819
3824
return nullptr;
3820
3825
case GGML_OP_CONCAT:
0 commit comments