36
36
int n_buffers;
37
37
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
// custom kernels
40
43
#define GGML_METAL_DECL_KERNEL (name ) \
41
44
id <MTLFunction > function_##name; \
@@ -98,6 +101,7 @@ @implementation GGMLMetalClass
98
101
ctx->device = MTLCreateSystemDefaultDevice ();
99
102
ctx->queue = [ctx->device newCommandQueue ];
100
103
ctx->n_buffers = 0 ;
104
+ ctx->concur_list_len = 0 ;
101
105
102
106
// determine if we can use MPS
103
107
if (MPSSupportsMTLDevice (ctx->device )) {
@@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
217
221
ctx->n_cb = n_cb;
218
222
}
219
223
224
+ bool ggml_metal_if_optimized (struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len ) {
226
+ return true ;
227
+ }
228
+ return false ;
229
+ }
230
+
220
231
// finds the Metal buffer that contains the tensor data on the GPU device
221
232
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
222
233
// Metal buffer based on the host memory pointer
@@ -355,11 +366,98 @@ void ggml_metal_get_tensor(
355
366
memcpy (t->data , (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes (t));
356
367
}
357
368
369
+ void ggml_metal_graph_find_concurrency (
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes ; // we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0 ; i < GGML_MAX_NODES; i++) {ctx->concur_list [i] = 0 ;}
376
+ for (int i = 0 ; i < gf->n_nodes ; i++) {nodes_unused[i] = 1 ;}
377
+ ctx->concur_list_len = 0 ;
378
+
379
+ int n_left = gf->n_nodes ;
380
+ int n_start = 0 ; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0 ; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0 ) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0 ;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes ) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1 ;
390
+ // scan all srcs
391
+ for (int src_ind = 0 ; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes [i]->src [src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL ) {continue ;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0 ;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2 *search_depth) < 0 ? 0 : (level_pos - 2 *search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes [ctx->concur_list[j]] == src_cur) {is_found = 1 ; break ;}
402
+ }
403
+ if (is_found == 0 ) {exe_flag = 0 ; break ;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t ) gf->nodes [i]->data ;
410
+ int64_t length = (int64_t ) ggml_nbytes (gf->nodes [i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes [j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes [j]->op != GGML_OP_VIEW \
414
+ && gf->nodes [j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes [j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t )gf->nodes [j]->data ) >= data_start + length || \
417
+ ((int64_t )gf->nodes [j]->data ) + (int64_t ) ggml_nbytes (gf->nodes [j]) <= data_start) {
418
+ continue ;
419
+ } else {
420
+ exe_flag = 0 ;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list [level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0 ;
428
+ concurrency++;
429
+ ctx->concur_list_len ++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list [level_pos + concurrency] = -1 ;
436
+ ctx->concur_list_len ++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1 ;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf (stderr, " %s : too many elements for metal ctx->concur_list!\n " , __func__);
444
+ }
445
+ }
446
+
358
447
void ggml_metal_graph_compute (
359
448
struct ggml_metal_context * ctx,
360
449
struct ggml_cgraph * gf) {
361
450
metal_printf (" %s : evaluating graph\n " , __func__);
362
451
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor ;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes ;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial ;
460
+
363
461
// create multiple command buffers and enqueue them
364
462
// then, we encode the graph into the command buffers in parallel
365
463
@@ -378,7 +476,7 @@ void ggml_metal_graph_compute(
378
476
dispatch_queue_t queue = dispatch_queue_create (" llama.cpp" , DISPATCH_QUEUE_CONCURRENT);
379
477
380
478
for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
381
- const int n_nodes_per_cb = (gf-> n_nodes + n_cb - 1 ) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1 ) / n_cb;
382
480
383
481
dispatch_async (queue, ^{
384
482
size_t offs_src0 = 0 ;
@@ -389,10 +487,21 @@ void ggml_metal_graph_compute(
389
487
390
488
id <MTLComputeCommandEncoder > encoder = nil ;
391
489
392
- const int node_start = (cb_idx + 0 ) * n_nodes_per_cb;
393
- const int node_end = (cb_idx == n_cb - 1 ) ? gf->n_nodes : (cb_idx + 1 ) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0 ) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1 ) ? n_nodes : (cb_idx + 1 ) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list [ind] : ind;
495
+
496
+ if (i == -1 ) {
497
+ if (encoder == nil ) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue ;
500
+ }
501
+ [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
502
+ continue ;
503
+ }
394
504
395
- for (int i = node_start; i < node_end; ++i) {
396
505
metal_printf (" %s : encoding node %3d , op = %8s \n " , __func__, i, ggml_op_name (gf->nodes [i]->op ));
397
506
398
507
struct ggml_tensor * src0 = gf->nodes [i]->src [0 ];
@@ -463,7 +572,7 @@ void ggml_metal_graph_compute(
463
572
case GGML_OP_ADD:
464
573
{
465
574
if (encoder == nil ) {
466
- encoder = [command_buffer computeCommandEncoder ];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
467
576
}
468
577
469
578
if (ggml_nelements (src1) == ne10) {
@@ -484,7 +593,7 @@ void ggml_metal_graph_compute(
484
593
case GGML_OP_MUL:
485
594
{
486
595
if (encoder == nil ) {
487
- encoder = [command_buffer computeCommandEncoder ];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
488
597
}
489
598
490
599
if (ggml_nelements (src1) == ne10) {
@@ -505,7 +614,7 @@ void ggml_metal_graph_compute(
505
614
case GGML_OP_SCALE:
506
615
{
507
616
if (encoder == nil ) {
508
- encoder = [command_buffer computeCommandEncoder ];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
509
618
}
510
619
511
620
const float scale = *(const float *) src1->data ;
@@ -524,7 +633,7 @@ void ggml_metal_graph_compute(
524
633
case GGML_UNARY_OP_SILU:
525
634
{
526
635
if (encoder == nil ) {
527
- encoder = [command_buffer computeCommandEncoder ];
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
528
637
}
529
638
530
639
[encoder setComputePipelineState: ctx->pipeline_silu];
@@ -538,7 +647,7 @@ void ggml_metal_graph_compute(
538
647
case GGML_UNARY_OP_RELU:
539
648
{
540
649
if (encoder == nil ) {
541
- encoder = [command_buffer computeCommandEncoder ];
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
542
651
}
543
652
544
653
[encoder setComputePipelineState: ctx->pipeline_relu];
@@ -552,7 +661,7 @@ void ggml_metal_graph_compute(
552
661
case GGML_UNARY_OP_GELU:
553
662
{
554
663
if (encoder == nil ) {
555
- encoder = [command_buffer computeCommandEncoder ];
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
556
665
}
557
666
558
667
[encoder setComputePipelineState: ctx->pipeline_gelu];
@@ -572,7 +681,7 @@ void ggml_metal_graph_compute(
572
681
case GGML_OP_SOFT_MAX:
573
682
{
574
683
if (encoder == nil ) {
575
- encoder = [command_buffer computeCommandEncoder ];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
576
685
}
577
686
578
687
const int nth = 32 ;
@@ -590,7 +699,7 @@ void ggml_metal_graph_compute(
590
699
case GGML_OP_DIAG_MASK_INF:
591
700
{
592
701
if (encoder == nil ) {
593
- encoder = [command_buffer computeCommandEncoder ];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
594
703
}
595
704
596
705
const int n_past = ((int32_t *)(dst->op_params ))[0 ];
@@ -653,7 +762,7 @@ void ggml_metal_graph_compute(
653
762
}
654
763
} else {
655
764
if (encoder == nil ) {
656
- encoder = [command_buffer computeCommandEncoder ];
765
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
657
766
}
658
767
659
768
int nth0 = 32 ;
@@ -780,7 +889,7 @@ void ggml_metal_graph_compute(
780
889
case GGML_OP_GET_ROWS:
781
890
{
782
891
if (encoder == nil ) {
783
- encoder = [command_buffer computeCommandEncoder ];
892
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
784
893
}
785
894
786
895
switch (src0->type ) {
@@ -809,7 +918,7 @@ void ggml_metal_graph_compute(
809
918
case GGML_OP_RMS_NORM:
810
919
{
811
920
if (encoder == nil ) {
812
- encoder = [command_buffer computeCommandEncoder ];
921
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
813
922
}
814
923
815
924
float eps;
@@ -832,7 +941,7 @@ void ggml_metal_graph_compute(
832
941
case GGML_OP_NORM:
833
942
{
834
943
if (encoder == nil ) {
835
- encoder = [command_buffer computeCommandEncoder ];
944
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
836
945
}
837
946
838
947
const float eps = 1e-5f ;
@@ -854,7 +963,7 @@ void ggml_metal_graph_compute(
854
963
case GGML_OP_ALIBI:
855
964
{
856
965
if (encoder == nil ) {
857
- encoder = [command_buffer computeCommandEncoder ];
966
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
858
967
}
859
968
860
969
GGML_ASSERT ((src0t == GGML_TYPE_F32));
@@ -897,7 +1006,7 @@ void ggml_metal_graph_compute(
897
1006
case GGML_OP_ROPE:
898
1007
{
899
1008
if (encoder == nil ) {
900
- encoder = [command_buffer computeCommandEncoder ];
1009
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
901
1010
}
902
1011
903
1012
const int n_past = ((int32_t *) dst->op_params )[0 ];
@@ -941,7 +1050,7 @@ void ggml_metal_graph_compute(
941
1050
case GGML_OP_CONT:
942
1051
{
943
1052
if (encoder == nil ) {
944
- encoder = [command_buffer computeCommandEncoder ];
1053
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc ];
945
1054
}
946
1055
947
1056
const int nth = 32 ;
0 commit comments