Skip to content

Commit 1aa18ef

Browse files
lshzh-wwggerganov
andauthored
metal : concurrently dispatch commands (#2358)
* metal: concurrently dispatch commands Function `ggml_metal_graph_find_concurrency` will run and write commands that can be issued concurrently to metal context `concur_list` array, when `ggml_metal_graph_compute` is called for the first time. * metal: don't call find_concurrency automatically. * metal : code style changes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 9a08eaf commit 1aa18ef

File tree

3 files changed

+138
-19
lines changed

3 files changed

+138
-19
lines changed

ggml-metal.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
6161
// get data from the device into host memory
6262
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
6363

64+
// try to find operations that can be run concurrently in the graph
65+
// you should run it again if the topology of your graph changes
66+
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67+
68+
// if the graph has been optimized for concurrently dispatch
69+
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70+
6471
// same as ggml_graph_compute but uses Metal
6572
// creates gf->n_threads command buffers in parallel
6673
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);

ggml-metal.m

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
int n_buffers;
3737
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
3838

39+
int concur_list[GGML_MAX_NODES];
40+
int concur_list_len;
41+
3942
// custom kernels
4043
#define GGML_METAL_DECL_KERNEL(name) \
4144
id<MTLFunction> function_##name; \
@@ -98,6 +101,7 @@ @implementation GGMLMetalClass
98101
ctx->device = MTLCreateSystemDefaultDevice();
99102
ctx->queue = [ctx->device newCommandQueue];
100103
ctx->n_buffers = 0;
104+
ctx->concur_list_len = 0;
101105

102106
// determine if we can use MPS
103107
if (MPSSupportsMTLDevice(ctx->device)) {
@@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
217221
ctx->n_cb = n_cb;
218222
}
219223

224+
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225+
if (ctx->concur_list_len) {
226+
return true;
227+
}
228+
return false;
229+
}
230+
220231
// finds the Metal buffer that contains the tensor data on the GPU device
221232
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
222233
// Metal buffer based on the host memory pointer
@@ -355,11 +366,98 @@ void ggml_metal_get_tensor(
355366
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
356367
}
357368

