@@ -228,6 +228,8 @@ struct vk_device_struct {
228
228
vk_pipeline pipeline_repeat_f32;
229
229
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
230
230
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
231
+ vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
232
+ vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
231
233
vk_pipeline pipeline_norm_f32;
232
234
vk_pipeline pipeline_group_norm_f32;
233
235
vk_pipeline pipeline_rms_norm_f32;
@@ -1965,6 +1967,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
1965
1967
ggml_vk_create_pipeline (device, device->pipeline_contig_cpy_f32_f16 , " contig_cpy_f32_f16" , contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
1966
1968
ggml_vk_create_pipeline (device, device->pipeline_contig_cpy_f16_f16 , " contig_cpy_f16_f16" , contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
1967
1969
1970
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q4_0], " cpy_f32_q4_0" , cpy_f32_q4_0_len, cpy_f32_q4_0_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_0), 1 , 1 }, {}, 1 );
1971
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q4_1], " cpy_f32_q4_1" , cpy_f32_q4_1_len, cpy_f32_q4_1_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_1), 1 , 1 }, {}, 1 );
1972
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q5_0], " cpy_f32_q5_0" , cpy_f32_q5_0_len, cpy_f32_q5_0_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_0), 1 , 1 }, {}, 1 );
1973
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q5_1], " cpy_f32_q5_1" , cpy_f32_q5_1_len, cpy_f32_q5_1_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_1), 1 , 1 }, {}, 1 );
1974
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q8_0], " cpy_f32_q8_0" , cpy_f32_q8_0_len, cpy_f32_q8_0_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q8_0), 1 , 1 }, {}, 1 );
1975
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_IQ4_NL], " cpy_f32_iq4_nl" , cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_IQ4_NL), 1 , 1 }, {}, 1 );
1976
+
1977
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q4_0], " cpy_q4_0_f32" , cpy_q4_0_f32_len, cpy_q4_0_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_0), 1 , 1 }, {}, 1 );
1978
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q4_1], " cpy_q4_1_f32" , cpy_q4_1_f32_len, cpy_q4_1_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_1), 1 , 1 }, {}, 1 );
1979
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q5_0], " cpy_q5_0_f32" , cpy_q5_0_f32_len, cpy_q5_0_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_0), 1 , 1 }, {}, 1 );
1980
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q5_1], " cpy_q5_1_f32" , cpy_q5_1_f32_len, cpy_q5_1_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_1), 1 , 1 }, {}, 1 );
1981
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q8_0], " cpy_q8_0_f32" , cpy_q8_0_f32_len, cpy_q8_0_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q8_0), 1 , 1 }, {}, 1 );
1982
+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_IQ4_NL], " cpy_iq4_nl_f32" , cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_IQ4_NL), 1 , 1 }, {}, 1 );
1983
+
1968
1984
ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
1969
1985
ggml_vk_create_pipeline (device, device->pipeline_add_f32_norepeat , " add_f32_norepeat" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
1970
1986
ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
@@ -3689,6 +3705,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
3689
3705
return ctx->device ->pipeline_cpy_f16_f16 ;
3690
3706
}
3691
3707
}
3708
+ if (src->type == GGML_TYPE_F32) {
3709
+ switch (to) {
3710
+ case GGML_TYPE_Q4_0:
3711
+ case GGML_TYPE_Q4_1:
3712
+ case GGML_TYPE_Q5_0:
3713
+ case GGML_TYPE_Q5_1:
3714
+ case GGML_TYPE_Q8_0:
3715
+ case GGML_TYPE_IQ4_NL:
3716
+ return ctx->device ->pipeline_cpy_f32_quant [to];
3717
+ default :
3718
+ break ;
3719
+ }
3720
+ }
3721
+
3722
+ if (to == GGML_TYPE_F32) {
3723
+ switch (src->type ) {
3724
+ case GGML_TYPE_Q4_0:
3725
+ case GGML_TYPE_Q4_1:
3726
+ case GGML_TYPE_Q5_0:
3727
+ case GGML_TYPE_Q5_1:
3728
+ case GGML_TYPE_Q8_0:
3729
+ case GGML_TYPE_IQ4_NL:
3730
+ return ctx->device ->pipeline_cpy_quant_f32 [src->type ];
3731
+ default :
3732
+ break ;
3733
+ }
3734
+ }
3692
3735
3693
3736
std::cerr << " Missing CPY op for types: " << ggml_type_name (src->type ) << " " << ggml_type_name (to) << std::endl;
3694
3737
GGML_ABORT (" fatal error" );
@@ -5160,7 +5203,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5160
5203
}
5161
5204
std::cerr << " ), (" << dst << " , name=" << dst->name << " , type=" << dst->type << " , ne0=" << dst->ne [0 ] << " , ne1=" << dst->ne [1 ] << " , ne2=" << dst->ne [2 ] << " , ne3=" << dst->ne [3 ] << " , nb0=" << dst->nb [0 ] << " , nb1=" << dst->nb [1 ] << " , nb2=" << dst->nb [2 ] << " , nb3=" << dst->nb [3 ];
5162
5205
std::cerr << " ), " << ggml_op_name (op) << " , " << (dryrun ? " dryrun" : " " ) << " )" );
5163
- GGML_ASSERT (op == GGML_OP_GET_ROWS || (!ggml_is_quantized (src0->type ) && (src1 == nullptr || !ggml_is_quantized (src1->type )))); // NOLINT
5206
+ GGML_ASSERT (op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized (src0->type ) && (src1 == nullptr || !ggml_is_quantized (src1->type )))); // NOLINT
5164
5207
GGML_ASSERT (ggml_vk_op_supports_incontiguous (op) || ggml_vk_dim01_contiguous (src0)); // NOLINT
5165
5208
GGML_ASSERT (dst->buffer != nullptr );
5166
5209
const uint64_t ne00 = src0->ne [0 ];
@@ -7905,12 +7948,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
7905
7948
{
7906
7949
ggml_type src0_type = op->src [0 ]->type ;
7907
7950
ggml_type src1_type = op->src [1 ] != nullptr ? op->src [1 ]->type : src0_type;
7908
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
7909
- return true ;
7951
+
7952
+ if (src0_type == GGML_TYPE_F32) {
7953
+ switch (src1_type) {
7954
+ case GGML_TYPE_F32:
7955
+ case GGML_TYPE_F16:
7956
+ case GGML_TYPE_Q4_0:
7957
+ case GGML_TYPE_Q4_1:
7958
+ case GGML_TYPE_Q5_0:
7959
+ case GGML_TYPE_Q5_1:
7960
+ case GGML_TYPE_Q8_0:
7961
+ case GGML_TYPE_IQ4_NL:
7962
+ return true ;
7963
+ default :
7964
+ break ;
7965
+ }
7910
7966
}
7911
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
7912
- return true ;
7967
+ if (src1_type == GGML_TYPE_F32) {
7968
+ switch (src0_type) {
7969
+ case GGML_TYPE_Q4_0:
7970
+ case GGML_TYPE_Q4_1:
7971
+ case GGML_TYPE_Q5_0:
7972
+ case GGML_TYPE_Q5_1:
7973
+ case GGML_TYPE_Q8_0:
7974
+ case GGML_TYPE_IQ4_NL:
7975
+ return true ;
7976
+ default :
7977
+ break ;
7978
+ }
7913
7979
}
7980
+
7914
7981
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
7915
7982
return true ;
7916
7983
}
0 commit comments