@@ -120,36 +120,13 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
120
120
}
121
121
122
122
/* quantized type same copy */
123
- static void cpy_block_q8_0_q8_0 (const char * cxi, char * cdsti) {
124
- const block_q8_0 * xi = (const block_q8_0 *) cxi;
125
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
123
+ template <typename T>
124
+ static void cpy_blck_q_q (const char * cxi, char * cdsti) {
125
+ const T * xi = (const T *) cxi;
126
+ T * dsti = (T *) cdsti;
126
127
*dsti = *xi;
127
128
}
128
129
129
- static void cpy_block_q5_0_q5_0 (const char * cxi, char * cdsti) {
130
- const block_q5_0 * xi = (const block_q5_0 *) cxi;
131
- block_q5_0 * dsti = (block_q5_0 *) cdsti;
132
- *dsti = *xi;
133
- }
134
-
135
-
136
- static void cpy_block_q5_1_q5_1 (const char * cxi, char * cdsti) {
137
- const block_q5_1 * xi = (const block_q5_1 *) cxi;
138
- block_q5_1 * dsti = (block_q5_1 *) cdsti;
139
- *dsti = *xi;
140
- }
141
-
142
- static void cpy_block_q4_0_q4_0 (const char * cxi, char * cdsti) {
143
- const block_q4_0 * xi = (const block_q4_0 *) cxi;
144
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
145
- *dsti = *xi;
146
- }
147
-
148
- static void cpy_block_q4_1_q4_1 (const char * cxi, char * cdsti) {
149
- const block_q4_1 * xi = (const block_q4_1 *) cxi;
150
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
151
- *dsti = *xi;
152
- }
153
130
154
131
static void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
155
132
float * cdstf = (float *) (cdsti);
@@ -347,7 +324,7 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
347
324
}
348
325
349
326
350
- template <cpy_kernel_t cpy_blck , int qk>
327
+ template <typename T , int qk>
351
328
static void cpy_q_q (const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
352
329
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
353
330
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
@@ -371,7 +348,7 @@ static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00,
371
348
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
372
349
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
373
350
374
- cpy_blck (cx + x_offset, cdst + dst_offset);
351
+ cpy_blck_q_q<T> (cx + x_offset, cdst + dst_offset);
375
352
}
376
353
377
354
template <cpy_kernel_t cpy_blck, int qk>
@@ -687,7 +664,7 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
687
664
const int num_blocks = ne;
688
665
stream->parallel_for (
689
666
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
690
- cpy_q_q<cpy_block_q8_0_q8_0 , QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
667
+ cpy_q_q<block_q8_0 , QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
691
668
});
692
669
}
693
670
@@ -700,7 +677,7 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
700
677
const int num_blocks = ne;
701
678
stream->parallel_for (
702
679
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
703
- cpy_q_q<cpy_block_q5_0_q5_0 , QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
680
+ cpy_q_q<block_q5_0 , QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
704
681
});
705
682
}
706
683
@@ -713,7 +690,7 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
713
690
const int num_blocks = ne;
714
691
stream->parallel_for (
715
692
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
716
- cpy_q_q<cpy_block_q5_1_q5_1 , QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
693
+ cpy_q_q<block_q5_1 , QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
717
694
});
718
695
}
719
696
@@ -726,7 +703,7 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
726
703
const int num_blocks = ne;
727
704
stream->parallel_for (
728
705
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
729
- cpy_q_q<cpy_block_q4_0_q4_0 , QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
706
+ cpy_q_q<block_q4_0 , QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
730
707
});
731
708
}
732
709
@@ -739,7 +716,7 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
739
716
const int num_blocks = ne;
740
717
stream->parallel_for (
741
718
sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks), sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
742
- cpy_q_q<cpy_block_q4_1_q4_1 , QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
719
+ cpy_q_q<block_q4_1 , QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
743
720
});
744
721
}
745
722
0 commit comments