|
| 1 | +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 |
| 2 | +#define USE_CUB |
| 3 | +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 |
| 4 | + |
| 5 | +#ifdef USE_CUB |
| 6 | +#include <cub/cub.cuh> |
| 7 | +using namespace cub; |
| 8 | +#endif // USE_CUB |
| 9 | + |
1 | 10 | #include "ssm-scan.cuh"
|
2 | 11 |
|
3 | 12 | template <size_t splitD, size_t N>
|
4 | 13 | __global__ void __launch_bounds__(splitD, 2)
|
5 |
| - ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, |
6 |
| - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, |
7 |
| - const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, |
8 |
| - const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, |
| 14 | + ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, |
| 15 | + const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, |
| 16 | + const int src0_nb1, const int src0_nb2, const int src1_nb1, const int src1_nb2, |
| 17 | + const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, |
9 | 18 | const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
|
10 |
| - float * __restrict__ dst, const int64_t L) { |
11 |
| - GGML_UNUSED(src1_nb0); |
12 |
| - GGML_UNUSED(src2_nb0); |
13 |
| - const int bidx = blockIdx.x; // split along B |
14 |
| - const int bidy = blockIdx.y; // split along D |
15 |
| - const int tid = threadIdx.x; |
16 |
| - const int wid = tid / 32; |
17 |
| - const int wtid = tid % 32; |
18 |
| - |
19 |
| - extern __shared__ float smem[]; |
20 |
| - const int stride_sA = N + 1; |
21 |
| - const int stride_ss0 = N + 1; |
22 |
| - float * smem_A = smem; |
23 |
| - float * smem_s0 = smem_A + splitD * stride_sA; |
24 |
| - |
25 |
| - const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); |
26 |
| - const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); |
27 |
| - const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); |
28 |
| - const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); |
29 |
| - const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); |
30 |
| - const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); |
31 |
| - float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); |
32 |
| - float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); |
| 19 | + float *__restrict__ dst, const int64_t L) |
| 20 | +{ |
33 | 21 |
|
34 |
| - const int stride_s0 = src0_nb1 / sizeof(float); |
35 |
| - const int stride_x = src1_nb1 / sizeof(float); |
| 22 | + const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); |
| 23 | + const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 24 | + const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 25 | + const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); |
| 26 | + const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2)); |
| 27 | + const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2)); |
| 28 | + float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 29 | + float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); |
| 30 | + |
| 31 | + const int stride_x = src1_nb1 / sizeof(float); |
36 | 32 | const int stride_dt = src2_nb1 / sizeof(float);
|
37 |
| - const int stride_A = src3_nb1 / sizeof(float); |
38 |
| - const int stride_B = src4_nb1 / sizeof(float); |
39 |
| - const int stride_C = src5_nb1 / sizeof(float); |
40 |
| - const int stride_s = stride_s0; |
41 |
| - const int stride_y = stride_x; |
42 |
| - |
43 |
| - // can N not be 16? for example 32? |
44 |
| - if (N == 16) { |
| 33 | + const int stride_B = src4_nb1 / sizeof(float); |
| 34 | + const int stride_C = src5_nb1 / sizeof(float); |
| 35 | + const int stride_y = stride_x; |
| 36 | + |
| 37 | + float regA[N]; |
| 38 | + float regs0[N]; |
| 39 | + |
| 40 | + __shared__ float smemB[N]; |
| 41 | + __shared__ float smemC[N]; |
| 42 | + |
| 43 | +#ifdef USE_CUB |
| 44 | + using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; |
| 45 | + using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; |
| 46 | + using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>; |
| 47 | + |
| 48 | + __shared__ typename BlockLoadA::TempStorage block_load_tempA; |
| 49 | + __shared__ typename BlockLoadS0::TempStorage block_load_tempS0; |
| 50 | + __shared__ typename BlockStoreS::TempStorage block_store_tempS; |
| 51 | + |
| 52 | + BlockLoadA(block_load_tempA).Load(A_block, regA); |
| 53 | + BlockLoadS0(block_load_tempS0).Load(s0_block, regs0); |
| 54 | +#else |
| 55 | + const int stride_s0 = src0_nb1 / sizeof(float); |
| 56 | + const int stride_A = src3_nb1 / sizeof(float); |
45 | 57 | #pragma unroll
|
46 |
| - for (size_t i = 0; i < splitD / 4; i += 2) { |
47 |
| - float value = A_block[(wid * warpSize + i) * stride_A + wtid]; |
48 |
| - // todo: bank conflict |
49 |
| - // I am always confused with how to use the swizzling method to solve |
50 |
| - // bank conflit. Hoping somebody can tell me. |
51 |
| - smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; |
| 58 | + for (int j = 0; j < N; ++j) |
| 59 | + { |
| 60 | + regA[j] = A_block[threadIdx.x * stride_A + j]; |
| 61 | + regs0[j] = s0_block[threadIdx.x * stride_s0 + j]; |
| 62 | + } |
| 63 | +#endif |
| 64 | + |
| 65 | + for (int i = 0; i < L; i++) |
| 66 | + { |
| 67 | + if (threadIdx.x < N) |
| 68 | + { |
| 69 | + smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x]; |
| 70 | + smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x]; |
| 71 | + } |
| 72 | + __syncthreads(); |
| 73 | + |
| 74 | + float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x]; |
| 75 | + if (dt_soft_plus <= 20.0f) |
| 76 | + { |
| 77 | + dt_soft_plus = log1pf(expf(dt_soft_plus)); |
52 | 78 | }
|
| 79 | + float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus; |
| 80 | + |
| 81 | + float sumf = 0.0f; |
53 | 82 | #pragma unroll
|
54 |
| - for (size_t i = 0; i < splitD / 4; i += 2) { |
55 |
| - float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; |
56 |
| - smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; |
| 83 | + for (int j = 0; j < N; j++) |
| 84 | + { |
| 85 | + float state = regs0[j] * expf(dt_soft_plus * regA[j]) + smemB[j] * x_dt; |
| 86 | + sumf += state * smemC[j]; |
| 87 | + regs0[j] = state; |
57 | 88 | }
|
| 89 | + y_block[i * stride_y + threadIdx.x] = sumf; |
| 90 | + } |
| 91 | + |
| 92 | +#ifdef USE_CUB |
| 93 | + BlockStoreS(block_store_tempS).Store(s_block, regs0); |
| 94 | +#else |
| 95 | + const int stride_s = stride_s0; |
| 96 | +#pragma unroll |
| 97 | + for (int j = 0; j < N; ++j) |
| 98 | + { |
| 99 | + s_block[threadIdx.x * stride_s + j] = regs0[j]; |
58 | 100 | }
|
| 101 | +#endif |
| 102 | +} |
| 103 | + |
| 104 | +template <size_t splitD, size_t N> |
| 105 | +__global__ void __launch_bounds__(splitD, 2) |
| 106 | + ssm_scan_single_step_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, |
| 107 | + const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, |
| 108 | + const int src0_nb1, const int src0_nb2, const int src1_nb2, |
| 109 | + const int src1_nb3, const int src2_nb2, const int src3_nb1, |
| 110 | + const int src4_nb2, const int src5_nb2, |
| 111 | + float *__restrict__ dst) |
| 112 | +{ |
| 113 | + const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); |
| 114 | + const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 115 | + const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 116 | + const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); |
| 117 | + const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2)); |
| 118 | + const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2)); |
| 119 | + float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 120 | + float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); |
| 121 | + |
| 122 | + float regA[N]; |
| 123 | + float regs0[N]; |
| 124 | + |
| 125 | + __shared__ float smemB[N]; |
| 126 | + __shared__ float smemC[N]; |
| 127 | + |
| 128 | +#ifdef USE_CUB |
| 129 | + using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; |
| 130 | + using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; |
| 131 | + using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>; |
| 132 | + |
| 133 | + __shared__ typename BlockLoadA::TempStorage block_load_tempA; |
| 134 | + __shared__ typename BlockLoadS0::TempStorage block_load_tempS0; |
| 135 | + __shared__ typename BlockStoreS::TempStorage block_store_tempS; |
| 136 | + |
| 137 | + BlockLoadA(block_load_tempA).Load(A_block, regA); |
| 138 | + BlockLoadS0(block_load_tempS0).Load(s0_block, regs0); |
| 139 | +#else |
| 140 | + const int stride_s0 = src0_nb1 / sizeof(float); |
| 141 | + const int stride_A = src3_nb1 / sizeof(float); |
| 142 | +#pragma unroll |
| 143 | + for (int j = 0; j < N; ++j) |
| 144 | + { |
| 145 | + regA[j] = A_block[threadIdx.x * stride_A + j]; |
| 146 | + regs0[j] = s0_block[threadIdx.x * stride_s0 + j]; |
| 147 | + } |
| 148 | +#endif |
59 | 149 |
|
| 150 | + if (threadIdx.x < N) |
| 151 | + { |
| 152 | + smemB[threadIdx.x] = B_block[threadIdx.x]; |
| 153 | + smemC[threadIdx.x] = C_block[threadIdx.x]; |
| 154 | + } |
60 | 155 | __syncthreads();
|
61 | 156 |
|
62 |
| - for (int64_t i = 0; i < L; i++) { |
63 |
| - float dt_soft_plus = dt_block[i * stride_dt + tid]; |
64 |
| - if (dt_soft_plus <= 20.0f) { |
65 |
| - dt_soft_plus = log1pf(exp(dt_soft_plus)); |
| 157 | + { |
| 158 | + float dt_soft_plus = dt_block[threadIdx.x]; |
| 159 | + if (dt_soft_plus <= 20.0f) |
| 160 | + { |
| 161 | + dt_soft_plus = log1pf(expf(dt_soft_plus)); |
66 | 162 | }
|
67 |
| - float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; |
| 163 | + float x_dt = x_block[threadIdx.x] * dt_soft_plus; |
68 | 164 | float sumf = 0.0f;
|
69 | 165 | #pragma unroll
|
70 |
| - for (size_t j = 0; j < N; j++) { |
71 |
| - float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + |
72 |
| - (B_block[i * stride_B + j] * x_dt); |
73 |
| - sumf += state * C_block[i * stride_C + j]; |
74 |
| - if (i == L - 1) { |
75 |
| - s_block[tid * stride_s + j] = state; |
76 |
| - } else { |
77 |
| - smem_s0[tid * stride_ss0 + j] = state; |
78 |
| - } |
| 166 | + for (int j = 0; j < N; j++) |
| 167 | + { |
| 168 | + float state = regs0[j] * expf(dt_soft_plus * regA[j]) + smemB[j] * x_dt; |
| 169 | + sumf += state * smemC[j]; |
| 170 | + regs0[j] = state; |
79 | 171 | }
|
80 |
| - __syncthreads(); |
81 |
| - y_block[i * stride_y + tid] = sumf; |
| 172 | + y_block[threadIdx.x] = sumf; |
| 173 | + } |
| 174 | + |
| 175 | +#ifdef USE_CUB |
| 176 | + BlockStoreS(block_store_tempS).Store(s_block, regs0); |
| 177 | +#else |
| 178 | + const int stride_s = s0; |
| 179 | +#pragma unroll |
| 180 | + for (int j = 0; j < N; ++j) |
| 181 | + { |
| 182 | + s_block[threadIdx.x * stride_s + j] = regs0[j]; |
82 | 183 | }
|
| 184 | +#endif |
83 | 185 | }
|
84 | 186 |
|
85 |
| -static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, |
86 |
| - const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, |
87 |
| - const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, |
88 |
| - const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, |
| 187 | +static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float *src2, const float *src3, |
| 188 | + const float *src4, const float *src5, const int src0_nb1, const int src0_nb2, |
| 189 | + const int src1_nb1, const int src1_nb2, const int src1_nb3, |
| 190 | + const int src2_nb1, const int src2_nb2, const int src3_nb1, |
89 | 191 | const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
|
90 |
| - float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, |
91 |
| - cudaStream_t stream) { |
| 192 | + float *dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, |
| 193 | + cudaStream_t stream) |
| 194 | +{ |
92 | 195 | const int threads = 128;
|
93 | 196 | // todo: consider D cannot be divided,does this situation exist?
|
94 | 197 | GGML_ASSERT(D % threads == 0);
|
95 | 198 | const dim3 blocks(B, (D + threads - 1) / threads, 1);
|
96 |
| - const int smem_size = (threads * (N + 1) * 2) * sizeof(float); |
97 |
| - if (N == 16) { |
98 |
| - ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( |
99 |
| - src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, |
100 |
| - src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); |
101 |
| - } else { |
| 199 | + if (N == 16) |
| 200 | + { |
| 201 | + if (L > 1) |
| 202 | + { |
| 203 | + ssm_scan_f32<threads, 16><<<blocks, threads, 0, stream>>>( |
| 204 | + src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3, |
| 205 | + src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); |
| 206 | + } |
| 207 | + else |
| 208 | + { |
| 209 | + ssm_scan_single_step_f32<threads, 16><<<blocks, threads, 0, stream>>>( |
| 210 | + src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb2, |
| 211 | + src1_nb3, src2_nb2, src3_nb1, |
| 212 | + src4_nb2, src5_nb2, |
| 213 | + dst); |
| 214 | + } |
| 215 | + } |
| 216 | + else |
| 217 | + { |
102 | 218 | GGML_ABORT("doesn't support N!=16.");
|
103 | 219 | }
|
104 | 220 | }
|
@@ -147,7 +263,7 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
147 | 263 | GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
148 | 264 | GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
149 | 265 |
|
150 |
| - ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], |
151 |
| - src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], |
| 266 | + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], |
| 267 | + src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], src3->nb[1], |
152 | 268 | src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
|
153 | 269 | }
|
0 commit comments