|
3 | 3 | static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
4 | 4 | const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
5 | 5 |
|
6 |
| - const int nrows = item_ct1.get_group_range(2); |
| 6 | + const int nrows = item_ct1.get_group_range(0); |
7 | 7 | const int nchannels = item_ct1.get_group_range(1);
|
8 | 8 | const int nthreads = item_ct1.get_local_range(2);
|
9 |
| - const int sample = item_ct1.get_group(0); |
| 9 | + const int sample = item_ct1.get_group(2); |
10 | 10 | const int channel = item_ct1.get_group(1);
|
11 |
| - const int row = item_ct1.get_group(2); |
| 11 | + const int row = item_ct1.get_group(0); |
12 | 12 |
|
13 | 13 | const int tid = item_ct1.get_local_id(2);
|
14 | 14 | const int nwarps = nthreads / WARP_SIZE;
|
@@ -140,11 +140,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
140 | 140 | static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
141 | 141 | const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
142 | 142 |
|
143 |
| - const int nrows = item_ct1.get_group_range(2); |
| 143 | + const int nrows = item_ct1.get_group_range(0); |
144 | 144 | const int nchannels = item_ct1.get_group_range(1);
|
145 |
| - const int sample = item_ct1.get_group(0); |
| 145 | + const int sample = item_ct1.get_group(2); |
146 | 146 | const int channel = item_ct1.get_group(1);
|
147 |
| - const int row = item_ct1.get_group(2); |
| 147 | + const int row = item_ct1.get_group(0); |
148 | 148 | const int nthreads = item_ct1.get_local_range(2);
|
149 | 149 |
|
150 | 150 | const int tid = item_ct1.get_local_id(2);
|
@@ -237,10 +237,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
237 | 237 | const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
238 | 238 | const float eps, queue_ptr stream, int device) {
|
239 | 239 |
|
240 |
| - const sycl::range<3> global_dims(nsamples, nchannels, nrows); |
| 240 | + const sycl::range<3> global_dims(nrows, nchannels, nsamples); |
241 | 241 | GGML_ASSERT(ncols % WARP_SIZE == 0);
|
242 | 242 | if (ncols < 1024) {
|
243 |
| - const sycl::range<3> block_dims(1, 1, WARP_SIZE); // Equivalent to CUDA's (WARP_SIZE, 1, 1) |
| 243 | + const sycl::range<3> block_dims(1, 1, WARP_SIZE); |
244 | 244 | stream->submit([&](sycl::handler& cgh) {
|
245 | 245 | cgh.parallel_for(
|
246 | 246 | sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
@@ -324,7 +324,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
324 | 324 | GGML_ASSERT(ncols % WARP_SIZE == 0);
|
325 | 325 | // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
326 | 326 |
|
327 |
| - const sycl::range<3> global_dims(nsamples, nchannels, nrows); |
| 327 | + const sycl::range<3> global_dims(nrows, nchannels, nsamples); |
328 | 328 | if (ncols < 1024) {
|
329 | 329 | const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
330 | 330 | stream->submit([&](sycl::handler& cgh) {
|
|
0 commit comments