Skip to content

Fuse matrix multiplication + SiLU #5413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,7 @@ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");


static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"IDENTITY",
"ABS",
"SGN",
"NEG",
Expand All @@ -2338,7 +2339,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID",
};

static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");


static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
Expand Down Expand Up @@ -11765,6 +11766,7 @@ static void ggml_compute_forward_mul_mat(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
int64_t const vec_dot_num_rows = type_traits[type].nrows;
enum ggml_unary_op const activation = (enum ggml_unary_op) ggml_get_op_params_i32(dst, 0);

GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
Expand Down Expand Up @@ -12028,6 +12030,20 @@ UseGgmlGemm2:;
for (int cn = 0; cn < nrc; ++cn) {
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}

float * dst_ptr = &dst_col[iir0];
const int64_t n = MIN(iir0 + blck_0, ir011) - iir0;
switch (activation) {
case GGML_UNARY_OP_IDENTITY:
memcpy(dst_ptr, tmp, n*sizeof(float));
break;
case GGML_UNARY_OP_SILU:
ggml_vec_silu_f32(n, dst_ptr, tmp);
break;
default:
GGML_ASSERT(false);
break;
}
}
}
}
Expand Down Expand Up @@ -16806,6 +16822,8 @@ static void ggml_compute_forward_unary(
const enum ggml_unary_op op = ggml_get_unary_op(dst);

switch (op) {
case GGML_UNARY_OP_IDENTITY:
break; // nothing to do
case GGML_UNARY_OP_ABS:
{
ggml_compute_forward_abs(params, dst);
Expand Down Expand Up @@ -19143,6 +19161,20 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) {
memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *));
}

void ggml_graph_optimize(struct ggml_cgraph * cgraph) {
for (int i = 1; i < cgraph->n_nodes; ++i) {
struct ggml_tensor * node_current = cgraph->nodes[i-0];
struct ggml_tensor * node_previous = cgraph->nodes[i-1];

if (node_current->op == GGML_OP_UNARY && ggml_get_unary_op(node_current) == GGML_UNARY_OP_SILU
&& node_previous->op == GGML_OP_MUL_MAT) {

ggml_set_op_params_i32(node_previous, 0, ggml_get_op_params_i32(node_current, 0));
ggml_set_op_params_i32(node_current, 0, (int32_t) GGML_UNARY_OP_IDENTITY);
}
}
}

//
// thread data
//
Expand Down Expand Up @@ -19348,6 +19380,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_IDENTITY:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:
Expand Down
3 changes: 3 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ extern "C" {
};

enum ggml_unary_op {
GGML_UNARY_OP_IDENTITY,

GGML_UNARY_OP_ABS,
GGML_UNARY_OP_SGN,
GGML_UNARY_OP_NEG,
Expand Down Expand Up @@ -1999,6 +2001,7 @@ extern "C" {
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_optimize (struct ggml_cgraph * cgraph);

GGML_API size_t ggml_graph_overhead(void);
GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
Expand Down
2 changes: 2 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10919,6 +10919,8 @@ static struct ggml_cgraph * llama_build_graph(

llm.free();

ggml_graph_optimize(result);

return result;
}

Expand Down
Loading