Skip to content

Commit 2c0b708

Browse files
author
Zhen
committed
Add max connection lifetime on each connection
By default the max connection lifetime is infinite.
1 parent ffcc17c commit 2c0b708

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

neo4j/bolt/connection.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@
4242
from neo4j.meta import version
4343
from neo4j.packstream import Packer, Unpacker
4444
from neo4j.util import import_best as _import_best
45+
from time import clock
4546

4647
ChunkedInputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedInputBuffer
4748
ChunkedOutputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedOutputBuffer
4849

4950

51+
INFINITE_CONNECTION_LIFETIME = -1
52+
DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE_CONNECTION_LIFETIME
5053
DEFAULT_CONNECTION_TIMEOUT = 5.0
5154
DEFAULT_PORT = 7687
5255
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
@@ -178,6 +181,8 @@ def __init__(self, address, sock, error_handler, **config):
178181
self.packer = Packer(self.output_buffer)
179182
self.unpacker = Unpacker()
180183
self.responses = deque()
184+
self._max_connection_lifetime = config.get("max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME)
185+
self._creation_timestamp = clock()
181186

182187
# Determine the user agent and ensure it is a Unicode value
183188
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
@@ -201,6 +206,7 @@ def __init__(self, address, sock, error_handler, **config):
201206
# Pick up the server certificate, if any
202207
self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
203208

209+
def Init(self):
204210
response = InitResponse(self)
205211
self.append(INIT, (self.user_agent, self.auth_dict), response=response)
206212
self.sync()
@@ -360,6 +366,9 @@ def _unpack(self):
360366
more = False
361367
return details, summary_signature, summary_metadata
362368

369+
def timedout(self):
370+
return 0 <= self._max_connection_lifetime <= clock() - self._creation_timestamp
371+
363372
def sync(self):
364373
""" Send and fetch all outstanding messages.
365374
@@ -425,7 +434,7 @@ def acquire_direct(self, address):
425434
except KeyError:
426435
connections = self.connections[address] = deque()
427436
for connection in list(connections):
428-
if connection.closed() or connection.defunct():
437+
if connection.closed() or connection.defunct() or connection.timedout():
429438
connections.remove(connection)
430439
continue
431440
if not connection.in_use:
@@ -600,8 +609,10 @@ def connect(address, ssl_context=None, error_handler=None, **config):
600609
s.shutdown(SHUT_RDWR)
601610
s.close()
602611
elif agreed_version == 1:
603-
return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate,
612+
connection = Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate,
604613
error_handler=error_handler, **config)
614+
connection.Init()
615+
return connection
605616
elif agreed_version == 0x48545450:
606617
log_error("S: [CLOSE]")
607618
s.close()

test/integration/test_connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from socket import create_connection
2323

2424
from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler
25-
2625
from test.integration.tools import IntegrationTestCase
2726

2827

@@ -44,6 +43,9 @@ def closed(self):
4443
def defunct(self):
4544
return False
4645

46+
def timedout(self):
47+
return False
48+
4749

4850
def connector(address, _):
4951
return QuickConnection(create_connection(address))
@@ -119,4 +121,4 @@ def test_in_use_count(self):
119121
connection = self.pool.acquire_direct(address)
120122
self.assertEqual(self.pool.in_use_connection_count(address), 1)
121123
self.pool.release(connection)
122-
self.assertEqual(self.pool.in_use_connection_count(address), 0)
124+
self.assertEqual(self.pool.in_use_connection_count(address), 0)

test/unit/test_connection.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2017 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from unittest import TestCase
22+
from neo4j.v1 import DirectConnectionErrorHandler
23+
from neo4j.bolt import Connection
24+
25+
26+
class FakeSocket(object):
27+
def __init__(self, address):
28+
self.address = address
29+
30+
def getpeername(self):
31+
return self.address
32+
33+
def sendall(self, data):
34+
return
35+
36+
def close(self):
37+
return
38+
39+
40+
class ConnectionTestCase(TestCase):
41+
42+
def test_conn_timedout(self):
43+
address = ("127.0.0.1", 7687)
44+
connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), max_connection_lifetime=0)
45+
self.assertEqual(connection.timedout(), True)
46+
47+
def test_conn_not_timedout_if_not_enabled(self):
48+
address = ("127.0.0.1", 7687)
49+
connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(),
50+
max_connection_lifetime=-1)
51+
self.assertEqual(connection.timedout(), False)
52+
53+
def test_conn_not_timedout(self):
54+
address = ("127.0.0.1", 7687)
55+
connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(),
56+
max_connection_lifetime=999999999)
57+
self.assertEqual(connection.timedout(), False)

0 commit comments

Comments
 (0)