Skip to content

ggml-cuda : increase max graph size #4084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
#define CC_OFFSET_AMD 1000000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)

#define GGML_CUDA_MAX_NODES 8192

// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
Expand Down Expand Up @@ -7727,7 +7729,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}

void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}

Expand Down Expand Up @@ -7842,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0;

static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (g_temp_tensor_extras == nullptr) {
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
}

size_t alloc_index = g_temp_tensor_extra_index;
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));

Expand Down Expand Up @@ -8173,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda {

ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (temp_tensor_extras == nullptr) {
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
}

size_t alloc_index = temp_tensor_extra_index;
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));

Expand Down