Skip to content

Commit aea38ec

Browse files
author
Bogdan Tsechoev
committed
Merge branch 'bot_ui_streaming' into 'master'
Refactor message handling to real-time WebSocket updates See merge request postgres-ai/database-lab!890
2 parents 86a0fba + 3a45610 commit aea38ec

File tree

7 files changed

+120
-51
lines changed

7 files changed

+120
-51
lines changed

ui/packages/platform/src/pages/Bot/Command/Command.tsx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ export const Command = React.memo((props: Props) => {
8585
wsLoading,
8686
loading,
8787
sendMessage,
88-
chatVisibility
88+
chatVisibility,
89+
isStreamingInProcess
8990
} = useAiBot();
9091

91-
const sendDisabled = error !== null || loading || wsLoading || wsReadyState !== ReadyState.OPEN;
92+
const sendDisabled = error !== null || loading || wsLoading || wsReadyState !== ReadyState.OPEN || isStreamingInProcess;
9293

9394
// Handle value.
9495
const [value, setValue] = useState('')

ui/packages/platform/src/pages/Bot/Messages/Message/Message.tsx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ const useStyles = makeStyles(
203203
marginBlockEnd: '1em',
204204
marginInlineStart: 0,
205205
marginInlineEnd: 0,
206+
//animation: `$typing 0.5s steps(30, end), $blinkCaret 0.75s step-end infinite`,
206207
},
207208
'& .MuiExpansionPanel-root div': {
208209
marginBlockStart: 0,
@@ -216,6 +217,7 @@ const useStyles = makeStyles(
216217
marginInlineStart: 0,
217218
marginInlineEnd: 0,
218219
fontSize: 14,
220+
color: colors.pgaiDarkGray,
219221
'&:after': {
220222
overflow: 'hidden',
221223
display: 'inline-block',
@@ -229,7 +231,16 @@ const useStyles = makeStyles(
229231
'to': {
230232
width: '0.9em'
231233
},
232-
}
234+
},
235+
'@keyframes typing': {
236+
from: { width: 0 },
237+
to: { width: '100%' },
238+
},
239+
'@keyframes blinkCaret': {
240+
from: { borderRightColor: 'transparent' },
241+
to: { borderRightColor: 'transparent' },
242+
'50%': { borderRightColor: 'black' },
243+
},
233244
}),
234245

235246
)
@@ -336,7 +347,8 @@ export const Message = React.memo((props: MessageProps) => {
336347
</div>
337348
<div>
338349
{isLoading
339-
? <div className={classes.markdown}>
350+
?
351+
<div className={classes.markdown}>
340352
<div className={classes.loading}>
341353
{stateMessage && stateMessage.state ? stateMessage.state : 'Thinking'}
342354
</div>

ui/packages/platform/src/pages/Bot/Messages/Messages.tsx

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,22 @@ const useStyles = makeStyles(
111111
padding: 10
112112
}
113113
}),
114-
115114
)
116115

117-
type MessagesProps = {
118-
messages: BotMessageWithDebugInfo[] | null
119-
isLoading: boolean
120-
isWaitingForAnswer: boolean
121-
}
122-
123116
type Time = string
124117

125118
type FormattedTime = {
126119
[id: string]: Time
127120
}
128121

129122
export const Messages = React.memo(() => {
130-
const { messages, loading: isLoading, wsLoading: isWaitingForAnswer, stateMessage } = useAiBot();
123+
const {
124+
messages,
125+
loading: isLoading,
126+
wsLoading: isWaitingForAnswer,
127+
stateMessage,
128+
currentStreamMessage
129+
} = useAiBot();
131130

132131
const rootRef = useRef<HTMLDivElement>(null);
133132
const wrapperRef = useRef<HTMLDivElement>(null);
@@ -275,6 +274,14 @@ export const Messages = React.memo(() => {
275274
/>
276275
)
277276
})}
277+
{
278+
currentStreamMessage && <Message
279+
id={null}
280+
isAi
281+
content={currentStreamMessage.content}
282+
aiModel={currentStreamMessage.ai_model}
283+
/>
284+
}
278285
{isWaitingForAnswer &&
279286
<Message id={null} isLoading isAi={true} stateMessage={stateMessage} />
280287
}

