Skip to content

Commit 1c19b36

Browse files
committed
refactor and add RMS_NORM non contiguous input support
ggml-ci
1 parent dfaadfa commit 1c19b36

File tree

2 files changed

+40
-39
lines changed

2 files changed

+40
-39
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4160,8 +4160,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
41604160
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
41614161
#endif
41624162
case GGML_OP_NORM:
4163-
return true;
41644163
case GGML_OP_RMS_NORM:
4164+
return true;
41654165
case GGML_OP_L2_NORM:
41664166
case GGML_OP_GROUP_NORM:
41674167
return ggml_is_contiguous(op->src[0]);

ggml/src/ggml-sycl/norm.cpp

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
#include "norm.hpp"
2-
#include "ggml-sycl/presets.hpp"
32

43
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
54
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
65

76
const int nrows = item_ct1.get_group_range(2);
87
const int nchannels = item_ct1.get_group_range(1);
9-
int sample = item_ct1.get_group(0);
10-
int channel = item_ct1.get_group(1);
11-
int row = item_ct1.get_group(2);
8+
const int sample = item_ct1.get_group(0);
9+
const int channel = item_ct1.get_group(1);
10+
const int row = item_ct1.get_group(2);
1211

13-
int tid = item_ct1.get_local_id(2);
12+
const int tid = item_ct1.get_local_id(2);
1413

15-
x += sample * stride_sample + channel * stride_channel + row * stride_row;
14+
x += sample * stride_sample + channel * stride_channel + row * stride_row;
1615
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
1716

1817
sycl::float2 mean_var{0.f, 0.f};
@@ -132,17 +131,25 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
132131
}
133132
}
134133

135-
static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
136-
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
137-
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
138-
item_ct1.get_local_id(1);
134+
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
135+
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
136+
137+
const int nrows = item_ct1.get_group_range(2);
138+
const int nchannels = item_ct1.get_group_range(1);
139+
const int sample = item_ct1.get_group(0);
140+
const int channel = item_ct1.get_group(1);
141+
const int row = item_ct1.get_group(2);
142+
139143
const int tid = item_ct1.get_local_id(2);
140-
const int nthreads = item_ct1.get_local_range(2);
141-
const int nwarps = nthreads / WARP_SIZE;
144+
145+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
146+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
147+
148+
142149
float tmp = 0.0f; // partial sum for thread in warp
143150

144151
for (int col = tid; col < ncols; col += block_size) {
145-
const float xi = x[row * ncols + col];
152+
const float xi = x[col];
146153
tmp += xi * xi;
147154
}
148155

@@ -155,25 +162,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
155162
if (lane_id == 0) {
156163
s_sum[warp_id] = tmp;
157164
}
158-
/*
159-
DPCT1118:3: SYCL group functions and algorithms must be encountered in
160-
converged control flow. You may need to adjust the code.
161-
*/
165+
162166
item_ct1.barrier(sycl::access::fence_space::local_space);
163-
size_t nreduce = nwarps / WARP_SIZE;
164-
tmp = 0.f;
165-
for (size_t i = 0; i < nreduce; i += 1)
166-
{
167-
tmp += s_sum[lane_id + i * WARP_SIZE];
168-
}
167+
tmp = s_sum[lane_id];
169168
tmp = warp_reduce_sum(tmp, item_ct1);
170169
}
171170

172171
const float mean = tmp / ncols;
173172
const float scale = sycl::rsqrt(mean + eps);
174173

175174
for (int col = tid; col < ncols; col += block_size) {
176-
dst[row * ncols + col] = scale * x[row * ncols + col];
175+
dst[col] = scale * x[col];
177176
}
178177
}
179178

@@ -307,21 +306,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
307306
}
308307
}
309308

310-
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
311-
const int nrows, const float eps,
312-
queue_ptr stream, int device) {
309+
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
310+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
313311
GGML_ASSERT(ncols % WARP_SIZE == 0);
314312
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
313+
314+
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
315315
if (ncols < 1024) {
316316
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
317317
stream->submit([&](sycl::handler& cgh) {
318318
cgh.parallel_for(
319-
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
320-
block_dims),
319+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
321320
[=](sycl::nd_item<3> item_ct1)
322321
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
323-
rms_norm_f32(x, dst, ncols, eps, item_ct1,
324-
nullptr, WARP_SIZE);
322+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
325323
});
326324
});
327325
}
@@ -338,12 +336,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
338336
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
339337
cgh);
340338
cgh.parallel_for(
341-
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
342-
block_dims),
339+
sycl::nd_range<3>(global_dims * block_dims, block_dims),
343340
[=](sycl::nd_item<3> item_ct1)
344341
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
345-
rms_norm_f32(x, dst, ncols, eps, item_ct1,
346-
get_pointer(s_sum_acc_ct1), work_group_size);
342+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
347343
});
348344
});
349345
}
@@ -436,11 +432,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
436432

437433
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
438434

435+
const ggml_tensor * src0 = dst->src[0];
439436
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
440437
GGML_ASSERT(dst->type == GGML_TYPE_F32);
441438

442-
const int64_t ne00 = dst->src[0]->ne[0];
443-
const int64_t nrows = ggml_nrows(dst->src[0]);
444439
dpct::queue_ptr main_stream = ctx.stream();
445440
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
446441

@@ -450,7 +445,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
450445
float eps;
451446
memcpy(&eps, dst->op_params, sizeof(float));
452447

453-
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
448+
GGML_TENSOR_UNARY_OP_LOCALS
449+
const size_t ts0 = ggml_type_size(src0->type);
450+
GGML_ASSERT(nb00 == ts0);
451+
const int64_t s01 = nb01 / ts0;
452+
const int64_t s02 = nb02 / ts0;
453+
const int64_t s03 = nb03 / ts0;
454+
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
454455
}
455456

456457
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {

0 commit comments

Comments
 (0)