-
Notifications
You must be signed in to change notification settings - Fork 12k
cuda: refactored ssm_scan and use CUB #13291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
b2f8eea
c7d4d45
949e4fa
75520d6
7e559f3
7d259d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,220 @@ | ||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 | ||
#define USE_CUB | ||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 | ||
|
||
#ifdef USE_CUB | ||
#include <cub/cub.cuh> | ||
using namespace cub; | ||
#endif // USE_CUB | ||
|
||
#include "ssm-scan.cuh" | ||
|
||
template <size_t splitD, size_t N> | ||
__global__ void __launch_bounds__(splitD, 2) | ||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, | ||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, | ||
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, | ||
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, | ||
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, | ||
const int src0_nb1, const int src0_nb2, const int src1_nb1, const int src1_nb2, | ||
const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
float * __restrict__ dst, const int64_t L) { | ||
GGML_UNUSED(src1_nb0); | ||
GGML_UNUSED(src2_nb0); | ||
const int bidx = blockIdx.x; // split along B | ||
const int bidy = blockIdx.y; // split along D | ||
const int tid = threadIdx.x; | ||
const int wid = tid / 32; | ||
const int wtid = tid % 32; | ||
|
||
extern __shared__ float smem[]; | ||
const int stride_sA = N + 1; | ||
const int stride_ss0 = N + 1; | ||
float * smem_A = smem; | ||
float * smem_s0 = smem_A + splitD * stride_sA; | ||
|
||
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); | ||
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); | ||
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); | ||
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); | ||
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
float *__restrict__ dst, const int64_t L) | ||
{ | ||
|
||
const int stride_s0 = src0_nb1 / sizeof(float); | ||
const int stride_x = src1_nb1 / sizeof(float); | ||
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); | ||
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); | ||
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2)); | ||
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2)); | ||
float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); | ||
Comment on lines
+22
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In GPU code there can be performance issues if you cast to |
||
|
||
const int stride_x = src1_nb1 / sizeof(float); | ||
const int stride_dt = src2_nb1 / sizeof(float); | ||
const int stride_A = src3_nb1 / sizeof(float); | ||
const int stride_B = src4_nb1 / sizeof(float); | ||
const int stride_C = src5_nb1 / sizeof(float); | ||
const int stride_s = stride_s0; | ||
const int stride_y = stride_x; | ||
|
||
// can N not be 16? for example 32? | ||
if (N == 16) { | ||
const int stride_B = src4_nb1 / sizeof(float); | ||
const int stride_C = src5_nb1 / sizeof(float); | ||
const int stride_y = stride_x; | ||
|
||
float regA[N]; | ||
float regs0[N]; | ||
|
||
__shared__ float smemB[N]; | ||
__shared__ float smemC[N]; | ||
|
||
#ifdef USE_CUB | ||
using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; | ||
using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; | ||
using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>; | ||
|
||
__shared__ typename BlockLoadA::TempStorage block_load_tempA; | ||
__shared__ typename BlockLoadS0::TempStorage block_load_tempS0; | ||
__shared__ typename BlockStoreS::TempStorage block_store_tempS; | ||
|
||
BlockLoadA(block_load_tempA).Load(A_block, regA); | ||
BlockLoadS0(block_load_tempS0).Load(s0_block, regs0); | ||
#else | ||
const int stride_s0 = src0_nb1 / sizeof(float); | ||
const int stride_A = src3_nb1 / sizeof(float); | ||
#pragma unroll | ||
for (size_t i = 0; i < splitD / 4; i += 2) { | ||
float value = A_block[(wid * warpSize + i) * stride_A + wtid]; | ||
// todo: bank conflict | ||
// I am always confused with how to use the swizzling method to solve | ||
// bank conflit. Hoping somebody can tell me. | ||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
for (size_t n = 0; n < N; ++n) | ||
{ | ||
regA[n] = A_block[threadIdx.x * stride_A + n]; | ||
regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; | ||
Comment on lines
+61
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The memory access pattern here is inefficient though I also wouldn't know how to improve it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the problem lie in that the loads aren't coalesced? Wouldn't using a coalesced loading pattern require the data to be in a different layout? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the problem is the uncoalesced I/O. If you could somehow re-write the kernel to make the loads coalesced or change the memory pattern the previous kernel puts out the performance would likely be better. (I did not try to analyze whether something like this is possible.) |
||
} | ||
#endif | ||
|
||
for (int i = 0; i < L; i++) | ||
{ | ||
if (threadIdx.x < N) | ||
{ | ||
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x]; | ||
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x]; | ||
} | ||
__syncthreads(); | ||
|
||
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x]; | ||
if (dt_soft_plus <= 20.0f) | ||
{ | ||
dt_soft_plus = log1pf(expf(dt_soft_plus)); | ||
} | ||
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus; | ||
|
||
float sumf = 0.0f; | ||
#pragma unroll | ||
for (size_t i = 0; i < splitD / 4; i += 2) { | ||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; | ||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
for (size_t n = 0; n < N; n++) | ||
{ | ||
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt; | ||
sumf += state * smemC[n]; | ||
regs0[n] = state; | ||
} | ||
y_block[i * stride_y + threadIdx.x] = sumf; | ||
} | ||
|
||
#ifdef USE_CUB | ||
BlockStoreS(block_store_tempS).Store(s_block, regs0); | ||
#else | ||
const int stride_s = stride_s0; | ||
#pragma unroll | ||
for (size_t n = 0; n < N; ++n) | ||
{ | ||
s_block[threadIdx.x * stride_s + n] = regs0[n]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The memory access pattern here is also inefficient. |
||
} | ||
#endif | ||
} | ||
|
||
template <size_t splitD, size_t N> | ||
__global__ void __launch_bounds__(splitD, 2) | ||
ssm_scan_single_step_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, | ||
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, | ||
const int src0_nb1, const int src0_nb2, const int src1_nb2, | ||
const int src1_nb3, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb2, const int src5_nb2, | ||
float *__restrict__ dst) | ||
{ | ||
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); | ||
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); | ||
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2)); | ||
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2)); | ||
float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); | ||
|
||
float regA[N]; | ||
float regs0[N]; | ||
|
||
__shared__ float smemB[N]; | ||
__shared__ float smemC[N]; | ||
|
||
#ifdef USE_CUB | ||
using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; | ||
using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>; | ||
using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>; | ||
|
||
__shared__ typename BlockLoadA::TempStorage block_load_tempA; | ||
__shared__ typename BlockLoadS0::TempStorage block_load_tempS0; | ||
__shared__ typename BlockStoreS::TempStorage block_store_tempS; | ||
|
||
BlockLoadA(block_load_tempA).Load(A_block, regA); | ||
BlockLoadS0(block_load_tempS0).Load(s0_block, regs0); | ||
#else | ||
const int stride_s0 = src0_nb1 / sizeof(float); | ||
const int stride_A = src3_nb1 / sizeof(float); | ||
#pragma unroll | ||
for (size_t n = 0; n < N; ++n) | ||
{ | ||
regA[n] = A_block[threadIdx.x * stride_A + n]; | ||
regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; | ||
} | ||
#endif | ||
|
||
if (threadIdx.x < N) | ||
{ | ||
smemB[threadIdx.x] = B_block[threadIdx.x]; | ||
smemC[threadIdx.x] = C_block[threadIdx.x]; | ||
} | ||
__syncthreads(); | ||
|
||
for (int64_t i = 0; i < L; i++) { | ||
float dt_soft_plus = dt_block[i * stride_dt + tid]; | ||
if (dt_soft_plus <= 20.0f) { | ||
dt_soft_plus = log1pf(exp(dt_soft_plus)); | ||
{ | ||
float dt_soft_plus = dt_block[threadIdx.x]; | ||
if (dt_soft_plus <= 20.0f) | ||
{ | ||
dt_soft_plus = log1pf(expf(dt_soft_plus)); | ||
} | ||
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; | ||
float x_dt = x_block[threadIdx.x] * dt_soft_plus; | ||
float sumf = 0.0f; | ||
#pragma unroll | ||
for (size_t j = 0; j < N; j++) { | ||
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + | ||
(B_block[i * stride_B + j] * x_dt); | ||
sumf += state * C_block[i * stride_C + j]; | ||
if (i == L - 1) { | ||
s_block[tid * stride_s + j] = state; | ||
} else { | ||
smem_s0[tid * stride_ss0 + j] = state; | ||
} | ||
for (size_t n = 0; n < N; n++) | ||
{ | ||
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt; | ||
sumf += state * smemC[n]; | ||
regs0[n] = state; | ||
} | ||
__syncthreads(); | ||
y_block[i * stride_y + tid] = sumf; | ||
y_block[threadIdx.x] = sumf; | ||
} | ||
|
||
#ifdef USE_CUB | ||
BlockStoreS(block_store_tempS).Store(s_block, regs0); | ||
#else | ||
const int stride_s = stride_s0; | ||
#pragma unroll | ||
for (size_t n = 0; n < N; ++n) | ||
{ | ||
s_block[threadIdx.x * stride_s + n] = regs0[n]; | ||
} | ||
#endif | ||
} | ||
|
||
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, | ||
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float *src2, const float *src3, | ||
const float *src4, const float *src5, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, | ||
cudaStream_t stream) { | ||
float *dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, | ||
cudaStream_t stream) | ||
{ | ||
const int threads = 128; | ||
// todo: consider D cannot be divided,does this situation exist? | ||
GGML_ASSERT(D % threads == 0); | ||
const dim3 blocks(B, (D + threads - 1) / threads, 1); | ||
const int smem_size = (threads * (N + 1) * 2) * sizeof(float); | ||
if (N == 16) { | ||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( | ||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, | ||
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); | ||
} else { | ||
if (N == 16) | ||
{ | ||
if (L > 1) | ||
{ | ||
ssm_scan_f32<threads, 16><<<blocks, threads, 0, stream>>>( | ||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3, | ||
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); | ||
} | ||
else | ||
{ | ||
ssm_scan_single_step_f32<threads, 16><<<blocks, threads, 0, stream>>>( | ||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb2, | ||
src1_nb3, src2_nb2, src3_nb1, | ||
src4_nb2, src5_nb2, | ||
dst); | ||
} | ||
} | ||
else | ||
{ | ||
GGML_ABORT("doesn't support N!=16."); | ||
} | ||
} | ||
|
@@ -147,7 +263,7 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|
||
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], | ||
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], | ||
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], | ||
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], src3->nb[1], | ||
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In CUDA there are 64k registers per SM and each thread can at most use 255 registers. So with 128 threads the occupancy limit in terms of registers is 4 and telling the compiler to limit register usage in order to fit 2 blocks effectively tells it to just use as many registers as it wants. You could maybe change the args to
(splitD, 1)
to make this a little clearer but I think it's also fine as-is.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could just remove it if it's not doing anything then, so it would be
(splitD)
only.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this does in fact do something. The compiler is by default very conservative with how many registers it uses because this avoids the worst-performing cases but it also leaves potential performance on the table. If you explicitly tell the compiler to use as many registers as it wants the performance can be better (for this kernel it probably doesn't matter anyways).
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see that's why the register count used was 64 if I removed it. It does seem to make a small difference in performance. I'll change it to 1 since there doesn't seem to be a difference from 2 in the generated assembly.