@@ -6,7 +6,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
6
6
7
7
const int nrows = item_ct1.get_group_range (2 );
8
8
const int nchannels = item_ct1.get_group_range (1 );
9
- const int nsamples = item_ct1.get_group_range (0 );
10
9
11
10
const int nthreads = item_ct1.get_local_range (2 );
12
11
const int sample = item_ct1.get_group (0 );
@@ -16,8 +15,8 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
16
15
const int tid = item_ct1.get_local_id (2 );
17
16
const int nwarps = nthreads / WARP_SIZE;
18
17
19
- const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { stride_sample, stride_channel, stride_row}, {sample, channel, row});
20
- const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
18
+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
19
+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21
20
22
21
x += strided_offset;
23
22
dst += packed_offset;
@@ -150,7 +149,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
150
149
151
150
const int nrows = item_ct1.get_group_range (2 );
152
151
const int nchannels = item_ct1.get_group_range (1 );
153
- const int nsamples = item_ct1.get_group_range (0 );
154
152
155
153
const int sample = item_ct1.get_group (0 );
156
154
const int channel = item_ct1.get_group (1 );
@@ -161,8 +159,8 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
161
159
const int tid = item_ct1.get_local_id (2 );
162
160
const int nwarps = nthreads / WARP_SIZE;
163
161
164
- const auto strided_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { stride_sample, stride_channel, stride_row}, {sample, channel, row});
165
- const auto packed_offset = calculate_offset<3 >({nsamples, nchannels, nrows}, { nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
162
+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
163
+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
166
164
167
165
x += strided_offset;
168
166
dst += packed_offset;
0 commit comments