From 3cf123e8f6d3785f85f9b166319c67dc557d69a8 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Thu, 8 Feb 2024 11:52:13 +0100 Subject: [PATCH] Fuse matrix multiplication + SiLU --- ggml.c | 36 ++++++++++++++++++++++++++++++++++-- ggml.h | 3 +++ llama.cpp | 2 ++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index f783a6fd3c336..0e771115486b3 100644 --- a/ggml.c +++ b/ggml.c @@ -1817,6 +1817,7 @@ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { + "IDENTITY", "ABS", "SGN", "NEG", @@ -1831,7 +1832,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", }; -static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12"); +static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -9992,6 +9993,7 @@ static void ggml_compute_forward_mul_mat( ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + enum ggml_unary_op const activation = (enum ggml_unary_op) ggml_get_op_params_i32(dst, 0); GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -10197,7 +10199,20 @@ static void ggml_compute_forward_mul_mat( for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); } - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + + float * dst_ptr = &dst_col[iir0]; + const int64_t n = MIN(iir0 + blck_0, ir011) - iir0; + switch (activation) { + case GGML_UNARY_OP_IDENTITY: + memcpy(dst_ptr, tmp, n*sizeof(float)); + break; + case GGML_UNARY_OP_SILU: + ggml_vec_silu_f32(n, dst_ptr, tmp); + break; + default: + GGML_ASSERT(false); + break; + } } } } @@ -14220,6 +14235,8 @@ static void ggml_compute_forward_unary( const enum ggml_unary_op op = ggml_get_unary_op(dst); switch (op) { + case GGML_UNARY_OP_IDENTITY: + break; // nothing to do case GGML_UNARY_OP_ABS: { ggml_compute_forward_abs(params, src0, dst); @@ -16517,6 +16534,20 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) { memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *)); } +void ggml_graph_optimize(struct ggml_cgraph * cgraph) { + for (int i = 1; i < cgraph->n_nodes; ++i) { + struct ggml_tensor * node_current = cgraph->nodes[i-0]; + struct ggml_tensor * node_previous = cgraph->nodes[i-1]; + + if (node_current->op == GGML_OP_UNARY && ggml_get_unary_op(node_current) == GGML_UNARY_OP_SILU + && node_previous->op == GGML_OP_MUL_MAT) { + + ggml_set_op_params_i32(node_previous, 0, ggml_get_op_params_i32(node_current, 0)); + ggml_set_op_params_i32(node_current, 0, (int32_t) GGML_UNARY_OP_IDENTITY); + } + } +} + // // thread data // @@ -16696,6 +16727,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_IDENTITY: case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_NEG: diff --git a/ggml.h b/ggml.h index e0a4799f3bd0a..3919863b99a50 100644 --- a/ggml.h +++ b/ggml.h @@ -481,6 +481,8 @@ extern "C" { }; enum ggml_unary_op { + GGML_UNARY_OP_IDENTITY, + GGML_UNARY_OP_ABS, GGML_UNARY_OP_SGN, GGML_UNARY_OP_NEG, @@ -1877,6 +1879,7 @@ extern "C" { GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); + GGML_API void ggml_graph_optimize (struct ggml_cgraph * cgraph); GGML_API size_t ggml_graph_overhead(void); GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); diff --git a/llama.cpp b/llama.cpp index f8f5796a43814..3377053350e55 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7162,6 +7162,8 @@ static struct ggml_cgraph * llama_build_graph( llm.free(); + ggml_graph_optimize(result); + return result; }