|
40 | 40 | from .constants import DEFAULT_PORT, ENCRYPTION_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES, \
|
41 | 41 | TRUST_ON_FIRST_USE, READ_ACCESS, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, \
|
42 | 42 | TRUST_ALL_CERTIFICATES, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES
|
43 |
| -from .exceptions import CypherError, ProtocolError, ResultError, TransactionError |
| 43 | +from .exceptions import CypherError, ProtocolError, ResultError, TransactionError, \ |
| 44 | + ServiceUnavailable |
44 | 45 | from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED
|
45 | 46 | from .summary import ResultSummary
|
46 | 47 | from .types import hydrated
|
@@ -196,53 +197,88 @@ def session(self, access_mode=None):
|
196 | 197 | return Session(self.pool.acquire(self.address))
|
197 | 198 |
|
198 | 199 |
|
199 |
| -class RoutingDriver(Driver): |
200 |
| - """ A :class:`.RoutingDriver` is created from a `bolt+routing` URI. |
| 200 | +def parse_address(address): |
| 201 | + """ Convert an address string to a tuple. |
201 | 202 | """
|
| 203 | + host, _, port = address.partition(":") |
| 204 | + return host, int(port) |
202 | 205 |
|
203 |
| - def __init__(self, address, **config): |
204 |
| - self.address = address |
205 |
| - self.security_plan = security_plan = SecurityPlan.build(address, **config) |
206 |
| - self.encrypted = security_plan.encrypted |
207 |
| - if not security_plan.routing_compatible: |
208 |
| - # this error message is case-specific as there is only one incompatible |
209 |
| - # scenario right now |
210 |
| - raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing") |
211 |
| - Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config)) |
212 |
| - self._lock = Lock() |
213 |
| - self._expiry_time = None |
214 |
| - self._routers = RoundRobinSet([address]) |
215 |
| - self._readers = RoundRobinSet() |
216 |
| - self._writers = RoundRobinSet() |
217 |
| - self.discover() |
| 206 | + |
| 207 | +class Router(object): |
| 208 | + |
| 209 | + timer = monotonic |
| 210 | + |
| 211 | + def __init__(self, pool, initial_address): |
| 212 | + self.pool = pool |
| 213 | + self.lock = Lock() |
| 214 | + self.expiry_time = None |
| 215 | + self.routers = RoundRobinSet([initial_address]) |
| 216 | + self.readers = RoundRobinSet() |
| 217 | + self.writers = RoundRobinSet() |
| 218 | + |
| 219 | + def stale(self): |
| 220 | + expired = self.expiry_time is None or self.expiry_time <= self.timer() |
| 221 | + return expired or len(self.routers) <= 1 or not self.readers or not self.writers |
218 | 222 |
|
219 | 223 | def discover(self):
|
220 |
| - with self._lock: |
221 |
| - for router in list(self._routers): |
| 224 | + with self.lock: |
| 225 | + if not self.routers: |
| 226 | + raise ServiceUnavailable("No routers available") |
| 227 | + for router in list(self.routers): |
222 | 228 | session = Session(self.pool.acquire(router))
|
223 | 229 | try:
|
224 | 230 | record = session.run("CALL dbms.cluster.routing.getServers").single()
|
| 231 | + except CypherError as error: |
| 232 | + if error.code == "Neo.ClientError.Procedure.ProcedureNotFound": |
| 233 | + raise ServiceUnavailable("Server does not support routing") |
| 234 | + raise |
225 | 235 | except ResultError:
|
226 | 236 | raise RuntimeError("TODO")
|
227 |
| - new_expiry_time = monotonic() + record["ttl"] |
| 237 | + new_expiry_time = self.timer() + record["ttl"] |
228 | 238 | servers = record["servers"]
|
229 | 239 | new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0]
|
230 | 240 | new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0]
|
231 | 241 | new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0]
|
232 | 242 | if new_routers and new_readers and new_writers:
|
233 |
| - self._expiry_time = new_expiry_time |
234 |
| - self._routers.replace(new_routers) |
235 |
| - self._readers.replace(new_readers) |
236 |
| - self._writers.replace(new_writers) |
237 |
| - else: |
238 |
| - raise RuntimeError("TODO") |
| 243 | + self.expiry_time = new_expiry_time |
| 244 | + self.routers.replace(map(parse_address, new_routers)) |
| 245 | + self.readers.replace(map(parse_address, new_readers)) |
| 246 | + self.writers.replace(map(parse_address, new_writers)) |
| 247 | + return |
| 248 | + raise ServiceUnavailable("Unable to establish routing information") |
| 249 | + |
| 250 | + def acquire_read_connection(self): |
| 251 | + if self.stale(): |
| 252 | + self.discover() |
| 253 | + return self.pool.acquire(next(self.readers)) |
| 254 | + |
| 255 | + def acquire_write_connection(self): |
| 256 | + if self.stale(): |
| 257 | + self.discover() |
| 258 | + return self.pool.acquire(next(self.writers)) |
| 259 | + |
| 260 | + |
| 261 | +class RoutingDriver(Driver): |
| 262 | + """ A :class:`.RoutingDriver` is created from a `bolt+routing` URI. |
| 263 | + """ |
| 264 | + |
| 265 | + def __init__(self, address, **config): |
| 266 | + self.security_plan = security_plan = SecurityPlan.build(address, **config) |
| 267 | + self.encrypted = security_plan.encrypted |
| 268 | + if not security_plan.routing_compatible: |
| 269 | + # this error message is case-specific as there is only one incompatible |
| 270 | + # scenario right now |
| 271 | + raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing") |
| 272 | + Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config)) |
| 273 | + self.router = Router(self.pool, address) |
| 274 | + self.router.discover() |
239 | 275 |
|
240 | 276 | def session(self, access_mode=None):
|
241 | 277 | if access_mode == READ_ACCESS:
|
242 |
| - address = next(self._readers) |
| 278 | + connection = self.router.acquire_read_connection() |
243 | 279 | else:
|
244 |
| - address = next(self._writers) |
245 |
| - return Session(self.pool.acquire(address)) |
| 280 | + connection = self.router.acquire_write_connection() |
| 281 | + return Session(connection) |
246 | 282 |
|
247 | 283 |
|
248 | 284 | class StatementResult(object):
|
|
0 commit comments