Skip to content

Commit 50058cb

Browse files
committed
feat: Guardrail for input validation
Signed-off-by: Karson To <karsontao@hotmail.com>
1 parent f78b549 commit 50058cb

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* The {@code GuardrailAdvisor} class is an implementation of both {@link CallAdvisor} and {@link StreamAdvisor}
3+
* that provides flexible input and output validation for chat client requests and responses.
4+
*
5+
* This advisor allows you to define custom validation logic for both user input and model output
6+
* by supplying {@link Predicate} functions. If the input or output does not pass the specified validation,
7+
* a configurable failure response is returned instead of proceeding with the normal processing chain.
8+
*
9+
* Typical use cases include enforcing content policies, blocking sensitive or inappropriate content,
10+
* or implementing custom guardrails for AI-powered chat applications.
11+
*
12+
* The class also provides a builder for convenient and readable instantiation.
13+
*
14+
* Example usage:
15+
* <pre>
16+
* GuardrailAdvisor advisor = new GuardrailAdvisor.Builder()
17+
* .inputValidator(input -> !input.contains("forbidden"))
18+
* .outputValidator(output -> !output.contains("restricted"))
19+
* .failureResponse("Your request cannot be processed due to policy restrictions.")
20+
* .order(1)
21+
* .build();
22+
* </pre>
23+
*
24+
* @author Karson To
25+
* @since 1.0.0
26+
*/
27+
28+
package org.springframework.ai.chat.client.advisor;
29+
30+
import org.springframework.ai.chat.client.ChatClientRequest;
31+
32+
import org.springframework.ai.chat.client.ChatClientResponse;
33+
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
34+
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
35+
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
36+
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
37+
import org.springframework.ai.chat.messages.AssistantMessage;
38+
import org.springframework.ai.chat.model.ChatResponse;
39+
import org.springframework.ai.chat.model.Generation;
40+
import org.springframework.util.Assert;
41+
import reactor.core.publisher.Flux;
42+
43+
import java.util.List;
44+
import java.util.Map;
45+
import java.util.function.Predicate;
46+
47+
public class GuardrailAdvisor implements CallAdvisor, StreamAdvisor {
48+
49+
private static final String DEFAULT_FAILURE_RESPONSE =
50+
"Sorry, your request cannot be processed because it contains content that does not comply with our policy. "
51+
+ "Please revise your input and try again.";
52+
53+
private static final int DEFAULT_ORDER = 0;
54+
55+
private final String failureResponse;
56+
57+
private final Predicate<String> inputValidator;
58+
59+
private final Predicate<String> outputValidator;
60+
61+
private final int order;
62+
63+
64+
public GuardrailAdvisor(Predicate<String> inputValidator, Predicate<String> outputValidator, String failureResponse,
65+
int order) {
66+
Assert.notNull(inputValidator, "Input validator must not be null!");
67+
Assert.notNull(outputValidator, "Output validator must not be null!");
68+
Assert.notNull(failureResponse, "Failure response must not be null!");
69+
this.inputValidator = inputValidator;
70+
this.outputValidator = outputValidator;
71+
this.failureResponse = failureResponse;
72+
this.order = order;
73+
}
74+
75+
private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) {
76+
return ChatClientResponse.builder().chatResponse(
77+
ChatResponse.builder().generations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
78+
.build()).context(Map.copyOf(chatClientRequest.context())).build();
79+
}
80+
81+
@Override
82+
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
83+
String input = chatClientRequest.prompt().getContents();
84+
if (!inputValidator.test(input)) {
85+
return createFailureResponse(chatClientRequest);
86+
}
87+
ChatClientResponse response = callAdvisorChain.nextCall(chatClientRequest);
88+
String output = null;
89+
if (response != null && response.chatResponse() != null && response.chatResponse().getResults() != null
90+
&& !response.chatResponse().getResults().isEmpty()) {
91+
Generation generation = response.chatResponse().getResults().get(0);
92+
if (generation != null && generation.getOutput() != null) {
93+
output = generation.getOutput().getText();
94+
}
95+
}
96+
if (!outputValidator.test(output != null ? output : "")) {
97+
return createFailureResponse(chatClientRequest);
98+
}
99+
return response;
100+
}
101+
102+
@Override
103+
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
104+
StreamAdvisorChain streamAdvisorChain) {
105+
String input = chatClientRequest.prompt().getContents();
106+
if (!inputValidator.test(input)) {
107+
return Flux.just(createFailureResponse(chatClientRequest));
108+
}
109+
return streamAdvisorChain.nextStream(chatClientRequest).map(response -> {
110+
String output = null;
111+
if (response != null && response.chatResponse() != null && response.chatResponse().getResults() != null
112+
&& !response.chatResponse().getResults().isEmpty()) {
113+
Generation generation = response.chatResponse().getResults().get(0);
114+
if (generation != null && generation.getOutput() != null) {
115+
output = generation.getOutput().getText();
116+
}
117+
}
118+
if (!outputValidator.test(output != null ? output : "")) {
119+
return createFailureResponse(chatClientRequest);
120+
}
121+
return response;
122+
});
123+
}
124+
125+
@Override
126+
public String getName() {
127+
return this.getClass().getSimpleName();
128+
}
129+
130+
@Override
131+
public int getOrder() {
132+
return this.order;
133+
}
134+
135+
public static final class Builder {
136+
137+
private Predicate<String> inputValidator = s -> true;
138+
139+
private Predicate<String> outputValidator = s -> true;
140+
141+
private String failureResponse = DEFAULT_FAILURE_RESPONSE;
142+
143+
private int order = DEFAULT_ORDER;
144+
145+
private Builder() {
146+
}
147+
148+
public static Builder builder() {
149+
return new Builder();
150+
}
151+
152+
public Builder inputValidator(Predicate<String> inputValidator) {
153+
this.inputValidator = inputValidator;
154+
return this;
155+
}
156+
157+
public Builder outputValidator(Predicate<String> outputValidator) {
158+
this.outputValidator = outputValidator;
159+
return this;
160+
}
161+
162+
public Builder failureResponse(String failureResponse) {
163+
this.failureResponse = failureResponse;
164+
return this;
165+
}
166+
167+
public Builder order(int order) {
168+
this.order = order;
169+
return this;
170+
}
171+
172+
public GuardrailAdvisor build() {
173+
return new GuardrailAdvisor(this.inputValidator, this.outputValidator, this.failureResponse, this.order);
174+
}
175+
}
176+
}

0 commit comments

Comments
 (0)