From d3fd04e92ecbff0fd383ca6ec9bae549df2e7c0c Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Thu, 27 Apr 2023 20:16:32 +0200 Subject: [PATCH 1/5] cuBLAS: dequantize simultaneously while copying memory --- ggml-cuda.cu | 35 +++++++++++++++++++++++++++++------ ggml-cuda.h | 7 ++++++- ggml.c | 45 ++++++++++++++------------------------------- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index eb244f409aafd..70127c4770ee0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -227,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q8_0<<>>(vx, y); } +dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q4_2: + return dequantize_row_q4_2_cuda; + case GGML_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + default: + return nullptr; + } +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 16 @@ -286,18 +305,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } -cublasHandle_t g_cublasH = NULL; -cudaStream_t g_cudaStream = NULL; +cublasHandle_t g_cublasH = nullptr; +cudaStream_t g_cudaStream = nullptr; +cudaStream_t g_cudaStream2 = nullptr; +cudaEvent_t g_cudaEvent = nullptr; -void ggml_init_cublas(void) { - if (g_cublasH == NULL) { +void ggml_init_cublas() { + if (g_cublasH == nullptr) { // create cublas handle, bind a stream CUBLAS_CHECK(cublasCreate(&g_cublasH)); - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); - CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + // create additional stream and event for synchronization + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking)); + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming)); + // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } diff --git a/ggml-cuda.h b/ggml-cuda.h index 1fd67ebeb71cc..c2b5c359dd25b 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -26,7 +26,9 @@ extern "C" { } while (0) extern cublasHandle_t g_cublasH; -extern cudaStream_t g_cudaStream; +extern cudaStream_t g_cudaStream; +extern cudaStream_t g_cudaStream2; +extern cudaEvent_t g_cudaEvent; void ggml_init_cublas(void); void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); @@ -41,6 +43,9 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream); +typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); +dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type); + #ifdef __cplusplus } #endif diff --git a/ggml.c b/ggml.c index 4ec637ee1e082..2dda75c7be75d 100644 --- a/ggml.c +++ b/ggml.c @@ -8033,7 +8033,7 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_CUBLAS) const float alpha = 1.0f; const float beta = 0.0f; - const int x_ne = ne01 * ne10; + const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; @@ -8239,7 +8239,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( const float alpha = 1.0f; const float beta = 0.0f; - const int x_ne = ne01 * ne10; + const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; @@ -8498,39 +8498,19 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) const float alpha = 1.0f; const float beta = 0.0f; - const int x_ne = ne01 * ne10; + const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size, q_size; - float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); - float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); + float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); - void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; - if (type == GGML_TYPE_Q4_0) { - dequantize_row_q_cuda = dequantize_row_q4_0_cuda; - } - else if (type == GGML_TYPE_Q4_1) { - dequantize_row_q_cuda = dequantize_row_q4_1_cuda; - } - else if (type == GGML_TYPE_Q4_2) { - dequantize_row_q_cuda = dequantize_row_q4_2_cuda; - } - else if (type == GGML_TYPE_Q5_0) { - dequantize_row_q_cuda = dequantize_row_q5_0_cuda; - } - else if (type == GGML_TYPE_Q5_1) { - dequantize_row_q_cuda = dequantize_row_q5_1_cuda; - } - else if (type == GGML_TYPE_Q8_0) { - dequantize_row_q_cuda = dequantize_row_q8_0_cuda; - } - else { - GGML_ASSERT(false); - } -#elif !defined(GGML_USE_CLBLAST) + const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type); + GGML_ASSERT(dequantize_row_q_cuda != NULL); +#else float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; #endif @@ -8545,7 +8525,7 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy and dequantize on device CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream)); - dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); + dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2); CUDA_CHECK(cudaGetLastError()); #elif defined(GGML_USE_CLBLAST) const void* x = (char *) src0->data + i03*nb03 + i02*nb02; @@ -8565,6 +8545,9 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy data to device CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); + // wait for dequantization + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0)); + // compute CUBLAS_CHECK( cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, From 2dd6deeb4919370eb68f42a855d23629ce6fa283 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Thu, 27 Apr 2023 21:51:43 +0200 Subject: [PATCH 2/5] cuBLAS: use host pinned memory --- Makefile | 5 +++-- ggml-cuda.cu | 10 ++++++++++ ggml-cuda.h | 3 +++ ggml.c | 11 +++++++---- llama.cpp | 6 +++--- llama_util.h | 26 ++++++++++++++++++++++++++ 6 files changed, 52 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 0715e857bc346..5a1cb3e83e365 100644 --- a/Makefile +++ b/Makefile @@ -106,6 +106,7 @@ ifdef LLAMA_OPENBLAS endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include + CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc @@ -164,10 +165,10 @@ $(info ) # Build library # -ggml.o: ggml.c ggml.h +ggml.o: ggml.c ggml.h ggml-cuda.h $(CC) $(CFLAGS) -c $< -o $@ -llama.o: llama.cpp ggml.h llama.h llama_util.h +llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama_util.h $(CXX) $(CXXFLAGS) -c $< -o $@ common.o: examples/common.cpp examples/common.h diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 70127c4770ee0..5a2701cfeef68 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -353,3 +353,13 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, return cudaSuccess; } } + +void * ggml_cuda_host_malloc(size_t size) { + void * ptr; + CUDA_CHECK(cudaMallocHost((void **) &ptr, size)); + return ptr; +} + +void ggml_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} diff --git a/ggml-cuda.h b/ggml-cuda.h index c2b5c359dd25b..36782d9e796b7 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -31,6 +31,9 @@ extern cudaStream_t g_cudaStream2; extern cudaEvent_t g_cudaEvent; void ggml_init_cublas(void); +void * ggml_cuda_host_malloc(size_t size); +void ggml_cuda_host_free(void * ptr); + void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); void ggml_cuda_pool_free(void * ptr, size_t size); diff --git a/ggml.c b/ggml.c index 2dda75c7be75d..f4dc48f819e95 100644 --- a/ggml.c +++ b/ggml.c @@ -8235,8 +8235,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) - ggml_fp16_t * const wdata = params->wdata; - const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne00; @@ -8254,6 +8252,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( for (int64_t i02 = 0; i02 < ne02; i02++) { #if defined(GGML_USE_CUBLAS) // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02); { size_t id = 0; for (int64_t i01 = 0; i01 < ne11; ++i01) { @@ -8540,7 +8539,6 @@ static void ggml_compute_forward_mul_mat_q_f32( const float * x = wdata; #endif - #if defined(GGML_USE_CUBLAS) // copy data to device CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); @@ -11571,7 +11569,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0)); //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); //printf("cur = %zu\n", cur); @@ -11583,6 +11581,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + } +#endif } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { diff --git a/llama.cpp b/llama.cpp index 45f0d44acc548..13df67c23f675 100644 --- a/llama.cpp +++ b/llama.cpp @@ -167,7 +167,7 @@ struct llama_model { struct llama_kv_cache kv_self; // the model memory buffer - llama_buffer buf; + llama_ctx_buffer buf; // model memory mapped file std::unique_ptr mapping; @@ -228,8 +228,8 @@ struct llama_context { // memory buffers used to evaluate the model // TODO: move in llama_state - llama_buffer buf_compute; - llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; + llama_ctx_buffer buf_compute; + llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; diff --git a/llama_util.h b/llama_util.h index acb207e653c10..6e66d12a8041c 100755 --- a/llama_util.h +++ b/llama_util.h @@ -405,4 +405,30 @@ struct llama_buffer { delete[] addr; } }; + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +struct llama_ctx_buffer { + uint8_t * addr = NULL; + size_t size = 0; + + void resize(size_t size) { + if (addr) { + ggml_cuda_host_free(addr); + } + addr = (uint8_t *) ggml_cuda_host_malloc(size); + this->size = size; + } + + ~llama_ctx_buffer() { + if (addr) { + ggml_cuda_host_free(addr); + } + } +}; +#else +typedef llama_buffer llama_ctx_buffer; +#endif + + #endif From d5d6a8083a05eb1ae3897f1302b5e5116146d83d Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Thu, 27 Apr 2023 23:27:59 +0200 Subject: [PATCH 3/5] cuBLAS: improve ggml_compute_forward_mul_mat_f16_f32 with pinned memory --- ggml.c | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index f4dc48f819e95..26e8b87292e2d 100644 --- a/ggml.c +++ b/ggml.c @@ -8242,15 +8242,18 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size; - float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #else float * const wdata = params->wdata; #endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { #if defined(GGML_USE_CUBLAS) + // copy src0 while converting src1 + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i02, i03, g_cudaStream)); + // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02); { @@ -8274,11 +8277,9 @@ static void ggml_compute_forward_mul_mat_f16_f32( #if defined(GGML_USE_CUBLAS) const ggml_fp16_t * y = (ggml_fp16_t *) wdata; - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // copy data to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); // compute From 3cf2247d37327fa6aec29ff3cd0d799673cfd50a Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Fri, 28 Apr 2023 00:48:01 +0200 Subject: [PATCH 4/5] cuBLAS: also pin kv cache --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 13df67c23f675..4699e5cf1de7c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -136,7 +136,7 @@ struct llama_kv_cache { struct ggml_context * ctx = NULL; - llama_buffer buf; + llama_ctx_buffer buf; int n; // number of tokens currently in the cache From 38a021fafefcfe4b02779cea4e951cb1fd473698 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 29 Apr 2023 01:55:50 +0200 Subject: [PATCH 5/5] fix rebase --- ggml.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 26e8b87292e2d..64ecd0867e4fb 100644 --- a/ggml.c +++ b/ggml.c @@ -8252,7 +8252,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( for (int64_t i02 = 0; i02 < ne02; i02++) { #if defined(GGML_USE_CUBLAS) // copy src0 while converting src1 - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i02, i03, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02); @@ -8523,10 +8523,11 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) // copy and dequantize on device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2)); dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2); CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2)); #elif defined(GGML_USE_CLBLAST) const void* x = (char *) src0->data + i03*nb03 + i02*nb02; #else