1
1
#include " norm.hpp"
2
- #include " ggml-sycl/presets.hpp"
3
2
4
3
static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
5
4
const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
6
5
7
6
const int nrows = item_ct1.get_group_range (2 );
8
7
const int nchannels = item_ct1.get_group_range (1 );
9
- int sample = item_ct1.get_group (0 );
10
- int channel = item_ct1.get_group (1 );
11
- int row = item_ct1.get_group (2 );
8
+ const int sample = item_ct1.get_group (0 );
9
+ const int channel = item_ct1.get_group (1 );
10
+ const int row = item_ct1.get_group (2 );
12
11
13
- int tid = item_ct1.get_local_id (2 );
12
+ const int tid = item_ct1.get_local_id (2 );
14
13
15
- x += sample * stride_sample + channel * stride_channel + row * stride_row;
14
+ x += sample * stride_sample + channel * stride_channel + row * stride_row;
16
15
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
17
16
18
17
sycl::float2 mean_var{0 .f , 0 .f };
@@ -132,17 +131,25 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
132
131
}
133
132
}
134
133
135
- static void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps,
136
- const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
137
- const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
138
- item_ct1.get_local_id (1 );
134
+ static void rms_norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
135
+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
136
+
137
+ const int nrows = item_ct1.get_group_range (2 );
138
+ const int nchannels = item_ct1.get_group_range (1 );
139
+ const int sample = item_ct1.get_group (0 );
140
+ const int channel = item_ct1.get_group (1 );
141
+ const int row = item_ct1.get_group (2 );
142
+
139
143
const int tid = item_ct1.get_local_id (2 );
140
- const int nthreads = item_ct1.get_local_range (2 );
141
- const int nwarps = nthreads / WARP_SIZE;
144
+
145
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
146
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
147
+
148
+
142
149
float tmp = 0 .0f ; // partial sum for thread in warp
143
150
144
151
for (int col = tid; col < ncols; col += block_size) {
145
- const float xi = x[row * ncols + col];
152
+ const float xi = x[col];
146
153
tmp += xi * xi;
147
154
}
148
155
@@ -155,25 +162,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
155
162
if (lane_id == 0 ) {
156
163
s_sum[warp_id] = tmp;
157
164
}
158
- /*
159
- DPCT1118:3: SYCL group functions and algorithms must be encountered in
160
- converged control flow. You may need to adjust the code.
161
- */
165
+
162
166
item_ct1.barrier (sycl::access::fence_space::local_space);
163
- size_t nreduce = nwarps / WARP_SIZE;
164
- tmp = 0 .f ;
165
- for (size_t i = 0 ; i < nreduce; i += 1 )
166
- {
167
- tmp += s_sum[lane_id + i * WARP_SIZE];
168
- }
167
+ tmp = s_sum[lane_id];
169
168
tmp = warp_reduce_sum (tmp, item_ct1);
170
169
}
171
170
172
171
const float mean = tmp / ncols;
173
172
const float scale = sycl::rsqrt (mean + eps);
174
173
175
174
for (int col = tid; col < ncols; col += block_size) {
176
- dst[row * ncols + col] = scale * x[row * ncols + col];
175
+ dst[col] = scale * x[col];
177
176
}
178
177
}
179
178
@@ -307,21 +306,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
307
306
}
308
307
}
309
308
310
- static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols,
311
- const int nrows, const float eps,
312
- queue_ptr stream, int device) {
309
+ static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
310
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
313
311
GGML_ASSERT (ncols % WARP_SIZE == 0 );
314
312
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
313
+
314
+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
315
315
if (ncols < 1024 ) {
316
316
const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
317
317
stream->submit ([&](sycl::handler& cgh) {
318
318
cgh.parallel_for (
319
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
320
- block_dims),
319
+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
321
320
[=](sycl::nd_item<3 > item_ct1)
322
321
[[sycl::reqd_sub_group_size (WARP_SIZE)]] {
323
- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
324
- nullptr , WARP_SIZE);
322
+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
325
323
});
326
324
});
327
325
}
@@ -338,12 +336,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
338
336
sycl::local_accessor<float , 1 > s_sum_acc_ct1 (sycl::range<1 >(work_group_size / WARP_SIZE),
339
337
cgh);
340
338
cgh.parallel_for (
341
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
342
- block_dims),
339
+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
343
340
[=](sycl::nd_item<3 > item_ct1)
344
341
[[sycl::reqd_sub_group_size (WARP_SIZE)]] {
345
- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
346
- get_pointer (s_sum_acc_ct1), work_group_size);
342
+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
347
343
});
348
344
});
349
345
}
@@ -436,11 +432,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
436
432
437
433
void ggml_sycl_op_rms_norm (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
438
434
435
+ const ggml_tensor * src0 = dst->src [0 ];
439
436
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
440
437
GGML_ASSERT (dst->type == GGML_TYPE_F32);
441
438
442
- const int64_t ne00 = dst->src [0 ]->ne [0 ];
443
- const int64_t nrows = ggml_nrows (dst->src [0 ]);
444
439
dpct::queue_ptr main_stream = ctx.stream ();
445
440
SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
446
441
@@ -450,7 +445,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
450
445
float eps;
451
446
memcpy (&eps, dst->op_params , sizeof (float ));
452
447
453
- rms_norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
448
+ GGML_TENSOR_UNARY_OP_LOCALS
449
+ const size_t ts0 = ggml_type_size (src0->type );
450
+ GGML_ASSERT (nb00 == ts0);
451
+ const int64_t s01 = nb01 / ts0;
452
+ const int64_t s02 = nb02 / ts0;
453
+ const int64_t s03 = nb03 / ts0;
454
+ rms_norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
454
455
}
455
456
456
457
void ggml_sycl_op_l2_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
0 commit comments