Skip to content

metal: concurrently dispatch commands #2358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
// get data from the device into host memory
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);

// try to find operations that can be run concurrently in the graph
// you should run it again if the topology of your graph changes
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);

// if the graph has been optimized for concurrently dispatch
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);

// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
Expand Down
146 changes: 128 additions & 18 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];

int concur_list[GGML_MAX_NODES];
int concur_list_len;

// custom kernels
#define GGML_METAL_DECL_KERNEL(name) \
id<MTLFunction> function_##name; \
Expand Down Expand Up @@ -98,6 +101,7 @@ @implementation GGMLMetalClass
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
ctx->concur_list_len = 0;

// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
Expand Down Expand Up @@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb;
}

bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
if (ctx->concur_list_len) {
return true;
}
return false;
}

// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
Expand Down Expand Up @@ -355,11 +366,98 @@ void ggml_metal_get_tensor(
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
}

void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_NODES];

for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
ctx->concur_list_len = 0;

int n_left = gf->n_nodes;
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos

while (n_left > 0) {
// number of nodes at a layer (that can be issued concurrently)
int concurrency = 0;
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
if (nodes_unused[i]) {
// if the requirements for gf->nodes[i] are satisfied
int exe_flag=1;
// scan all srcs
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
if (src_cur) {
// if is leaf nodes it's satisfied.
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}

// otherwise this src should be the output from previous nodes.
int is_found = 0;
// scan 2*search_depth back because we inserted barrier.
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
}
if (is_found == 0) {exe_flag = 0; break;}
}
}
if (exe_flag) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data;
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
for (int j = n_start; j < i; j++) {
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
&& gf->nodes[j]->op != GGML_OP_VIEW \
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue;
} else {
exe_flag = 0;
}
}
}
}
if (exe_flag) {
ctx->concur_list[level_pos + concurrency] = i;
nodes_unused[i] = 0;
concurrency++;
ctx->concur_list_len++;
}
}
}
n_left -= concurrency;
// adding a barrier different layer
ctx->concur_list[level_pos + concurrency] = -1;
ctx->concur_list_len++;
// jump all sorted nodes at nodes_bak
while (!nodes_unused[n_start]) {n_start++;}
level_pos += concurrency + 1;
}

if (ctx->concur_list_len > GGML_MAX_NODES) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
}
}

void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__);

// if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch
MTLComputePassDescriptor * encoder_descriptor = MTLComputePassDescriptor.computePassDescriptor;
encoder_descriptor.dispatchType = MTLDispatchTypeSerial;
int all_nodes_len = gf->n_nodes;
if (ctx->concur_list_len) {
encoder_descriptor.dispatchType = MTLDispatchTypeConcurrent;
all_nodes_len = ctx->concur_list_len;
}
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel

Expand All @@ -378,7 +476,7 @@ void ggml_metal_graph_compute(
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);

for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
const int n_nodes_per_cb = (all_nodes_len + n_cb - 1) / n_cb;

dispatch_async(queue, ^{
size_t offs_src0 = 0;
Expand All @@ -390,9 +488,21 @@ void ggml_metal_graph_compute(
id<MTLComputeCommandEncoder> encoder = nil;

const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? all_nodes_len : (cb_idx + 1) * n_nodes_per_cb;

for (int i = node_start; i < node_end; ++i) {
for (int ind = node_start; ind < node_end; ++ind) {
int i = ind;
if (ctx->concur_list_len) {
i = ctx->concur_list[ind];
}
if (i == -1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
continue;
}
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));

struct ggml_tensor * src0 = gf->nodes[i]->src[0];
Expand Down Expand Up @@ -463,7 +573,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

if (ggml_nelements(src1) == ne10) {
Expand All @@ -484,7 +594,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

if (ggml_nelements(src1) == ne10) {
Expand All @@ -505,7 +615,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

const float scale = *(const float *) src1->data;
Expand All @@ -524,7 +634,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

[encoder setComputePipelineState:ctx->pipeline_silu];
Expand All @@ -538,7 +648,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

[encoder setComputePipelineState:ctx->pipeline_relu];
Expand All @@ -552,7 +662,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

[encoder setComputePipelineState:ctx->pipeline_gelu];
Expand All @@ -572,7 +682,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

const int nth = 32;
Expand All @@ -590,7 +700,7 @@ void ggml_metal_graph_compute(
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

const int n_past = ((int32_t *)(dst->op_params))[0];
Expand Down Expand Up @@ -653,7 +763,7 @@ void ggml_metal_graph_compute(
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

int nth0 = 32;
Expand Down Expand Up @@ -780,7 +890,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

switch (src0->type) {
Expand Down Expand Up @@ -809,7 +919,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

float eps;
Expand All @@ -832,7 +942,7 @@ void ggml_metal_graph_compute(
case GGML_OP_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

const float eps = 1e-5f;
Expand All @@ -854,7 +964,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ALIBI:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

GGML_ASSERT((src0t == GGML_TYPE_F32));
Expand Down Expand Up @@ -897,7 +1007,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

const int n_past = ((int32_t *) dst->op_params)[0];
Expand Down Expand Up @@ -941,7 +1051,7 @@ void ggml_metal_graph_compute(
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
}

const int nth = 32;
Expand Down
3 changes: 3 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,9 @@ static bool llama_eval_internal(

#ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) {
if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
}
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
ggml_metal_get_tensor (lctx.ctx_metal, cur);
Expand Down