18
18
# See the License for the specific language governing permissions and
19
19
# limitations under the License.
20
20
21
+ from enum import Enum
21
22
from logging import getLogger
22
23
from ssl import SSLSocket
23
24
37
38
Neo4jError ,
38
39
NotALeader ,
39
40
ServiceUnavailable ,
40
- SessionExpired ,
41
41
)
42
42
from neo4j .io import (
43
43
Bolt ,
48
48
InitResponse ,
49
49
Response ,
50
50
)
51
+ from neo4j .io ._bolt3 import (
52
+ ServerStates ,
53
+ STATE_TRANSITIONS ,
54
+ )
51
55
52
56
53
57
log = getLogger ("neo4j" )
@@ -65,6 +69,16 @@ class Bolt4x0(Bolt):
65
69
66
70
supports_multiple_databases = True
67
71
72
+ _server_state = ServerStates .CONNECTED
73
+
74
+ @property
75
+ def is_reset (self ):
76
+ if self .responses :
77
+ # we can't be sure of the server's state as there are still pending
78
+ # responses.
79
+ return False
80
+ return self ._server_state == ServerStates .READY
81
+
68
82
@property
69
83
def encrypted (self ):
70
84
return isinstance (self .socket , SSLSocket )
@@ -93,7 +107,8 @@ def hello(self):
93
107
logged_headers ["credentials" ] = "*******"
94
108
log .debug ("[#%04X] C: HELLO %r" , self .local_port , logged_headers )
95
109
self ._append (b"\x01 " , (headers ,),
96
- response = InitResponse (self , on_success = self .server_info .update ))
110
+ response = InitResponse (self , "hello" ,
111
+ on_success = self .server_info .update ))
97
112
self .send_all ()
98
113
self .fetch_all ()
99
114
check_supported_server_product (self .server_info .agent )
@@ -162,24 +177,25 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
162
177
fields = (query , parameters , extra )
163
178
log .debug ("[#%04X] C: RUN %s" , self .local_port , " " .join (map (repr , fields )))
164
179
if query .upper () == u"COMMIT" :
165
- self ._append (b"\x10 " , fields , CommitResponse (self , ** handlers ))
180
+ self ._append (b"\x10 " , fields , CommitResponse (self , "run" ,
181
+ ** handlers ))
166
182
else :
167
- self ._append (b"\x10 " , fields , Response (self , ** handlers ))
183
+ self ._append (b"\x10 " , fields , Response (self , "run" , ** handlers ))
168
184
self ._is_reset = False
169
185
170
186
def discard (self , n = - 1 , qid = - 1 , ** handlers ):
171
187
extra = {"n" : n }
172
188
if qid != - 1 :
173
189
extra ["qid" ] = qid
174
190
log .debug ("[#%04X] C: DISCARD %r" , self .local_port , extra )
175
- self ._append (b"\x2F " , (extra ,), Response (self , ** handlers ))
191
+ self ._append (b"\x2F " , (extra ,), Response (self , "discard" , ** handlers ))
176
192
177
193
def pull (self , n = - 1 , qid = - 1 , ** handlers ):
178
194
extra = {"n" : n }
179
195
if qid != - 1 :
180
196
extra ["qid" ] = qid
181
197
log .debug ("[#%04X] C: PULL %r" , self .local_port , extra )
182
- self ._append (b"\x3F " , (extra ,), Response (self , ** handlers ))
198
+ self ._append (b"\x3F " , (extra ,), Response (self , "pull" , ** handlers ))
183
199
self ._is_reset = False
184
200
185
201
def begin (self , mode = None , bookmarks = None , metadata = None , timeout = None ,
@@ -205,16 +221,16 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
205
221
except TypeError :
206
222
raise TypeError ("Timeout must be specified as a number of seconds" )
207
223
log .debug ("[#%04X] C: BEGIN %r" , self .local_port , extra )
208
- self ._append (b"\x11 " , (extra ,), Response (self , ** handlers ))
224
+ self ._append (b"\x11 " , (extra ,), Response (self , "begin" , ** handlers ))
209
225
self ._is_reset = False
210
226
211
227
def commit (self , ** handlers ):
212
228
log .debug ("[#%04X] C: COMMIT" , self .local_port )
213
- self ._append (b"\x12 " , (), CommitResponse (self , ** handlers ))
229
+ self ._append (b"\x12 " , (), CommitResponse (self , "commit" , ** handlers ))
214
230
215
231
def rollback (self , ** handlers ):
216
232
log .debug ("[#%04X] C: ROLLBACK" , self .local_port )
217
- self ._append (b"\x13 " , (), Response (self , ** handlers ))
233
+ self ._append (b"\x13 " , (), Response (self , "rollback" , ** handlers ))
218
234
219
235
def reset (self ):
220
236
""" Add a RESET message to the outgoing queue, send
@@ -225,11 +241,22 @@ def fail(metadata):
225
241
raise BoltProtocolError ("RESET failed %r" % metadata , self .unresolved_address )
226
242
227
243
log .debug ("[#%04X] C: RESET" , self .local_port )
228
- self ._append (b"\x0F " , response = Response (self , on_failure = fail ))
244
+ self ._append (b"\x0F " , response = Response (self , "reset" , on_failure = fail ))
229
245
self .send_all ()
230
246
self .fetch_all ()
231
247
self ._is_reset = True
232
248
249
+ def _update_server_state_on_success (self , metadata , message ):
250
+ if metadata .get ("has_more" ):
251
+ return
252
+ state_before = self ._server_state
253
+ self ._server_state = STATE_TRANSITIONS \
254
+ .get (self ._server_state , {})\
255
+ .get (message , self ._server_state )
256
+ if state_before != self ._server_state :
257
+ log .debug ("[#%04X] [%s]" , self .local_port ,
258
+ self ._server_state .name )
259
+
233
260
def fetch_message (self ):
234
261
""" Receive at most one message from the server, if available.
235
262
@@ -261,12 +288,15 @@ def fetch_message(self):
261
288
response .complete = True
262
289
if summary_signature == b"\x70 " :
263
290
log .debug ("[#%04X] S: SUCCESS %r" , self .local_port , summary_metadata )
291
+ self ._update_server_state_on_success (summary_metadata ,
292
+ response .message )
264
293
response .on_success (summary_metadata or {})
265
294
elif summary_signature == b"\x7E " :
266
295
log .debug ("[#%04X] S: IGNORED" , self .local_port )
267
296
response .on_ignored (summary_metadata or {})
268
297
elif summary_signature == b"\x7F " :
269
298
log .debug ("[#%04X] S: FAILURE %r" , self .local_port , summary_metadata )
299
+ self ._server_state = ServerStates .FAILED
270
300
try :
271
301
response .on_failure (summary_metadata or {})
272
302
except (ServiceUnavailable , DatabaseUnavailable ):
@@ -372,7 +402,9 @@ def fail(md):
372
402
else :
373
403
bookmarks = list (bookmarks )
374
404
self ._append (b"\x66 " , (routing_context , bookmarks , database ),
375
- response = Response (self , on_success = metadata .update , on_failure = fail ))
405
+ response = Response (self , "route" ,
406
+ on_success = metadata .update ,
407
+ on_failure = fail ))
376
408
self .send_all ()
377
409
self .fetch_all ()
378
410
return [metadata .get ("rt" )]
@@ -400,7 +432,8 @@ def on_success(metadata):
400
432
logged_headers ["credentials" ] = "*******"
401
433
log .debug ("[#%04X] C: HELLO %r" , self .local_port , logged_headers )
402
434
self ._append (b"\x01 " , (headers ,),
403
- response = InitResponse (self , on_success = on_success ))
435
+ response = InitResponse (self , "hello" ,
436
+ on_success = on_success ))
404
437
self .send_all ()
405
438
self .fetch_all ()
406
439
check_supported_server_product (self .server_info .agent )
0 commit comments