8
8
import React , { createContext , useCallback , useContext , useEffect , useState } from "react" ;
9
9
import useWebSocket , { ReadyState } from "react-use-websocket" ;
10
10
import { useLocation } from "react-router-dom" ;
11
- import { BotMessage , DebugMessage , AiModel , StateMessage } from "../../types/api/entities/bot" ;
11
+ import { BotMessage , DebugMessage , AiModel , StateMessage , StreamMessage } from "../../types/api/entities/bot" ;
12
12
import { getChatsWithWholeThreads } from "../../api/bot/getChatsWithWholeThreads" ;
13
13
import { getChats } from "api/bot/getChats" ;
14
14
import { useAlertSnackbar } from "@postgres.ai/shared/components/AlertSnackbar/useAlertSnackbar" ;
@@ -58,6 +58,8 @@ type UseAiBotReturnType = {
58
58
aiModelsLoading : UseAiModelsList [ 'loading' ] ;
59
59
debugMessagesLoading : boolean ;
60
60
stateMessage : StateMessage | null ;
61
+ isStreamingInProcess : boolean
62
+ currentStreamMessage : StreamMessage | null
61
63
}
62
64
63
65
type UseAiBotArgs = {
@@ -90,6 +92,8 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
90
92
const [ wsLoading , setWsLoading ] = useState < boolean > ( false ) ;
91
93
const [ chatVisibility , setChatVisibility ] = useState < UseAiBotReturnType [ 'chatVisibility' ] > ( 'public' ) ;
92
94
const [ stateMessage , setStateMessage ] = useState < StateMessage | null > ( null )
95
+ const [ currentStreamMessage , setCurrentStreamMessage ] = useState < StreamMessage | null > ( null )
96
+ const [ isStreamingInProcess , setStreamingInProcess ] = useState < boolean > ( false )
93
97
94
98
const [ isChangeVisibilityLoading , setIsChangeVisibilityLoading ] = useState < boolean > ( false ) ;
95
99
@@ -102,51 +106,35 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
102
106
103
107
const onWebSocketMessage = ( event : WebSocketEventMap [ 'message' ] ) => {
104
108
if ( event . data ) {
105
- const messageData : BotMessage | DebugMessage | StateMessage = JSON . parse ( event . data ) ;
109
+ const messageData : BotMessage | DebugMessage | StateMessage | StreamMessage = JSON . parse ( event . data ) ;
106
110
if ( messageData ) {
107
111
const isThreadMatching = threadId && threadId === messageData . thread_id ;
108
112
const isParentMatching = ! threadId && 'parent_id' in messageData && messageData . parent_id && messages ;
109
113
const isDebugMessage = messageData . type === 'debug' ;
110
114
const isStateMessage = messageData . type === 'state' ;
111
- if ( isThreadMatching || isParentMatching || isDebugMessage || isStateMessage ) {
112
- if ( isDebugMessage ) {
113
- let currentDebugMessages = [ ...( debugMessages || [ ] ) ] ;
114
- currentDebugMessages . push ( messageData )
115
- setDebugMessages ( currentDebugMessages )
116
- } else if ( isStateMessage ) {
117
- if ( isThreadMatching || ! threadId ) {
118
- if ( messageData . state ) {
119
- setStateMessage ( messageData )
120
- } else {
121
- setStateMessage ( null )
122
- }
123
- }
124
- } else {
125
- // Check if the last message needs its data updated
126
- let currentMessages = [ ...( messages || [ ] ) ] ;
127
- const lastMessage = currentMessages [ currentMessages . length - 1 ] ;
128
- if ( lastMessage && ! lastMessage . id && messageData . parent_id ) {
129
- lastMessage . id = messageData . parent_id ;
130
- lastMessage . created_at = messageData . created_at ;
131
- lastMessage . is_public = messageData . is_public ;
132
- }
133
-
134
- currentMessages . push ( messageData ) ;
135
- setMessages ( currentMessages ) ;
136
- setWsLoading ( false ) ;
137
- if ( document . visibilityState === "hidden" ) {
138
- if ( Notification . permission === "granted" ) {
139
- new Notification ( "New message" , {
140
- body : 'New message from Postgres.AI Bot' ,
141
- icon : '/images/bot_avatar.png'
142
- } ) ;
143
- }
144
- }
115
+ const isStreamMessage = messageData . type === 'stream' ;
116
+
117
+ if ( isThreadMatching || isParentMatching || isDebugMessage || isStateMessage || isStreamMessage ) {
118
+ switch ( messageData . type ) {
119
+ case 'debug' :
120
+ handleDebugMessage ( messageData )
121
+ break ;
122
+ case 'state' :
123
+ handleStateMessage ( messageData , Boolean ( isThreadMatching ) )
124
+ break ;
125
+ case 'stream' :
126
+ handleStreamMessage ( messageData , Boolean ( isThreadMatching ) )
127
+ break ;
128
+ case 'message' :
129
+ handleBotMessage ( messageData )
130
+ break ;
145
131
}
146
132
} else if ( threadId !== messageData . thread_id ) {
147
133
const threadInList = chatsList ?. find ( ( item ) => item . thread_id === messageData . thread_id )
148
134
if ( ! threadInList ) getChatsList ( )
149
- setWsLoading ( false ) ;
135
+ if ( currentStreamMessage ) setCurrentStreamMessage ( null )
136
+ if ( wsLoading ) setWsLoading ( false ) ;
137
+ if ( isStreamingInProcess ) setStreamingInProcess ( false )
150
138
}
151
139
} else {
152
140
showMessage ( 'An error occurred. Please try again' )
@@ -158,6 +146,56 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
158
146
setLoading ( false ) ;
159
147
}
160
148
149
+ const handleDebugMessage = ( message : DebugMessage ) => {
150
+ let currentDebugMessages = [ ...( debugMessages || [ ] ) ] ;
151
+ currentDebugMessages . push ( message )
152
+ setDebugMessages ( currentDebugMessages )
153
+ }
154
+
155
+ const handleStateMessage = ( message : StateMessage , isThreadMatching ?: boolean ) => {
156
+ if ( isThreadMatching || ! threadId ) {
157
+ if ( message . state ) {
158
+ setStateMessage ( message )
159
+ } else {
160
+ setStateMessage ( null )
161
+ }
162
+ }
163
+ }
164
+
165
+ const handleStreamMessage = ( message : StreamMessage , isThreadMatching ?: boolean ) => {
166
+ if ( isThreadMatching || ! threadId ) {
167
+ if ( ! isStreamingInProcess ) setStreamingInProcess ( true )
168
+ setCurrentStreamMessage ( message )
169
+ setWsLoading ( false ) ;
170
+ }
171
+ }
172
+
173
+ const handleBotMessage = ( message : BotMessage ) => {
174
+ if ( messages && messages . length > 0 ) {
175
+ let currentMessages = [ ...messages ] ;
176
+ const lastMessage = currentMessages [ currentMessages . length - 1 ] ;
177
+ if ( lastMessage && ! lastMessage . id && message . parent_id ) {
178
+ lastMessage . id = message . parent_id ;
179
+ lastMessage . created_at = message . created_at ;
180
+ lastMessage . is_public = message . is_public ;
181
+ }
182
+
183
+ currentMessages . push ( message ) ;
184
+ if ( currentStreamMessage ) setCurrentStreamMessage ( null )
185
+ setMessages ( currentMessages ) ;
186
+ setWsLoading ( false ) ;
187
+ setStreamingInProcess ( false ) ;
188
+ if ( document . visibilityState === "hidden" ) {
189
+ if ( Notification . permission === "granted" ) {
190
+ new Notification ( "New message" , {
191
+ body : 'New message from Postgres.AI Bot' ,
192
+ icon : '/images/bot_avatar.png'
193
+ } ) ;
194
+ }
195
+ }
196
+ }
197
+ }
198
+
161
199
const onWebSocketOpen = ( ) => {
162
200
console . log ( 'WebSocket connection established' ) ;
163
201
if ( threadId ) {
@@ -381,6 +419,8 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
381
419
debugMessages,
382
420
debugMessagesLoading,
383
421
stateMessage,
422
+ isStreamingInProcess,
423
+ currentStreamMessage
384
424
}
385
425
}
386
426
0 commit comments