369+
void ggml_metal_graph_find_concurrency(
370+
struct ggml_metal_context * ctx,
371+
struct ggml_cgraph * gf) {
372+
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373+
int nodes_unused[GGML_MAX_NODES];
374+
375+
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376+
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377+
ctx->concur_list_len = 0;
378+
379+
int n_left = gf->n_nodes;
380+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382+
383+
while (n_left > 0) {
384+
// number of nodes at a layer (that can be issued concurrently)
385+
int concurrency = 0;
386+
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387+
if (nodes_unused[i]) {
388+
// if the requirements for gf->nodes[i] are satisfied
389+
int exe_flag=1;
390+
// scan all srcs
391+
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392+
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393+
if (src_cur) {
394+
// if is leaf nodes it's satisfied.
395+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396+
397+
// otherwise this src should be the output from previous nodes.
398+
int is_found = 0;
399+
// scan 2*search_depth back because we inserted barrier.
400+
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401+
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402+
}
403+
if (is_found == 0) {exe_flag = 0; break;}
404+
}
405+
}
406+
if (exe_flag) {
407+
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
408+
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409+
int64_t data_start = (int64_t) gf->nodes[i]->data;
410+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411+
for (int j = n_start; j < i; j++) {
412+
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413+
&& gf->nodes[j]->op != GGML_OP_VIEW \
414+
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415+
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
416+
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417+
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418+
continue;
419+
} else {
420+
exe_flag = 0;
421+
}
422+
}
423+
}
424+
}
425+
if (exe_flag) {
426+
ctx->concur_list[level_pos + concurrency] = i;
427+
nodes_unused[i] = 0;
428+
concurrency++;
429+
ctx->concur_list_len++;
430+
}
431+
}
432+
}
433+
n_left -= concurrency;
434+
// adding a barrier different layer
435+
ctx->concur_list[level_pos + concurrency] = -1;
436+
ctx->concur_list_len++;
437+
// jump all sorted nodes at nodes_bak
438+
while (!nodes_unused[n_start]) {n_start++;}
439+
level_pos += concurrency + 1;
440+
}
441+
442+
if (ctx->concur_list_len > GGML_MAX_NODES) {
443+
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444+
}
445+
}
446+
358447
void ggml_metal_graph_compute(
359448
struct ggml_metal_context * ctx,
360449
struct ggml_cgraph * gf) {
361450
metal_printf("%s: evaluating graph\n", __func__);
362451

452+
// if there is ctx->concur_list, dispatch concurrently
453+
// else fallback to serial dispatch
454+
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455+
456+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457+
458+
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459+
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460+
363461
// create multiple command buffers and enqueue them
364462
// then, we encode the graph into the command buffers in parallel
365463

@@ -378,7 +476,7 @@ void ggml_metal_graph_compute(
378476
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
379477

380478
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
381-
const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
382480

383481
dispatch_async(queue, ^{
384482
size_t offs_src0 = 0;
@@ -389,10 +487,21 @@ void ggml_metal_graph_compute(
389487

390488
id<MTLComputeCommandEncoder> encoder = nil;
391489

392-
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
393-
const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491+
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492+
493+
for (int ind = node_start; ind < node_end; ++ind) {
494+
const int i = has_concur ? ctx->concur_list[ind] : ind;
495+
496+
if (i == -1) {
497+
if (encoder == nil) {
498+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499+
continue;
500+
}
501+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502+
continue;
503+
}
394504

395-
for (int i = node_start; i < node_end; ++i) {
396505
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
397506

398507
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -463,7 +572,7 @@ void ggml_metal_graph_compute(
463572
case GGML_OP_ADD:
464573
{
465574
if (encoder == nil) {
466-
encoder = [command_buffer computeCommandEncoder];
575+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
467576
}
468577

469578
if (ggml_nelements(src1) == ne10) {
@@ -484,7 +593,7 @@ void ggml_metal_graph_compute(
484593
case GGML_OP_MUL:
485594
{
486595
if (encoder == nil) {
487-
encoder = [command_buffer computeCommandEncoder];
596+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
488597
}
489598

490599
if (ggml_nelements(src1) == ne10) {
@@ -505,7 +614,7 @@ void ggml_metal_graph_compute(
505614
case GGML_OP_SCALE:
506615
{
507616
if (encoder == nil) {
508-
encoder = [command_buffer computeCommandEncoder];
617+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
509618
}
510619

511620
const float scale = *(const float *) src1->data;
@@ -524,7 +633,7 @@ void ggml_metal_graph_compute(
524633
case GGML_UNARY_OP_SILU:
525634
{
526635
if (encoder == nil) {
527-
encoder = [command_buffer computeCommandEncoder];
636+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
528637
}
529638

530639
[encoder setComputePipelineState:ctx->pipeline_silu];
@@ -538,7 +647,7 @@ void ggml_metal_graph_compute(
538647
case GGML_UNARY_OP_RELU:
539648
{
540649
if (encoder == nil) {
541-
encoder = [command_buffer computeCommandEncoder];
650+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
542651
}
543652

544653
[encoder setComputePipelineState:ctx->pipeline_relu];
@@ -552,7 +661,7 @@ void ggml_metal_graph_compute(
552661
case GGML_UNARY_OP_GELU:
553662
{
554663
if (encoder == nil) {
555-
encoder = [command_buffer computeCommandEncoder];
664+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
556665
}
557666

558667
[encoder setComputePipelineState:ctx->pipeline_gelu];
@@ -572,7 +681,7 @@ void ggml_metal_graph_compute(
572681
case GGML_OP_SOFT_MAX:
573682
{
574683
if (encoder == nil) {
575-
encoder = [command_buffer computeCommandEncoder];
684+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
576685
}
577686

578687
const int nth = 32;
@@ -590,7 +699,7 @@ void ggml_metal_graph_compute(
590699
case GGML_OP_DIAG_MASK_INF:
591700
{
592701
if (encoder == nil) {
593-
encoder = [command_buffer computeCommandEncoder];
702+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
594703
}
595704

596705
const int n_past = ((int32_t *)(dst->op_params))[0];
@@ -653,7 +762,7 @@ void ggml_metal_graph_compute(
653762
}
654763
} else {
655764
if (encoder == nil) {
656-
encoder = [command_buffer computeCommandEncoder];
765+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
657766
}
658767

659768
int nth0 = 32;
@@ -780,7 +889,7 @@ void ggml_metal_graph_compute(
780889
case GGML_OP_GET_ROWS:
781890
{
782891
if (encoder == nil) {
783-
encoder = [command_buffer computeCommandEncoder];
892+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
784893
}
785894

786895
switch (src0->type) {
@@ -809,7 +918,7 @@ void ggml_metal_graph_compute(
809918
case GGML_OP_RMS_NORM:
810919
{
811920
if (encoder == nil) {
812-
encoder = [command_buffer computeCommandEncoder];
921+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
813922
}
814923

815924
float eps;
@@ -832,7 +941,7 @@ void ggml_metal_graph_compute(
832941
case GGML_OP_NORM:
833942
{
834943
if (encoder == nil) {
835-
encoder = [command_buffer computeCommandEncoder];
944+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
836945
}
837946

838947
const float eps = 1e-5f;
@@ -854,7 +963,7 @@ void ggml_metal_graph_compute(
854963
case GGML_OP_ALIBI:
855964
{
856965
if (encoder == nil) {
857-
encoder = [command_buffer computeCommandEncoder];
966+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
858967
}
859968

860969
GGML_ASSERT((src0t == GGML_TYPE_F32));
@@ -897,7 +1006,7 @@ void ggml_metal_graph_compute(
8971006
case GGML_OP_ROPE:
8981007
{
8991008
if (encoder == nil) {
900-
encoder = [command_buffer computeCommandEncoder];
1009+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
9011010
}
9021011

9031012
const int n_past = ((int32_t *) dst->op_params)[0];
@@ -941,7 +1050,7 @@ void ggml_metal_graph_compute(
9411050
case GGML_OP_CONT:
9421051
{
9431052
if (encoder == nil) {
944-
encoder = [command_buffer computeCommandEncoder];
1053+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
9451054
}
9461055

9471056
const int nth = 32;

llama.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,9 @@ static bool llama_eval_internal(
17201720

17211721
#ifdef GGML_USE_METAL
17221722
if (lctx.ctx_metal && N == 1) {
1723+
if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
1724+
ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
1725+
}
17231726
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
17241727
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
17251728
ggml_metal_get_tensor (lctx.ctx_metal, cur);

0 commit comments

Comments
 (0)