Skip to content

Commit 528deeb

Browse files
authored
chat.cpp: simplify calls to apply to ensure systematic propagation of extra_context (+ the odd existing additional_context)
1 parent e45a43a commit 528deeb

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

common/chat.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <string>
1717
#include <vector>
1818

19+
using json = nlohmann::ordered_json;
1920

2021
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
2122
auto time = std::chrono::system_clock::to_time_t(now);
@@ -721,16 +722,23 @@ static void foreach_function(const json & tools, const std::function<void(const
721722

722723
static std::string apply(
723724
const common_chat_template & tmpl,
724-
const nlohmann::ordered_json & messages,
725-
const nlohmann::ordered_json & tools,
726-
bool add_generation_prompt,
727-
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
725+
const struct templates_params & inputs,
726+
const std::optional<json> & messages_override = std::nullopt,
727+
const std::optional<json> & tools_override = std::nullopt,
728+
const std::optional<json> & additional_context = std::nullopt)
728729
{
729730
minja::chat_template_inputs tmpl_inputs;
730-
tmpl_inputs.messages = messages;
731-
tmpl_inputs.tools = tools;
732-
tmpl_inputs.add_generation_prompt = add_generation_prompt;
733-
tmpl_inputs.extra_context = extra_context;
731+
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
732+
if (tools_override) {
733+
tmpl_inputs.tools = *tools_override;
734+
} else {
735+
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
736+
}
737+
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
738+
tmpl_inputs.extra_context = inputs.extra_context;
739+
if (additional_context) {
740+
tmpl_inputs.extra_context.merge_patch(*additional_context);
741+
}
734742
// TODO: add flag to control date/time, if only for testing purposes.
735743
// tmpl_inputs.now = std::chrono::system_clock::now();
736744

@@ -829,7 +837,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
829837
inputs.messages,
830838
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
831839

832-
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context);
840+
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
833841
data.format = COMMON_CHAT_FORMAT_GENERIC;
834842
return data;
835843
}
@@ -901,7 +909,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
901909
data.preserved_tokens = {
902910
"[TOOL_CALLS]",
903911
};
904-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
912+
data.prompt = apply(tmpl, inputs);
905913
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
906914
return data;
907915
}
@@ -926,7 +934,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
926934
adjusted_messages.push_back(msg);
927935
}
928936
}
929-
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
937+
data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
930938
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
931939
if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
932940
if (!inputs.enable_thinking) {
@@ -1119,7 +1127,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11191127
} else {
11201128
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
11211129
}
1122-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1130+
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
11231131
{"date_string", format_time(inputs.now, "%d %b %Y")},
11241132
{"tools_in_user_message", false},
11251133
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
@@ -1181,7 +1189,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11811189

11821190
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
11831191
common_chat_params data;
1184-
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1192+
auto prompt = apply(tmpl, inputs);
11851193

11861194
// Hacks to fix the official (broken) prompt.
11871195
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
@@ -1272,7 +1280,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12721280
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
12731281
LOG_DBG("%s\n", __func__);
12741282
common_chat_params data;
1275-
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1283+
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json {
12761284
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
12771285
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
12781286
});
@@ -1324,7 +1332,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
13241332
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
13251333
// If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
13261334
common_chat_params data;
1327-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1335+
data.prompt = apply(tmpl, inputs);
13281336
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
13291337
if (inputs.tools.is_array() && !inputs.tools.empty()) {
13301338
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1451,7 +1459,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
14511459
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
14521460
}
14531461

1454-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1462+
data.prompt = apply(tmpl, inputs);
14551463
// TODO: if (has_raw_python)
14561464
return data;
14571465
}
@@ -1481,11 +1489,9 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
14811489
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
14821490
common_chat_params data;
14831491

1484-
json additional_context = {
1492+
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
14851493
{"enable_thinking", inputs.enable_thinking},
1486-
};
1487-
1488-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context);
1494+
});
14891495
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
14901496
if (string_ends_with(data.prompt, "<think>\n")) {
14911497
if (!inputs.enable_thinking) {
@@ -1672,7 +1678,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16721678

16731679
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
16741680
common_chat_params data;
1675-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context);
1681+
data.prompt = apply(tmpl, inputs);
16761682
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
16771683
data.grammar_lazy = false;
16781684
if (!inputs.json_schema.is_null()) {

0 commit comments

Comments
 (0)