-
Notifications
You must be signed in to change notification settings - Fork 12k
cuda: refactored ssm_scan and use CUB #13291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider deduplicating the code by adding an optional template parameter for L. Ignore the template parameter if it's 0, otherwise use it instead of the runtime parameter (add #pragma unroll
to the loops over L). Adding additional template specializations for L <= 8 would likely also improve performance. You can look at softmax.cu
for an example.
If you are not doing this already, my recommendation for optimizing CUDA performance would be to first use NVIDIA NSight Systems to identify which kernels take up a large percentage of the total runtime (and are thus worth optimizing). Then you can use NVIDIA NSight Compute to get a detailed breakdown of a specific kernel and to identify bottlenecks. For this kernel I assume the bottleneck is I/0.
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); | ||
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); | ||
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2)); | ||
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2)); | ||
float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float)); | ||
float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In GPU code there can be performance issues if you cast to char *
, do pointer arithmetic, and then cast back to float *
. But since this is only done once here it should be fine and in my experience this mostly affects the HIP port for AMD anyways.
ggml/src/ggml-cuda/ssm-scan.cu
Outdated
#include <cub/cub.cuh> | ||
using namespace cub; | ||
#endif // USE_CUB | ||
|
||
#include "ssm-scan.cuh" | ||
|
||
template <size_t splitD, size_t N> | ||
__global__ void __launch_bounds__(splitD, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In CUDA there are 64k registers per SM and each thread can at most use 255 registers. So with 128 threads the occupancy limit in terms of registers is 4 and telling the compiler to limit register usage in order to fit 2 blocks effectively tells it to just use as many registers as it wants. You could maybe change the args to (splitD, 1)
to make this a little clearer but I think it's also fine as-is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could just remove it if it's not doing anything then, so it would be (splitD)
only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this does in fact do something. The compiler is by default very conservative with how many registers it uses because this avoids the worst-performing cases but it also leaves potential performance on the table. If you explicitly tell the compiler to use as many registers as it wants the performance can be better (for this kernel it probably doesn't matter anyways).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see that's why the register count used was 64 if I removed it. It does seem to make a small difference in performance. I'll change it to 1 since there doesn't seem to be a difference from 2 in the generated assembly.
regA[n] = A_block[threadIdx.x * stride_A + n]; | ||
regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memory access pattern here is inefficient though I also wouldn't know how to improve it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the problem lie in that the loads aren't coalesced? Wouldn't using a coalesced loading pattern require the data to be in a different layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the problem is the uncoalesced I/O. If you could somehow re-write the kernel to make the loads coalesced or change the memory pattern the previous kernel puts out the performance would likely be better. (I did not try to analyze whether something like this is possible.)
ggml/src/ggml-cuda/ssm-scan.cu
Outdated
#pragma unroll | ||
for (size_t n = 0; n < N; ++n) | ||
{ | ||
s_block[threadIdx.x * stride_s + n] = regs0[n]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memory access pattern here is also inefficient.
I modified the structure of the CUDA kernel for the ssm scan such parallelization is performed per thread across the channel dimension (D). This allows A and the initial state (s0) to be loaded into registers and reused across the sequence (L) and SSM state dimensions (N). Additionally, B and C can be loaded into shared memory since blocks process the same timestep in parallel. I also added another CUDA kernel specifically for a sequence length of 1 (recurrent mode) in order to reduce the number of registers used by removing the loop over the sequence dimension.
I'm unsure about optimizing the number of threads per block or the minimum number of blocks per multiprocessor in the launch bounds, however, so I left them as is.
Benchmarks
I got the following results with the following test cases added to test-backend-ops.cpp.
Hardware: Intel i7-13700K, Nvidia RTX 3090
Raw output:
cpu.txt
original_cuda.txt
improved_cuda.txt
improved_cuda_no_cub.txt
llama-bench
Original:
Improved:
Improved (No CUB):