Skip to content

Commit aeaed61

Browse files
authored
Merge pull request #1 from arthw/update_warp
[SYCL] Fix WARP_SIZE=16 bug of Intel GPU (ggml-org#8266) cherry-pick b549a1b
2 parents c5009e6 + 74e3185 commit aeaed61

File tree

9 files changed

+203
-70
lines changed

9 files changed

+203
-70
lines changed

ggml/src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ if (GGML_SYCL)
490490
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
491491
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
492492
else()
493-
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
493+
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
494494
endif()
495495

496496
file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp")

ggml/src/ggml-sycl.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
906906

907907
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
908908
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
909+
const int nthreads = block_size;
910+
const int nwarps = nthreads / WARP_SIZE;
911+
int nreduce = nwarps / WARP_SIZE;
912+
909913

910914
float slope = 1.0f;
911915

@@ -919,7 +923,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
919923
slope = sycl::pow(base, float(exp));
920924
}
921925

922-
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
926+
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
923927
float max_val = -INFINITY;
924928

925929
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -943,6 +947,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
943947
if (block_size > WARP_SIZE) {
944948
if (warp_id == 0) {
945949
buf[lane_id] = -INFINITY;
950+
for (size_t i = 1; i < nreduce; i += 1)
951+
buf[lane_id + i * WARP_SIZE] = -INFINITY;
952+
946953
}
947954
item_ct1.barrier(sycl::access::fence_space::local_space);
948955

@@ -952,6 +959,11 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
952959
item_ct1.barrier(sycl::access::fence_space::local_space);
953960

954961
max_val = buf[lane_id];
962+
for (size_t i = 1; i < nreduce; i += 1)
963+
{
964+
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
965+
}
966+
955967
max_val = warp_reduce_max(max_val, item_ct1);
956968
}
957969

@@ -975,6 +987,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
975987
item_ct1.barrier(sycl::access::fence_space::local_space);
976988
if (warp_id == 0) {
977989
buf[lane_id] = 0.f;
990+
for (size_t i = 1; i < nreduce; i += 1)
991+
buf[lane_id + i * WARP_SIZE] = 0.f;
992+
978993
}
979994
item_ct1.barrier(sycl::access::fence_space::local_space);
980995

@@ -984,6 +999,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
984999
item_ct1.barrier(sycl::access::fence_space::local_space);
9851000

9861001
tmp = buf[lane_id];
1002+
for (size_t i = 1; i < nreduce; i += 1)
1003+
{
1004+
tmp += buf[lane_id + i * WARP_SIZE];
1005+
}
9871006
tmp = warp_reduce_sum(tmp, item_ct1);
9881007
}
9891008

ggml/src/ggml-sycl/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ void sycl_device_mgr::detect_all_sycl_device_list() try {
314314
dpct::get_device_info(prop, device);
315315
work_group_sizes.push_back(prop.get_max_work_group_size());
316316
max_compute_units.push_back(prop.get_max_compute_units());
317+
hw_familys.push_back(get_device_family(&device));
317318
}
318319
return;
319320
} catch (sycl::exception const &exc) {
@@ -498,4 +499,8 @@ int ggml_sycl_device_info::get_device_id(int device_index) {
498499
}
499500
}
500501

502+
int ggml_sycl_device_info::hw_family(int device_id) {
503+
return device_mgr->hw_familys[device_id];
504+
}
505+
501506
//--ggml_sycl_device_info--

ggml/src/ggml-sycl/common.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "dpct/helper.hpp"
2121
#include "ggml-sycl.h"
2222
#include "presets.hpp"
23+
#include "sycl_hw.hpp"
2324

2425
#define GGML_COMMON_DECL_SYCL
2526
#define GGML_COMMON_IMPL_SYCL
@@ -188,6 +189,8 @@ class sycl_device_mgr {
188189
std::vector<sycl::device> devices;
189190
std::vector<int> max_compute_units;
190191
std::vector<int> work_group_sizes;
192+
std::vector<int> hw_familys;
193+
191194
sycl::queue *first_queue;
192195
std::vector<sycl::queue> _queues;
193196
std::vector<sycl::context> ctxs;
@@ -236,6 +239,7 @@ struct ggml_sycl_device_info {
236239
bool is_allowed_device(int device_id);
237240
const char* devices_list();
238241
int get_device_id(int device_index);
242+
int hw_family(int device_id);
239243
};
240244

241245
struct ggml_sycl_pool {

0 commit comments

Comments
 (0)