Skip to content

Commit f2e8616

Browse files
committed
ggml : restore ggml_get_n_tasks() logic in ggml_graph_plan()
1 parent 3d154ad commit f2e8616

File tree

2 files changed

+7
-26
lines changed

2 files changed

+7
-26
lines changed

CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -662,11 +662,11 @@ add_library(ggml OBJECT
662662
ggml-backend.h
663663
ggml-quants.c
664664
ggml-quants.h
665-
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
665+
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
666666
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
667-
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
668-
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
669-
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
667+
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
668+
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
669+
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
670670
)
671671

672672
target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})

ggml.c

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8459,6 +8459,7 @@ static void ggml_compute_forward_concat_f32(
84598459
GGML_ASSERT(src0->nb[0] == sizeof(float));
84608460

84618461
const int ith = params->ith;
8462+
const int nth = params->nth;
84628463

84638464
GGML_TENSOR_BINARY_OP_LOCALS
84648465

@@ -16100,35 +16101,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1610016101

1610116102
// thread scheduling for the different operations + work buffer size estimation
1610216103
for (int i = 0; i < cgraph->n_nodes; i++) {
16103-
int n_tasks = 1;
16104-
1610516104
struct ggml_tensor * node = cgraph->nodes[i];
1610616105

16106+
const int n_tasks = ggml_get_n_tasks(node, n_threads);
16107+
1610716108
size_t cur = 0;
1610816109

1610916110
switch (node->op) {
1611016111
case GGML_OP_CPY:
1611116112
case GGML_OP_DUP:
1611216113
{
16113-
n_tasks = n_threads;
16114-
1611516114
if (ggml_is_quantized(node->type)) {
1611616115
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1611716116
}
1611816117
} break;
1611916118
case GGML_OP_ADD:
1612016119
case GGML_OP_ADD1:
1612116120
{
16122-
n_tasks = n_threads;
16123-
1612416121
if (ggml_is_quantized(node->src[0]->type)) {
1612516122
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1612616123
}
1612716124
} break;
1612816125
case GGML_OP_ACC:
1612916126
{
16130-
n_tasks = n_threads;
16131-
1613216127
if (ggml_is_quantized(node->src[0]->type)) {
1613316128
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
1613416129
}
@@ -16173,8 +16168,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1617316168
} break;
1617416169
case GGML_OP_OUT_PROD:
1617516170
{
16176-
n_tasks = n_threads;
16177-
1617816171
if (ggml_is_quantized(node->src[0]->type)) {
1617916172
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1618016173
}
@@ -16208,10 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1620816201
GGML_ASSERT(false);
1620916202
}
1621016203
} break;
16211-
case GGML_OP_IM2COL:
16212-
{
16213-
n_tasks = n_threads;
16214-
} break;
1621516204
case GGML_OP_CONV_TRANSPOSE_2D:
1621616205
{
1621716206
const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -16228,8 +16217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1622816217
} break;
1622916218
case GGML_OP_FLASH_ATTN:
1623016219
{
16231-
n_tasks = n_threads;
16232-
1623316220
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
1623416221

1623516222
if (node->src[1]->type == GGML_TYPE_F32) {
@@ -16242,8 +16229,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1624216229
} break;
1624316230
case GGML_OP_FLASH_FF:
1624416231
{
16245-
n_tasks = n_threads;
16246-
1624716232
if (node->src[1]->type == GGML_TYPE_F32) {
1624816233
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
1624916234
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16254,8 +16239,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1625416239
} break;
1625516240
case GGML_OP_FLASH_ATTN_BACK:
1625616241
{
16257-
n_tasks = n_threads;
16258-
1625916242
const int64_t D = node->src[0]->ne[0];
1626016243
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
1626116244
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16270,8 +16253,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1627016253

1627116254
case GGML_OP_CROSS_ENTROPY_LOSS:
1627216255
{
16273-
n_tasks = n_threads;
16274-
1627516256
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
1627616257
} break;
1627716258
case GGML_OP_COUNT:

0 commit comments

Comments
 (0)