38
38
)
39
39
from ._bolt import (
40
40
AsyncBolt ,
41
+ ClientStateManagerBase ,
41
42
ServerStateManagerBase ,
42
43
tx_timeout_as_ms ,
43
44
)
52
53
log = getLogger ("neo4j" )
53
54
54
55
55
- class ServerStates (Enum ):
56
+ class BoltStates (Enum ):
56
57
CONNECTED = "CONNECTED"
57
58
READY = "READY"
58
59
STREAMING = "STREAMING"
@@ -62,25 +63,25 @@ class ServerStates(Enum):
62
63
63
64
class ServerStateManager (ServerStateManagerBase ):
64
65
_STATE_TRANSITIONS : t .Dict [Enum , t .Dict [str , Enum ]] = {
65
- ServerStates .CONNECTED : {
66
- "hello" : ServerStates .READY ,
66
+ BoltStates .CONNECTED : {
67
+ "hello" : BoltStates .READY ,
67
68
},
68
- ServerStates .READY : {
69
- "run" : ServerStates .STREAMING ,
70
- "begin" : ServerStates .TX_READY_OR_TX_STREAMING ,
69
+ BoltStates .READY : {
70
+ "run" : BoltStates .STREAMING ,
71
+ "begin" : BoltStates .TX_READY_OR_TX_STREAMING ,
71
72
},
72
- ServerStates .STREAMING : {
73
- "pull" : ServerStates .READY ,
74
- "discard" : ServerStates .READY ,
75
- "reset" : ServerStates .READY ,
73
+ BoltStates .STREAMING : {
74
+ "pull" : BoltStates .READY ,
75
+ "discard" : BoltStates .READY ,
76
+ "reset" : BoltStates .READY ,
76
77
},
77
- ServerStates .TX_READY_OR_TX_STREAMING : {
78
- "commit" : ServerStates .READY ,
79
- "rollback" : ServerStates .READY ,
80
- "reset" : ServerStates .READY ,
78
+ BoltStates .TX_READY_OR_TX_STREAMING : {
79
+ "commit" : BoltStates .READY ,
80
+ "rollback" : BoltStates .READY ,
81
+ "reset" : BoltStates .READY ,
81
82
},
82
- ServerStates .FAILED : {
83
- "reset" : ServerStates .READY ,
83
+ BoltStates .FAILED : {
84
+ "reset" : BoltStates .READY ,
84
85
}
85
86
}
86
87
@@ -99,7 +100,40 @@ def transition(self, message, metadata):
99
100
self ._on_change (state_before , self .state )
100
101
101
102
def failed (self ):
102
- return self .state == ServerStates .FAILED
103
+ return self .state == BoltStates .FAILED
104
+
105
+
106
+ class ClientStateManager (ClientStateManagerBase ):
107
+ _STATE_TRANSITIONS : t .Dict [Enum , t .Dict [str , Enum ]] = {
108
+ BoltStates .CONNECTED : {
109
+ "hello" : BoltStates .READY ,
110
+ },
111
+ BoltStates .READY : {
112
+ "run" : BoltStates .STREAMING ,
113
+ "begin" : BoltStates .TX_READY_OR_TX_STREAMING ,
114
+ },
115
+ BoltStates .STREAMING : {
116
+ "begin" : BoltStates .TX_READY_OR_TX_STREAMING ,
117
+ "reset" : BoltStates .READY ,
118
+ },
119
+ BoltStates .TX_READY_OR_TX_STREAMING : {
120
+ "commit" : BoltStates .READY ,
121
+ "rollback" : BoltStates .READY ,
122
+ "reset" : BoltStates .READY ,
123
+ },
124
+ }
125
+
126
+ def __init__ (self , init_state , on_change = None ):
127
+ self .state = init_state
128
+ self ._on_change = on_change
129
+
130
+ def transition (self , message ):
131
+ state_before = self .state
132
+ self .state = self ._STATE_TRANSITIONS \
133
+ .get (self .state , {}) \
134
+ .get (message , self .state )
135
+ if state_before != self .state and callable (self ._on_change ):
136
+ self ._on_change (state_before , self .state )
103
137
104
138
105
139
class AsyncBolt3 (AsyncBolt ):
@@ -121,25 +155,34 @@ class AsyncBolt3(AsyncBolt):
121
155
def __init__ (self , * args , ** kwargs ):
122
156
super ().__init__ (* args , ** kwargs )
123
157
self ._server_state_manager = ServerStateManager (
124
- ServerStates .CONNECTED , on_change = self ._on_server_state_change
158
+ BoltStates .CONNECTED , on_change = self ._on_server_state_change
159
+ )
160
+ self ._client_state_manager = ClientStateManager (
161
+ BoltStates .CONNECTED , on_change = self ._on_client_state_change
125
162
)
126
163
127
164
def _on_server_state_change (self , old_state , new_state ):
128
- log .debug ("[#%04X] _: <CONNECTION> state: %s > %s" , self . local_port ,
129
- old_state .name , new_state .name )
165
+ log .debug ("[#%04X] _: <CONNECTION> server state: %s > %s" ,
166
+ self . local_port , old_state .name , new_state .name )
130
167
131
168
def _get_server_state_manager (self ) -> ServerStateManagerBase :
132
169
return self ._server_state_manager
133
170
171
+ def _on_client_state_change (self , old_state , new_state ):
172
+ log .debug ("[#%04X] _: <CONNECTION> client state: %s > %s" ,
173
+ self .local_port , old_state .name , new_state .name )
174
+
175
+ def _get_client_state_manager (self ) -> ClientStateManagerBase :
176
+ return self ._client_state_manager
177
+
134
178
@property
135
179
def is_reset (self ):
136
180
# We can't be sure of the server's state if there are still pending
137
181
# responses. Unless the last message we sent was RESET. In that case
138
182
# the server state will always be READY when we're done.
139
- if (self .responses and self .responses [- 1 ]
140
- and self .responses [- 1 ].message == "reset" ):
141
- return True
142
- return self ._server_state_manager .state == ServerStates .READY
183
+ if self .responses :
184
+ return self .responses [- 1 ] and self .responses [- 1 ].message == "reset"
185
+ return self ._server_state_manager .state == BoltStates .READY
143
186
144
187
@property
145
188
def encrypted (self ):
@@ -216,7 +259,7 @@ async def route(
216
259
hydration_hooks = hydration_hooks ,
217
260
on_success = metadata .update
218
261
)
219
- self .pull (dehydration_hooks = None , hydration_hooks = None ,
262
+ self .pull (dehydration_hooks = None , hydration_hooks = None ,
220
263
on_success = metadata .update , on_records = records .extend )
221
264
await self .send_all ()
222
265
await self .fetch_all ()
@@ -398,7 +441,7 @@ async def _process_message(self, tag, fields):
398
441
await response .on_ignored (summary_metadata or {})
399
442
elif summary_signature == b"\x7F " :
400
443
log .debug ("[#%04X] S: FAILURE %r" , self .local_port , summary_metadata )
401
- self ._server_state_manager .state = ServerStates .FAILED
444
+ self ._server_state_manager .state = BoltStates .FAILED
402
445
try :
403
446
await response .on_failure (summary_metadata or {})
404
447
except (ServiceUnavailable , DatabaseUnavailable ):
@@ -408,7 +451,8 @@ async def _process_message(self, tag, fields):
408
451
except (NotALeader , ForbiddenOnReadOnlyDatabase ):
409
452
if self .pool :
410
453
await self .pool .on_write_failure (
411
- address = self .unresolved_address
454
+ address = self .unresolved_address ,
455
+ database = self .last_database ,
412
456
)
413
457
raise
414
458
except Neo4jError as e :
0 commit comments