Skip to content

Commit b2f8eea

Browse files
committed
cuda: refactored ssm_scan to use CUB
1 parent 3e959f0 commit b2f8eea

File tree

1 file changed

+192
-76
lines changed

1 file changed

+192
-76
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 192 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,220 @@
1+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2+
#define USE_CUB
3+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4+
5+
#ifdef USE_CUB
6+
#include <cub/cub.cuh>
7+
using namespace cub;
8+
#endif // USE_CUB
9+
110
#include "ssm-scan.cuh"
211

312
template <size_t splitD, size_t N>
413
__global__ void __launch_bounds__(splitD, 2)
5-
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6-
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7-
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
8-
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
14+
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
15+
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
16+
const int src0_nb1, const int src0_nb2, const int src1_nb1, const int src1_nb2,
17+
const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1,
918
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
10-
float * __restrict__ dst, const int64_t L) {
11-
GGML_UNUSED(src1_nb0);
12-
GGML_UNUSED(src2_nb0);
13-
const int bidx = blockIdx.x; // split along B
14-
const int bidy = blockIdx.y; // split along D
15-
const int tid = threadIdx.x;
16-
const int wid = tid / 32;
17-
const int wtid = tid % 32;
18-
19-
extern __shared__ float smem[];
20-
const int stride_sA = N + 1;
21-
const int stride_ss0 = N + 1;
22-
float * smem_A = smem;
23-
float * smem_s0 = smem_A + splitD * stride_sA;
24-
25-
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
26-
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
27-
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
28-
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
29-
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
30-
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
31-
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
32-
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
19+
float *__restrict__ dst, const int64_t L)
20+
{
3321

34-
const int stride_s0 = src0_nb1 / sizeof(float);
35-
const int stride_x = src1_nb1 / sizeof(float);
22+
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
23+
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
24+
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
25+
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
26+
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2));
27+
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2));
28+
float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
29+
float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
30+
31+
const int stride_x = src1_nb1 / sizeof(float);
3632
const int stride_dt = src2_nb1 / sizeof(float);
37-
const int stride_A = src3_nb1 / sizeof(float);
38-
const int stride_B = src4_nb1 / sizeof(float);
39-
const int stride_C = src5_nb1 / sizeof(float);
40-
const int stride_s = stride_s0;
41-
const int stride_y = stride_x;
42-
43-
// can N not be 16? for example 32?
44-
if (N == 16) {
33+
const int stride_B = src4_nb1 / sizeof(float);
34+
const int stride_C = src5_nb1 / sizeof(float);
35+
const int stride_y = stride_x;
36+
37+
float regA[N];
38+
float regs0[N];
39+
40+
__shared__ float smemB[N];
41+
__shared__ float smemC[N];
42+
43+
#ifdef USE_CUB
44+
using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
45+
using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
46+
using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>;
47+
48+
__shared__ typename BlockLoadA::TempStorage block_load_tempA;
49+
__shared__ typename BlockLoadS0::TempStorage block_load_tempS0;
50+
__shared__ typename BlockStoreS::TempStorage block_store_tempS;
51+
52+
BlockLoadA(block_load_tempA).Load(A_block, regA);
53+
BlockLoadS0(block_load_tempS0).Load(s0_block, regs0);
54+
#else
55+
const int stride_s0 = src0_nb1 / sizeof(float);
56+
const int stride_A = src3_nb1 / sizeof(float);
4557
#pragma unroll
46-
for (size_t i = 0; i < splitD / 4; i += 2) {
47-
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
48-
// todo: bank conflict
49-
// I am always confused with how to use the swizzling method to solve
50-
// bank conflit. Hoping somebody can tell me.
51-
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
58+
for (int j = 0; j < N; ++j)
59+
{
60+
regA[j] = A_block[threadIdx.x * stride_A + j];
61+
regs0[j] = s0_block[threadIdx.x * stride_s0 + j];
62+
}
63+
#endif
64+
65+
for (int i = 0; i < L; i++)
66+
{
67+
if (threadIdx.x < N)
68+
{
69+
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
70+
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
71+
}
72+
__syncthreads();
73+
74+
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
75+
if (dt_soft_plus <= 20.0f)
76+
{
77+
dt_soft_plus = log1pf(expf(dt_soft_plus));
5278
}
79+
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
80+
81+
float sumf = 0.0f;
5382
#pragma unroll
54-
for (size_t i = 0; i < splitD / 4; i += 2) {
55-
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56-
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
83+
for (int j = 0; j < N; j++)
84+
{
85+
float state = regs0[j] * expf(dt_soft_plus * regA[j]) + smemB[j] * x_dt;
86+
sumf += state * smemC[j];
87+
regs0[j] = state;
5788
}
89+
y_block[i * stride_y + threadIdx.x] = sumf;
90+
}
91+
92+
#ifdef USE_CUB
93+
BlockStoreS(block_store_tempS).Store(s_block, regs0);
94+
#else
95+
const int stride_s = stride_s0;
96+
#pragma unroll
97+
for (int j = 0; j < N; ++j)
98+
{
99+
s_block[threadIdx.x * stride_s + j] = regs0[j];
58100
}
101+
#endif
102+
}
103+
104+
template <size_t splitD, size_t N>
105+
__global__ void __launch_bounds__(splitD, 2)
106+
ssm_scan_single_step_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
107+
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
108+
const int src0_nb1, const int src0_nb2, const int src1_nb2,
109+
const int src1_nb3, const int src2_nb2, const int src3_nb1,
110+
const int src4_nb2, const int src5_nb2,
111+
float *__restrict__ dst)
112+
{
113+
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
114+
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
115+
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
116+
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
117+
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2));
118+
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2));
119+
float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
120+
float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
121+
122+
float regA[N];
123+
float regs0[N];
124+
125+
__shared__ float smemB[N];
126+
__shared__ float smemC[N];
127+
128+
#ifdef USE_CUB
129+
using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
130+
using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
131+
using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>;
132+
133+
__shared__ typename BlockLoadA::TempStorage block_load_tempA;
134+
__shared__ typename BlockLoadS0::TempStorage block_load_tempS0;
135+
__shared__ typename BlockStoreS::TempStorage block_store_tempS;
136+
137+
BlockLoadA(block_load_tempA).Load(A_block, regA);
138+
BlockLoadS0(block_load_tempS0).Load(s0_block, regs0);
139+
#else
140+
const int stride_s0 = src0_nb1 / sizeof(float);
141+
const int stride_A = src3_nb1 / sizeof(float);
142+
#pragma unroll
143+
for (int j = 0; j < N; ++j)
144+
{
145+
regA[j] = A_block[threadIdx.x * stride_A + j];
146+
regs0[j] = s0_block[threadIdx.x * stride_s0 + j];
147+
}
148+
#endif
59149

