@@ -892,117 +892,6 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
892
892
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
893
893
}
894
894
895
-
896
- template <bool vals_smem, int ncols_template, int block_size_template>
897
- static void soft_max_f32 (const float * x, const float * mask, float * dst, const int ncols_par,
898
- const int nrows_y, const float scale, const float max_bias, const float m0,
899
- const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
900
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
901
-
902
- const int tid = item_ct1.get_local_id (2 );
903
- const int rowx = item_ct1.get_group (2 );
904
- const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
905
-
906
- const int block_size = block_size_template == 0 ? item_ct1.get_local_range (2 ) : block_size_template;
907
-
908
- const int warp_id = item_ct1.get_local_id (2 ) / WARP_SIZE;
909
- const int lane_id = item_ct1.get_local_id (2 ) % WARP_SIZE;
910
-
911
- float slope = 1 .0f ;
912
-
913
- // ALiBi
914
- if (max_bias > 0 .0f ) {
915
- const uint32_t h = rowx/nrows_y; // head index
916
-
917
- const float base = h < n_head_log2 ? m0 : m1;
918
- const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
919
-
920
- slope = sycl::pow (base, float (exp));
921
- }
922
-
923
- float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
924
- float max_val = -INFINITY;
925
-
926
- for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
927
- const int col = col0 + tid;
928
-
929
- if (ncols_template == 0 && col >= ncols) {
930
- break ;
931
- }
932
-
933
- const int ix = rowx*ncols + col;
934
- const int iy = rowy*ncols + col;
935
-
936
- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0 .0f );
937
-
938
- vals[col] = val;
939
- max_val = sycl::max (max_val, val);
940
- }
941
-
942
- // find the max value in the block
943
- max_val = warp_reduce_max (max_val, item_ct1);
944
- if (block_size > WARP_SIZE) {
945
- if (warp_id == 0 ) {
946
- buf[lane_id] = -INFINITY;
947
- }
948
- item_ct1.barrier (sycl::access::fence_space::local_space);
949
-
950
- if (lane_id == 0 ) {
951
- buf[warp_id] = max_val;
952
- }
953
- item_ct1.barrier (sycl::access::fence_space::local_space);
954
-
955
- max_val = buf[lane_id];
956
- max_val = warp_reduce_max (max_val, item_ct1);
957
- }
958
-
959
- float tmp = 0 .f ;
960
-
961
- #pragma unroll
962
- for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
963
- const int col = col0 + tid;
964
- if (ncols_template == 0 && col >= ncols) {
965
- break ;
966
- }
967
-
968
- const float val = sycl::native::exp (vals[col] - max_val);
969
- tmp += val;
970
- vals[col] = val;
971
- }
972
-
973
- // find the sum of exps in the block
974
- tmp = warp_reduce_sum (tmp, item_ct1);
975
- if (block_size > WARP_SIZE) {
976
- item_ct1.barrier (sycl::access::fence_space::local_space);
977
- if (warp_id == 0 ) {
978
- buf[lane_id] = 0 .f ;
979
- }
980
- item_ct1.barrier (sycl::access::fence_space::local_space);
981
-
982
- if (lane_id == 0 ) {
983
- buf[warp_id] = tmp;
984
- }
985
- item_ct1.barrier (sycl::access::fence_space::local_space);
986
-
987
- tmp = buf[lane_id];
988
- tmp = warp_reduce_sum (tmp, item_ct1);
989
- }
990
-
991
- const float inv_sum = 1 .f / tmp;
992
-
993
- #pragma unroll
994
- for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
995
- const int col = col0 + tid;
996
-
997
- if (ncols_template == 0 && col >= ncols) {
998
- return ;
999
- }
1000
-
1001
- const int idst = rowx*ncols + col;
1002
- dst[idst] = vals[col] * inv_sum;
1003
- }
1004
- }
1005
-
1006
895
static void scale_f32 (const float * x, float * dst, const float scale, const int k,
1007
896
const sycl::nd_item<3 > &item_ct1) {
1008
897
const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
@@ -1890,106 +1779,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
1890
1779
});
1891
1780
}
1892
1781
1893
- template <bool vals_smem, int ncols_template, int block_size_template>
1894
- static void soft_max_f32_submitter (const float * x, const float * mask, float * dst, const int ncols_par,
1895
- const int nrows_y, const float scale, const float max_bias, const float m0,
1896
- const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
1897
- const size_t n_local_scratch, queue_ptr stream) {
1898
- stream->submit ([&](sycl::handler &cgh) {
1899
- sycl::local_accessor<float , 1 > local_buf_acc (n_local_scratch, cgh);
1900
-
1901
- cgh.parallel_for (
1902
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
1903
- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE)]] {
1904
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
1905
- nrows_y, scale, max_bias, m0,
1906
- m1, n_head_log2, item_ct1,
1907
- local_buf_acc.get_pointer ());
1908
- });
1909
- });
1910
- }
1911
-
1912
- static void soft_max_f32_sycl (const float * x, const float * mask,
1913
- float * dst, const int ncols_x, const int nrows_x,
1914
- const int nrows_y, const float scale, const float max_bias,
1915
- queue_ptr stream, int device) {
1916
- int nth = WARP_SIZE;
1917
- int max_block_size = ggml_sycl_info ().max_work_group_sizes [device];
1918
- while (nth < ncols_x && nth < max_block_size) nth *= 2 ;
1919
- if (nth>max_block_size) nth = max_block_size;
1920
-
1921
- const sycl::range<3 > block_dims (1 , 1 , nth);
1922
- const sycl::range<3 > block_nums (1 , 1 , nrows_x);
1923
- const size_t n_local_scratch = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE);
1924
-
1925
- const uint32_t n_head_kv = nrows_x/nrows_y;
1926
- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head_kv));
1927
-
1928
- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
1929
- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
1930
-
1931
- const size_t local_mem_size = stream->get_device ().get_info <sycl::info::device::local_mem_size>();
1932
- if (n_local_scratch*sizeof (float ) < local_mem_size) {
1933
- if (ncols_x > max_block_size) {
1934
- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
1935
- max_bias, m0, m1, n_head_log2, block_nums,
1936
- block_dims, n_local_scratch, stream);
1937
- return ;
1938
- }
1939
- switch (ncols_x) {
1940
- case 32 :
1941
- soft_max_f32_submitter<true , 32 , 32 >(x, mask, dst, ncols_x, nrows_y, scale,
1942
- max_bias, m0, m1, n_head_log2, block_nums,
1943
- block_dims, n_local_scratch, stream);
1944
- break ;
1945
- case 64 :
1946
- soft_max_f32_submitter<true , 64 , 64 >(x, mask, dst, ncols_x, nrows_y, scale,
1947
- max_bias, m0, m1, n_head_log2, block_nums,
1948
- block_dims, n_local_scratch, stream);
1949
- break ;
1950
- case 128 :
1951
- soft_max_f32_submitter<true , 128 , 128 >(x, mask, dst, ncols_x, nrows_y, scale,
1952
- max_bias, m0, m1, n_head_log2, block_nums,
1953
- block_dims, n_local_scratch, stream);
1954
- break ;
1955
- case 256 :
1956
- soft_max_f32_submitter<true , 256 , 256 >(x, mask, dst, ncols_x, nrows_y, scale,
1957
- max_bias, m0, m1, n_head_log2, block_nums,
1958
- block_dims, n_local_scratch, stream);
1959
- break ;
1960
- case 512 :
1961
- soft_max_f32_submitter<true , 512 , 512 >(x, mask, dst, ncols_x, nrows_y, scale,
1962
- max_bias, m0, m1, n_head_log2, block_nums,
1963
- block_dims, n_local_scratch, stream);
1964
- break ;
1965
- case 1024 :
1966
- soft_max_f32_submitter<true , 1024 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
1967
- max_bias, m0, m1, n_head_log2, block_nums,
1968
- block_dims, n_local_scratch, stream);
1969
- break ;
1970
- case 2048 :
1971
- soft_max_f32_submitter<true , 2048 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
1972
- max_bias, m0, m1, n_head_log2, block_nums,
1973
- block_dims, n_local_scratch, stream);
1974
- break ;
1975
- case 4096 :
1976
- soft_max_f32_submitter<true , 4096 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
1977
- max_bias, m0, m1, n_head_log2, block_nums,
1978
- block_dims, n_local_scratch, stream);
1979
- break ;
1980
- default :
1981
- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
1982
- max_bias, m0, m1, n_head_log2, block_nums,
1983
- block_dims, n_local_scratch, stream);
1984
- break ;
1985
- }
1986
- } else {
1987
- soft_max_f32_submitter<false , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
1988
- max_bias, m0, m1, n_head_log2, block_nums,
1989
- block_dims, WARP_SIZE, stream);
1990
- }
1991
- }
1992
-
1993
1782
template <typename T>
1994
1783
static void im2col_sycl (const float *x, T *dst, int IW, int IH,
1995
1784
int OW, int OH, int KW, int KH, int IC,
@@ -3009,33 +2798,6 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const gg
3009
2798
(void ) src1_dd;
3010
2799
}
3011
2800
3012
- inline void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3013
- const ggml_tensor *src1, ggml_tensor *dst,
3014
- const float *src0_dd, const float *src1_dd,
3015
- float *dst_dd,
3016
- const queue_ptr &main_stream) {
3017
-
3018
- GGML_ASSERT (src0->type == GGML_TYPE_F32);
3019
- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
3020
-
3021
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
3022
- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
3023
- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
3024
-
3025
- const int64_t ne00 = src0->ne [0 ];
3026
- const int64_t nrows_x = ggml_nrows (src0);
3027
- const int64_t nrows_y = src0->ne [1 ];
3028
-
3029
- float scale = 1 .0f ;
3030
- float max_bias = 0 .0f ;
3031
-
3032
- memcpy (&scale, dst->op_params + 0 , sizeof (float ));
3033
- memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
3034
-
3035
- soft_max_f32_sycl (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00,
3036
- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
3037
- }
3038
-
3039
2801
inline void ggml_sycl_op_scale (ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3040
2802
ggml_tensor *dst, const float *src0_dd,
3041
2803
const float *src1_dd, float *dst_dd,
@@ -5532,7 +5294,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
5532
5294
case GGML_OP_CONCAT:
5533
5295
{
5534
5296
ggml_type src0_type = op->src [0 ]->type ;
5535
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
5297
+ int dim = op->op_params [0 ];
5298
+ return ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2 ;
5536
5299
} break ;
5537
5300
case GGML_OP_DUP:
5538
5301
case GGML_OP_NONE:
0 commit comments