@@ -8459,6 +8459,7 @@ static void ggml_compute_forward_concat_f32(
8459
8459
GGML_ASSERT(src0->nb[0] == sizeof(float));
8460
8460
8461
8461
const int ith = params->ith;
8462
+ const int nth = params->nth;
8462
8463
8463
8464
GGML_TENSOR_BINARY_OP_LOCALS
8464
8465
@@ -16100,35 +16101,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16100
16101
16101
16102
// thread scheduling for the different operations + work buffer size estimation
16102
16103
for (int i = 0; i < cgraph->n_nodes; i++) {
16103
- int n_tasks = 1;
16104
-
16105
16104
struct ggml_tensor * node = cgraph->nodes[i];
16106
16105
16106
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
16107
+
16107
16108
size_t cur = 0;
16108
16109
16109
16110
switch (node->op) {
16110
16111
case GGML_OP_CPY:
16111
16112
case GGML_OP_DUP:
16112
16113
{
16113
- n_tasks = n_threads;
16114
-
16115
16114
if (ggml_is_quantized(node->type)) {
16116
16115
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
16117
16116
}
16118
16117
} break;
16119
16118
case GGML_OP_ADD:
16120
16119
case GGML_OP_ADD1:
16121
16120
{
16122
- n_tasks = n_threads;
16123
-
16124
16121
if (ggml_is_quantized(node->src[0]->type)) {
16125
16122
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16126
16123
}
16127
16124
} break;
16128
16125
case GGML_OP_ACC:
16129
16126
{
16130
- n_tasks = n_threads;
16131
-
16132
16127
if (ggml_is_quantized(node->src[0]->type)) {
16133
16128
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
16134
16129
}
@@ -16173,8 +16168,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16173
16168
} break;
16174
16169
case GGML_OP_OUT_PROD:
16175
16170
{
16176
- n_tasks = n_threads;
16177
-
16178
16171
if (ggml_is_quantized(node->src[0]->type)) {
16179
16172
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16180
16173
}
@@ -16208,10 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16208
16201
GGML_ASSERT(false);
16209
16202
}
16210
16203
} break;
16211
- case GGML_OP_IM2COL:
16212
- {
16213
- n_tasks = n_threads;
16214
- } break;
16215
16204
case GGML_OP_CONV_TRANSPOSE_2D:
16216
16205
{
16217
16206
const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -16228,8 +16217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16228
16217
} break;
16229
16218
case GGML_OP_FLASH_ATTN:
16230
16219
{
16231
- n_tasks = n_threads;
16232
-
16233
16220
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16234
16221
16235
16222
if (node->src[1]->type == GGML_TYPE_F32) {
@@ -16242,8 +16229,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16242
16229
} break;
16243
16230
case GGML_OP_FLASH_FF:
16244
16231
{
16245
- n_tasks = n_threads;
16246
-
16247
16232
if (node->src[1]->type == GGML_TYPE_F32) {
16248
16233
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16249
16234
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16254,8 +16239,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16254
16239
} break;
16255
16240
case GGML_OP_FLASH_ATTN_BACK:
16256
16241
{
16257
- n_tasks = n_threads;
16258
-
16259
16242
const int64_t D = node->src[0]->ne[0];
16260
16243
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16261
16244
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16270,8 +16253,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16270
16253
16271
16254
case GGML_OP_CROSS_ENTROPY_LOSS:
16272
16255
{
16273
- n_tasks = n_threads;
16274
-
16275
16256
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16276
16257
} break;
16277
16258
case GGML_OP_COUNT:
0 commit comments