1
1
#include " norm.hpp"
2
+ #include " ggml-sycl/common.hpp"
2
3
3
4
static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
4
5
const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
5
6
6
7
const int nrows = item_ct1.get_group_range (2 );
7
8
const int nchannels = item_ct1.get_group_range (1 );
9
+ const int nsamples = item_ct1.get_group_range (0 );
10
+
8
11
const int nthreads = item_ct1.get_local_range (2 );
9
12
const int sample = item_ct1.get_group (0 );
10
13
const int channel = item_ct1.get_group (1 );
@@ -13,8 +16,11 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
13
16
const int tid = item_ct1.get_local_id (2 );
14
17
const int nwarps = nthreads / WARP_SIZE;
15
18
16
- x += sample * stride_sample + channel * stride_channel + row * stride_row;
17
- dst += ((sample * nchannels + channel) * nrows + row) * ncols;
19
+ const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
20
+ const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21
+
22
+ x += strided_offset;
23
+ dst += packed_offset;
18
24
19
25
sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
20
26
@@ -144,16 +150,22 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
144
150
145
151
const int nrows = item_ct1.get_group_range (2 );
146
152
const int nchannels = item_ct1.get_group_range (1 );
153
+ const int nsamples = item_ct1.get_group_range (0 );
154
+
147
155
const int sample = item_ct1.get_group (0 );
148
156
const int channel = item_ct1.get_group (1 );
149
157
const int row = item_ct1.get_group (2 );
158
+
150
159
const int nthreads = item_ct1.get_local_range (2 );
151
160
152
161
const int tid = item_ct1.get_local_id (2 );
153
162
const int nwarps = nthreads / WARP_SIZE;
154
163
155
- x += sample*stride_sample + channel*stride_channel + row*stride_row;
156
- dst += ((sample*nchannels + channel)*nrows + row)*ncols;
164
+ const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
165
+ const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
166
+
167
+ x += strided_offset;
168
+ dst += packed_offset;
157
169
158
170
159
171
float tmp = 0 .0f ; // partial sum for thread in warp
0 commit comments