Skip to content

Commit 22dbcdf

Browse files
committed
remove wrap around logic for handling broadcasts
1 parent 7f65c00 commit 22dbcdf

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,15 +475,11 @@ static __dpct_inline__ float warp_reduce_max(float x,
475475
/* Helper for Computing the linear offset into an 4-dimensional ggml_tensor given
476476
per-dimension sizes, strides, and indices */
477477
template<int N>
478-
static __dpct_inline__ size_t calculate_offset(const std::array<int, N> & dims, const std::array<int, N> & strides, const std::array<int, N> & indices) {
478+
static __dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) {
479479
size_t offset = 0;
480480
#pragma unroll
481481
for (int i = 0; i < N; i++) {
482482
auto index_i = indices[i];
483-
// Handle wrap-around for indices that exceed dimensions
484-
if (indices[i] >= dims[i]) {
485-
index_i = indices[i] % dims[i];
486-
}
487483
offset += strides[i] * index_i;
488484
}
489485
return offset;

ggml/src/ggml-sycl/norm.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
66

77
const int nrows = item_ct1.get_group_range(2);
88
const int nchannels = item_ct1.get_group_range(1);
9-
const int nsamples = item_ct1.get_group_range(0);
109

1110
const int nthreads = item_ct1.get_local_range(2);
1211
const int sample = item_ct1.get_group(0);
@@ -16,8 +15,8 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
1615
const int tid = item_ct1.get_local_id(2);
1716
const int nwarps = nthreads / WARP_SIZE;
1817

19-
const auto strided_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
20-
const auto packed_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
18+
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
19+
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
2120

2221
x += strided_offset;
2322
dst += packed_offset;
@@ -150,7 +149,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
150149

151150
const int nrows = item_ct1.get_group_range(2);
152151
const int nchannels = item_ct1.get_group_range(1);
153-
const int nsamples = item_ct1.get_group_range(0);
154152

155153
const int sample = item_ct1.get_group(0);
156154
const int channel = item_ct1.get_group(1);
@@ -161,8 +159,8 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
161159
const int tid = item_ct1.get_local_id(2);
162160
const int nwarps = nthreads / WARP_SIZE;
163161

164-
const auto strided_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
165-
const auto packed_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
162+
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
163+
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
166164

167165
x += strided_offset;
168166
dst += packed_offset;

0 commit comments

Comments
 (0)