150+
if (threadIdx.x < N)
151+
{
152+
smemB[threadIdx.x] = B_block[threadIdx.x];
153+
smemC[threadIdx.x] = C_block[threadIdx.x];
154+
}
60155
__syncthreads();
61156

62-
for (int64_t i = 0; i < L; i++) {
63-
float dt_soft_plus = dt_block[i * stride_dt + tid];
64-
if (dt_soft_plus <= 20.0f) {
65-
dt_soft_plus = log1pf(exp(dt_soft_plus));
157+
{
158+
float dt_soft_plus = dt_block[threadIdx.x];
159+
if (dt_soft_plus <= 20.0f)
160+
{
161+
dt_soft_plus = log1pf(expf(dt_soft_plus));
66162
}
67-
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
163+
float x_dt = x_block[threadIdx.x] * dt_soft_plus;
68164
float sumf = 0.0f;
69165
#pragma unroll
70-
for (size_t j = 0; j < N; j++) {
71-
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
72-
(B_block[i * stride_B + j] * x_dt);
73-
sumf += state * C_block[i * stride_C + j];
74-
if (i == L - 1) {
75-
s_block[tid * stride_s + j] = state;
76-
} else {
77-
smem_s0[tid * stride_ss0 + j] = state;
78-
}
166+
for (int j = 0; j < N; j++)
167+
{
168+
float state = regs0[j] * expf(dt_soft_plus * regA[j]) + smemB[j] * x_dt;
169+
sumf += state * smemC[j];
170+
regs0[j] = state;
79171
}
80-
__syncthreads();
81-
y_block[i * stride_y + tid] = sumf;
172+
y_block[threadIdx.x] = sumf;
173+
}
174+
175+
#ifdef USE_CUB
176+
BlockStoreS(block_store_tempS).Store(s_block, regs0);
177+
#else
178+
const int stride_s = s0;
179+
#pragma unroll
180+
for (int j = 0; j < N; ++j)
181+
{
182+
s_block[threadIdx.x * stride_s + j] = regs0[j];
82183
}
184+
#endif
83185
}
84186

85-
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
86-
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
87-
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
88-
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
187+
static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float *src2, const float *src3,
188+
const float *src4, const float *src5, const int src0_nb1, const int src0_nb2,
189+
const int src1_nb1, const int src1_nb2, const int src1_nb3,
190+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
89191
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
90-
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
91-
cudaStream_t stream) {
192+
float *dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
193+
cudaStream_t stream)
194+
{
92195
const int threads = 128;
93196
// todo: consider D cannot be divided,does this situation exist?
94197
GGML_ASSERT(D % threads == 0);
95198
const dim3 blocks(B, (D + threads - 1) / threads, 1);
96-
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
97-
if (N == 16) {
98-
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
99-
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
100-
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
101-
} else {
199+
if (N == 16)
200+
{
201+
if (L > 1)
202+
{
203+
ssm_scan_f32<threads, 16><<<blocks, threads, 0, stream>>>(
204+
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
205+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
206+
}
207+
else
208+
{
209+
ssm_scan_single_step_f32<threads, 16><<<blocks, threads, 0, stream>>>(
210+
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb2,
211+
src1_nb3, src2_nb2, src3_nb1,
212+
src4_nb2, src5_nb2,
213+
dst);
214+
}
215+
}
216+
else
217+
{
102218
GGML_ABORT("doesn't support N!=16.");
103219
}
104220
}
@@ -147,7 +263,7 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
147263
GGML_ASSERT(src0->type == GGML_TYPE_F32);
148264
GGML_ASSERT(dst->type == GGML_TYPE_F32);
149265

150-
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
151-
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
266+
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2],
267+
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], src3->nb[1],
152268
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
153269
}

0 commit comments

Comments
 (0)