Skip to content

Commit 43be2d6

Browse files
committed
Swap grid dims of nsamples and nrows
ggml-ci
1 parent 67c4a8c commit 43be2d6

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

ggml/src/ggml-sycl/norm.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
44
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
55

6-
const int nrows = item_ct1.get_group_range(2);
6+
const int nrows = item_ct1.get_group_range(0);
77
const int nchannels = item_ct1.get_group_range(1);
88
const int nthreads = item_ct1.get_local_range(2);
9-
const int sample = item_ct1.get_group(0);
9+
const int sample = item_ct1.get_group(2);
1010
const int channel = item_ct1.get_group(1);
11-
const int row = item_ct1.get_group(2);
11+
const int row = item_ct1.get_group(0);
1212

1313
const int tid = item_ct1.get_local_id(2);
1414
const int nwarps = nthreads / WARP_SIZE;
@@ -140,11 +140,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
140140
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
141141
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
142142

143-
const int nrows = item_ct1.get_group_range(2);
143+
const int nrows = item_ct1.get_group_range(0);
144144
const int nchannels = item_ct1.get_group_range(1);
145-
const int sample = item_ct1.get_group(0);
145+
const int sample = item_ct1.get_group(2);
146146
const int channel = item_ct1.get_group(1);
147-
const int row = item_ct1.get_group(2);
147+
const int row = item_ct1.get_group(0);
148148
const int nthreads = item_ct1.get_local_range(2);
149149

150150
const int tid = item_ct1.get_local_id(2);
@@ -237,10 +237,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
237237
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
238238
const float eps, queue_ptr stream, int device) {
239239

240-
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
240+
const sycl::range<3> global_dims(nrows, nchannels, nsamples);
241241
GGML_ASSERT(ncols % WARP_SIZE == 0);
242242
if (ncols < 1024) {
243-
const sycl::range<3> block_dims(1, 1, WARP_SIZE); // Equivalent to CUDA's (WARP_SIZE, 1, 1)
243+
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
244244
stream->submit([&](sycl::handler& cgh) {
245245
cgh.parallel_for(
246246
sycl::nd_range<3>(global_dims * block_dims, block_dims),
@@ -324,7 +324,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
324324
GGML_ASSERT(ncols % WARP_SIZE == 0);
325325
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
326326

327-
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
327+
const sycl::range<3> global_dims(nrows, nchannels, nsamples);
328328
if (ncols < 1024) {
329329
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
330330
stream->submit([&](sycl::handler& cgh) {

0 commit comments

Comments
 (0)