Skip to content

Commit 896c2d5

Browse files
committed
refactor: replace specific block copy functions with template
The changes replace multiple redundant block copy functions (e.g., cpy_block_q8_0_q8_0, cpy_block_q5_0_q5_0) with a single templated function cpy_blck_q_q. This reduces code duplication by using a generic template that works for any block type, improving maintainability while preserving the same functionality. The template is instantiated with specific block types (e.g., block_q8_0) where needed.
1 parent c8c2278 commit 896c2d5

File tree

1 file changed

+11
-34
lines changed

1 file changed

+11
-34
lines changed

ggml/src/ggml-sycl/cpy.cpp

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -120,36 +120,13 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
120120
}
121121

122122
/* quantized type same copy */
123-
static void cpy_block_q8_0_q8_0(const char * cxi, char * cdsti) {
124-
const block_q8_0 * xi = (const block_q8_0 *) cxi;
125-
block_q8_0 * dsti = (block_q8_0 *) cdsti;
123+
template<typename T>
124+
static void cpy_blck_q_q(const char * cxi, char * cdsti) {
125+
const T * xi = (const T *) cxi;
126+
T * dsti = (T *) cdsti;
126127
*dsti = *xi;
127128
}
128129

129-
static void cpy_block_q5_0_q5_0(const char * cxi, char * cdsti) {
130-
const block_q5_0 * xi = (const block_q5_0 *) cxi;
131-
block_q5_0 * dsti = (block_q5_0 *) cdsti;
132-
*dsti = *xi;
133-
}
134-
135-
136-
static void cpy_block_q5_1_q5_1(const char * cxi, char * cdsti) {
137-
const block_q5_1 * xi = (const block_q5_1 *) cxi;
138-
block_q5_1 * dsti = (block_q5_1 *) cdsti;
139-
*dsti = *xi;
140-
}
141-
142-
static void cpy_block_q4_0_q4_0(const char * cxi, char * cdsti) {
143-
const block_q4_0 * xi = (const block_q4_0 *) cxi;
144-
block_q4_0 * dsti = (block_q4_0 *) cdsti;
145-
*dsti = *xi;
146-
}
147-
148-
static void cpy_block_q4_1_q4_1(const char * cxi, char * cdsti) {
149-
const block_q4_1 * xi = (const block_q4_1 *) cxi;
150-
block_q4_1 * dsti = (block_q4_1 *) cdsti;
151-
*dsti = *xi;
152-
}
153130

154131
static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
155132
float * cdstf = (float *) (cdsti);
@@ -347,7 +324,7 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
347324
}
348325

349326

350-
template <cpy_kernel_t cpy_blck, int qk>
327+
template <typename T, int qk>
351328
static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
352329
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
353330
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
@@ -371,7 +348,7 @@ static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00,
371348
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
372349
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
373350

374-
cpy_blck(cx + x_offset, cdst + dst_offset);
351+
cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
375352
}
376353

377354
template <cpy_kernel_t cpy_blck, int qk>
@@ -687,7 +664,7 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
687664
const int num_blocks = ne;
688665
stream->parallel_for(
689666
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
690-
cpy_q_q<cpy_block_q8_0_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
667+
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
691668
});
692669
}
693670

@@ -700,7 +677,7 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
700677
const int num_blocks = ne;
701678
stream->parallel_for(
702679
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
703-
cpy_q_q<cpy_block_q5_0_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
680+
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
704681
});
705682
}
706683

@@ -713,7 +690,7 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
713690
const int num_blocks = ne;
714691
stream->parallel_for(
715692
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
716-
cpy_q_q<cpy_block_q5_1_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
693+
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
717694
});
718695
}
719696

@@ -726,7 +703,7 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
726703
const int num_blocks = ne;
727704
stream->parallel_for(
728705
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
729-
cpy_q_q<cpy_block_q4_0_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
706+
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
730707
});
731708
}
732709

@@ -739,7 +716,7 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
739716
const int num_blocks = ne;
740717
stream->parallel_for(
741718
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
742-
cpy_q_q<cpy_block_q4_1_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
719+
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
743720
});
744721
}
745722

0 commit comments

Comments
 (0)