ui/packages/platform/src/pages/Bot/hooks.tsx

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import React, { createContext, useCallback, useContext, useEffect, useState } from "react";
99
import useWebSocket, {ReadyState} from "react-use-websocket";
1010
import { useLocation } from "react-router-dom";
11-
import { BotMessage, DebugMessage, AiModel, StateMessage } from "../../types/api/entities/bot";
11+
import { BotMessage, DebugMessage, AiModel, StateMessage, StreamMessage } from "../../types/api/entities/bot";
1212
import {getChatsWithWholeThreads} from "../../api/bot/getChatsWithWholeThreads";
1313
import {getChats} from "api/bot/getChats";
1414
import {useAlertSnackbar} from "@postgres.ai/shared/components/AlertSnackbar/useAlertSnackbar";
@@ -58,6 +58,8 @@ type UseAiBotReturnType = {
5858
aiModelsLoading: UseAiModelsList['loading'];
5959
debugMessagesLoading: boolean;
6060
stateMessage: StateMessage | null;
61+
isStreamingInProcess: boolean
62+
currentStreamMessage: StreamMessage | null
6163
}
6264

6365
type UseAiBotArgs = {
@@ -90,6 +92,8 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
9092
const [wsLoading, setWsLoading] = useState<boolean>(false);
9193
const [chatVisibility, setChatVisibility] = useState<UseAiBotReturnType['chatVisibility']>('public');
9294
const [stateMessage, setStateMessage] = useState<StateMessage | null>(null)
95+
const [currentStreamMessage, setCurrentStreamMessage] = useState<StreamMessage | null>(null)
96+
const [isStreamingInProcess, setStreamingInProcess] = useState<boolean>(false)
9397

9498
const [isChangeVisibilityLoading, setIsChangeVisibilityLoading] = useState<boolean>(false);
9599

@@ -102,51 +106,35 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
102106

103107
const onWebSocketMessage = (event: WebSocketEventMap['message']) => {
104108
if (event.data) {
105-
const messageData: BotMessage | DebugMessage | StateMessage = JSON.parse(event.data);
109+
const messageData: BotMessage | DebugMessage | StateMessage | StreamMessage = JSON.parse(event.data);
106110
if (messageData) {
107111
const isThreadMatching = threadId && threadId === messageData.thread_id;
108112
const isParentMatching = !threadId && 'parent_id' in messageData && messageData.parent_id && messages;
109113
const isDebugMessage = messageData.type === 'debug';
110114
const isStateMessage = messageData.type === 'state';
111-
if (isThreadMatching || isParentMatching || isDebugMessage || isStateMessage) {
112-
if (isDebugMessage) {
113-
let currentDebugMessages = [...(debugMessages || [])];
114-
currentDebugMessages.push(messageData)
115-
setDebugMessages(currentDebugMessages)
116-
} else if (isStateMessage) {
117-
if (isThreadMatching || !threadId) {
118-
if (messageData.state) {
119-
setStateMessage(messageData)
120-
} else {
121-
setStateMessage(null)
122-
}
123-
}
124-
} else {
125-
// Check if the last message needs its data updated
126-
let currentMessages = [...(messages || [])];
127-
const lastMessage = currentMessages[currentMessages.length - 1];
128-
if (lastMessage && !lastMessage.id && messageData.parent_id) {
129-
lastMessage.id = messageData.parent_id;
130-
lastMessage.created_at = messageData.created_at;
131-
lastMessage.is_public = messageData.is_public;
132-
}
133-
134-
currentMessages.push(messageData);
135-
setMessages(currentMessages);
136-
setWsLoading(false);
137-
if (document.visibilityState === "hidden") {
138-
if (Notification.permission === "granted") {
139-
new Notification("New message", {
140-
body: 'New message from Postgres.AI Bot',
141-
icon: '/images/bot_avatar.png'
142-
});
143-
}
144-
}
115+
const isStreamMessage = messageData.type === 'stream';
116+
117+
if (isThreadMatching || isParentMatching || isDebugMessage || isStateMessage || isStreamMessage) {
118+
switch (messageData.type) {
119+
case 'debug':
120+
handleDebugMessage(messageData)
121+
break;
122+
case 'state':
123+
handleStateMessage(messageData, Boolean(isThreadMatching))
124+
break;
125+
case 'stream':
126+
handleStreamMessage(messageData, Boolean(isThreadMatching))
127+
break;
128+
case 'message':
129+
handleBotMessage(messageData)
130+
break;
145131
}
146132
} else if (threadId !== messageData.thread_id) {
147133
const threadInList = chatsList?.find((item) => item.thread_id === messageData.thread_id)
148134
if (!threadInList) getChatsList()
149-
setWsLoading(false);
135+
if (currentStreamMessage) setCurrentStreamMessage(null)
136+
if (wsLoading) setWsLoading(false);
137+
if (isStreamingInProcess) setStreamingInProcess(false)
150138
}
151139
} else {
152140
showMessage('An error occurred. Please try again')
@@ -158,6 +146,56 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
158146
setLoading(false);
159147
}
160148

149+
const handleDebugMessage = (message: DebugMessage) => {
150+
let currentDebugMessages = [...(debugMessages || [])];
151+
currentDebugMessages.push(message)
152+
setDebugMessages(currentDebugMessages)
153+
}
154+
155+
const handleStateMessage = (message: StateMessage, isThreadMatching?: boolean) => {
156+
if (isThreadMatching || !threadId) {
157+
if (message.state) {
158+
setStateMessage(message)
159+
} else {
160+
setStateMessage(null)
161+
}
162+
}
163+
}
164+
165+
const handleStreamMessage = (message: StreamMessage, isThreadMatching?: boolean) => {
166+
if (isThreadMatching || !threadId) {
167+
if (!isStreamingInProcess) setStreamingInProcess(true)
168+
setCurrentStreamMessage(message)
169+
setWsLoading(false);
170+
}
171+
}
172+
173+
const handleBotMessage = (message: BotMessage) => {
174+
if (messages && messages.length > 0) {
175+
let currentMessages = [...messages];
176+
const lastMessage = currentMessages[currentMessages.length - 1];
177+
if (lastMessage && !lastMessage.id && message.parent_id) {
178+
lastMessage.id = message.parent_id;
179+
lastMessage.created_at = message.created_at;
180+
lastMessage.is_public = message.is_public;
181+
}
182+
183+
currentMessages.push(message);
184+
if (currentStreamMessage) setCurrentStreamMessage(null)
185+
setMessages(currentMessages);
186+
setWsLoading(false);
187+
setStreamingInProcess(false);
188+
if (document.visibilityState === "hidden") {
189+
if (Notification.permission === "granted") {
190+
new Notification("New message", {
191+
body: 'New message from Postgres.AI Bot',
192+
icon: '/images/bot_avatar.png'
193+
});
194+
}
195+
}
196+
}
197+
}
198+
161199
const onWebSocketOpen = () => {
162200
console.log('WebSocket connection established');
163201
if (threadId) {
@@ -381,6 +419,8 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
381419
debugMessages,
382420
debugMessagesLoading,
383421
stateMessage,
422+
isStreamingInProcess,
423+
currentStreamMessage
384424
}
385425
}
386426

ui/packages/platform/src/pages/Bot/index.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ export const BotPage = (props: BotPageProps) => {
156156
}
157157

158158
useEffect(() => {
159-
if (!match.params.threadId && !prevThreadId && messages && messages.length > 1 && messages[1].parent_id) {
159+
if (!match.params.threadId && !prevThreadId && messages && messages.length > 0 && messages[0].id) {
160160
// hack that skip additional loading chats_ancestors_and_descendants
161-
history.replace(`/${match.params.org}/bot/${messages[1].parent_id}`, { skipReloading: true })
161+
history.replace(`/${match.params.org}/bot/${messages[0].id}`, { skipReloading: true })
162162
getChatsList();
163163
} else if (prevThreadId && !match.params.threadId) {
164164
clearChat()

ui/packages/platform/src/types/api/entities/bot.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,11 @@ export type StateMessage = {
4040
type: 'state'
4141
state: string | null
4242
thread_id: string
43+
}
44+
45+
export type StreamMessage = {
46+
type: 'stream'
47+
content: string
48+
ai_model: string
49+
thread_id: string
4350
}

ui/packages/platform/src/utils/format.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ const Format = {
264264
},
265265

266266
timeAgo: function (date: string | Date): string | null {
267+
if (!date) return null
268+
267269
const now = new Date();
268270
const past = new Date(date);
269271
const diff = Math.abs(now.getTime() - past.getTime());

0 commit comments

Comments
 (0)