Skip to content

Commit 7f65c00

Browse files
committed
Use a common function to calculate offset
1 parent 6dab4bf commit 7f65c00

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef GGML_SYCL_COMMON_HPP
1414
#define GGML_SYCL_COMMON_HPP
1515

16+
#include <cstddef>
1617
#include <fstream>
1718
#include <iostream>
1819

@@ -471,6 +472,23 @@ static __dpct_inline__ float warp_reduce_max(float x,
471472
return x;
472473
}
473474

475+
/* Helper for Computing the linear offset into an 4-dimensional ggml_tensor given
476+
per-dimension sizes, strides, and indices */
477+
template<int N>
478+
static __dpct_inline__ size_t calculate_offset(const std::array<int, N> & dims, const std::array<int, N> & strides, const std::array<int, N> & indices) {
479+
size_t offset = 0;
480+
#pragma unroll
481+
for (int i = 0; i < N; i++) {
482+
auto index_i = indices[i];
483+
// Handle wrap-around for indices that exceed dimensions
484+
if (indices[i] >= dims[i]) {
485+
index_i = indices[i] % dims[i];
486+
}
487+
offset += strides[i] * index_i;
488+
}
489+
return offset;
490+
}
491+
474492
// Helper for vec loading aligned data
475493
template <typename Tp, int n>
476494
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {

ggml/src/ggml-sycl/norm.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#include "norm.hpp"
2+
#include "ggml-sycl/common.hpp"
23

34
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
45
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
56

67
const int nrows = item_ct1.get_group_range(2);
78
const int nchannels = item_ct1.get_group_range(1);
9+
const int nsamples = item_ct1.get_group_range(0);
10+
811
const int nthreads = item_ct1.get_local_range(2);
912
const int sample = item_ct1.get_group(0);
1013
const int channel = item_ct1.get_group(1);
@@ -13,8 +16,11 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
1316
const int tid = item_ct1.get_local_id(2);
1417
const int nwarps = nthreads / WARP_SIZE;
1518

16-
x += sample * stride_sample + channel * stride_channel + row * stride_row;
17-
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
19+
const auto strided_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
20+
const auto packed_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21+
22+
x += strided_offset;
23+
dst += packed_offset;
1824

1925
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
2026

@@ -144,16 +150,22 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
144150

145151
const int nrows = item_ct1.get_group_range(2);
146152
const int nchannels = item_ct1.get_group_range(1);
153+
const int nsamples = item_ct1.get_group_range(0);
154+
147155
const int sample = item_ct1.get_group(0);
148156
const int channel = item_ct1.get_group(1);
149157
const int row = item_ct1.get_group(2);
158+
150159
const int nthreads = item_ct1.get_local_range(2);
151160

152161
const int tid = item_ct1.get_local_id(2);
153162
const int nwarps = nthreads / WARP_SIZE;
154163

155-
x += sample*stride_sample + channel*stride_channel + row*stride_row;
156-
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
164+
const auto strided_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {stride_sample, stride_channel, stride_row}, {sample, channel, row});
165+
const auto packed_offset = calculate_offset<3>({nsamples, nchannels, nrows}, {nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
166+
167+
x += strided_offset;
168+
dst += packed_offset;
157169

158170

159171
float tmp = 0.0f; // partial sum for thread in warp

0 commit comments

Comments
 (0)