Skip to content

Commit 902fc09

Browse files
sunyuhan1998ilayaperumalg
authored andcommitted
fix: Fixed a bug in the augmentSystemMessage method of the Prompt, where an extra system message was incorrectly added when the system message was not the first one in the message list.
Signed-off-by: Sun Yuhan <1085481446@qq.com>
1 parent 64ea88a commit 902fc09

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,21 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
198198
* @return a new {@link Prompt} instance with the augmented system message.
199199
*/
200200
public Prompt augmentSystemMessage(Function<SystemMessage, SystemMessage> systemMessageAugmenter) {
201-
202201
var messagesCopy = new ArrayList<>(this.messages);
203-
for (int i = 0; i <= this.messages.size() - 1; i++) {
202+
boolean found = false;
203+
for (int i = 0; i < messagesCopy.size(); i++) {
204204
Message message = messagesCopy.get(i);
205205
if (message instanceof SystemMessage systemMessage) {
206206
messagesCopy.set(i, systemMessageAugmenter.apply(systemMessage));
207+
found = true;
207208
break;
208209
}
209-
if (i == 0) {
210-
// If no system message is found, create a new one with the provided text
211-
// and add it as the first item in the list.
212-
messagesCopy.add(0, systemMessageAugmenter.apply(new SystemMessage("")));
213-
}
214210
}
215-
211+
if (!found) {
212+
// If no system message is found, create a new one with the provided text
213+
// and add it as the first item in the list.
214+
messagesCopy.add(0, systemMessageAugmenter.apply(new SystemMessage("")));
215+
}
216216
return new Prompt(messagesCopy, null == this.chatOptions ? null : this.chatOptions.copy());
217217
}
218218

spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,26 @@ void augmentSystemMessageWhenNone() {
239239
assertThat(prompt.getSystemMessage().getText()).isEqualTo("");
240240
}
241241

242+
@Test
243+
void augmentSystemMessageWhenNotFirst() {
244+
Message[] messages = { new UserMessage("Hi"), new SystemMessage("Hello") };
245+
Prompt prompt = Prompt.builder().messages(messages).build();
246+
247+
assertThat(prompt.getSystemMessage()).isNotNull();
248+
assertThat(prompt.getUserMessage()).isNotNull();
249+
assertThat(prompt.getUserMessage().getText()).isEqualTo("Hi");
250+
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
251+
252+
Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("How are you?").build());
253+
254+
assertThat(copy.getSystemMessage()).isNotNull();
255+
assertThat(copy.getInstructions().size()).isEqualTo(messages.length);
256+
assertThat(copy.getSystemMessage().getText()).isEqualTo("How are you?");
257+
258+
assertThat(prompt.getSystemMessage()).isNotNull();
259+
assertThat(prompt.getUserMessage()).isNotNull();
260+
assertThat(prompt.getUserMessage().getText()).isEqualTo("Hi");
261+
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
262+
}
263+
242264
}

0 commit comments

Comments
 (0)