Skip to content

Commit cc7d492

Browse files
committed
TLS settings and tests
1 parent 22d1684 commit cc7d492

File tree

6 files changed

+172
-39
lines changed

6 files changed

+172
-39
lines changed

neo4j/v1/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
2020

21+
from .constants import *
2122
from .session import *
2223
from .typesystem import *

neo4j/v1/compat.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,3 @@ def perf_counter():
9090
from urllib.parse import urlparse
9191
except ImportError:
9292
from urlparse import urlparse
93-
94-
95-
try:
96-
from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, HAS_SNI
97-
except ImportError:
98-
from ssl import wrap_socket, PROTOCOL_SSLv23
99-
100-
def secure_socket(s, host):
101-
return wrap_socket(s, ssl_version=PROTOCOL_SSLv23)
102-
103-
else:
104-
105-
def secure_socket(s, host):
106-
ssl_context = SSLContext(PROTOCOL_SSLv23)
107-
ssl_context.options |= OP_NO_SSLv2
108-
return ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)

neo4j/v1/connection.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,24 @@
2121

2222
from __future__ import division
2323

24+
from base64 import b64encode
2425
from collections import deque
2526
from io import BytesIO
2627
import logging
27-
from os import environ
28+
from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY
29+
from os.path import dirname, isfile
2830
from select import select
2931
from socket import create_connection, SHUT_RDWR
32+
from ssl import HAS_SNI, SSLError
3033
from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from
3134

32-
from ..meta import version
33-
from .compat import hex2, secure_socket
35+
from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \
36+
SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE
37+
from .compat import hex2
3438
from .exceptions import ProtocolError
3539
from .packstream import Packer, Unpacker
3640

3741

38-
DEFAULT_PORT = 7687
39-
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
40-
41-
MAGIC_PREAMBLE = 0x6060B017
42-
4342
# Signature bytes for each message type
4443
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
4544
RESET = b"\x0F" # 0000 1111 // RESET
@@ -211,14 +210,18 @@ def __init__(self, sock, **config):
211210
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
212211
if isinstance(user_agent, bytes):
213212
user_agent = user_agent.decode("UTF-8")
213+
self.user_agent = user_agent
214+
215+
# Pick up the server certificate, if any
216+
self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
214217

215218
def on_failure(metadata):
216219
raise ProtocolError("Initialisation failed")
217220

218221
response = Response(self)
219222
response.on_failure = on_failure
220223

221-
self.append(INIT, (user_agent,), response=response)
224+
self.append(INIT, (self.user_agent,), response=response)
222225
self.send()
223226
while not response.complete:
224227
self.fetch_next()
@@ -313,7 +316,39 @@ def close(self):
313316
self.closed = True
314317

315318

316-
def connect(host, port=None, **config):
319+
def verify_certificate(host, der_encoded_certificate):
320+
base64_encoded_certificate = b64encode(der_encoded_certificate)
321+
if isfile(KNOWN_HOSTS):
322+
with open(KNOWN_HOSTS) as f_in:
323+
for line in f_in:
324+
known_host, _, known_cert = line.strip().partition(":")
325+
if host == known_host:
326+
if base64_encoded_certificate == known_cert:
327+
# Certificate match
328+
return
329+
else:
330+
# Certificate mismatch
331+
print(base64_encoded_certificate)
332+
print(known_cert)
333+
raise ProtocolError("Server certificate does not match known certificate for %r; check "
334+
"details in file %r" % (host, KNOWN_HOSTS))
335+
# First use (no hosts match)
336+
try:
337+
makedirs(dirname(KNOWN_HOSTS))
338+
except OSError:
339+
pass
340+
f_out = os_open(KNOWN_HOSTS, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows
341+
if isinstance(host, bytes):
342+
os_write(f_out, host)
343+
else:
344+
os_write(f_out, host.encode("utf-8"))
345+
os_write(f_out, b":")
346+
os_write(f_out, base64_encoded_certificate)
347+
os_write(f_out, b"\n")
348+
os_close(f_out)
349+
350+
351+
def connect(host, port=None, ssl_context=None, **config):
317352
""" Connect and perform a handshake and return a valid Connection object, assuming
318353
a protocol version can be agreed.
319354
"""
@@ -323,10 +358,25 @@ def connect(host, port=None, **config):
323358
if __debug__: log_info("~~ [CONNECT] %s %d", host, port)
324359
s = create_connection((host, port))
325360

326-
# Secure the connection if so requested
327-
if config.get("secure", False):
361+
# Secure the connection if an SSL context has been provided
362+
if ssl_context:
328363
if __debug__: log_info("~~ [SECURE] %s", host)
329-
s = secure_socket(s, host)
364+
try:
365+
s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)
366+
except SSLError as cause:
367+
error = ProtocolError("Cannot establish secure connection; %s" % cause.args[1])
368+
error.__cause__ = cause
369+
raise error
370+
else:
371+
# Check that the server provides a certificate
372+
der_encoded_server_certificate = s.getpeercert(binary_form=True)
373+
if der_encoded_server_certificate is None:
374+
raise ProtocolError("When using a secure socket, the server should always provide a certificate")
375+
security = config.get("security", SECURITY_NONE)
376+
if security == SECURITY_TRUST_ON_FIRST_USE:
377+
verify_certificate(host, der_encoded_server_certificate)
378+
else:
379+
der_encoded_server_certificate = None
330380

331381
# Send details of the protocol versions supported
332382
supported_versions = [1, 0, 0, 0]
@@ -360,4 +410,4 @@ def connect(host, port=None, **config):
360410
s.shutdown(SHUT_RDWR)
361411
s.close()
362412
else:
363-
return Connection(s, **config)
413+
return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config)

neo4j/v1/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2016 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
from os.path import expanduser, join
23+
24+
from ..meta import version
25+
26+
27+
DEFAULT_PORT = 7687
28+
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
29+
30+
KNOWN_HOSTS = join(expanduser("~"), ".neo4j", "known_hosts")
31+
32+
MAGIC_PREAMBLE = 0x6060B017
33+
34+
SECURITY_NONE = 0
35+
SECURITY_TRUST_ON_FIRST_USE = 1
36+
SECURITY_VERIFIED = 2

neo4j/v1/session.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ class which can be used to obtain `Driver` instances that are used for
2929
from __future__ import division
3030

3131
from collections import deque, namedtuple
32+
from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED, Purpose
3233

3334
from .compat import integer, string, urlparse
3435
from .connection import connect, Response, RUN, PULL_ALL
36+
from .constants import SECURITY_NONE, SECURITY_VERIFIED
3537
from .exceptions import CypherError
3638
from .typesystem import hydrated
3739

@@ -77,6 +79,16 @@ def __init__(self, url, **config):
7779
self.config = config
7880
self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE)
7981
self.session_pool = deque()
82+
self.security = security = config.get("security", SECURITY_NONE)
83+
if security > SECURITY_NONE:
84+
ssl_context = SSLContext(PROTOCOL_SSLv23)
85+
ssl_context.options |= OP_NO_SSLv2
86+
if security >= SECURITY_VERIFIED:
87+
ssl_context.verify_mode = CERT_REQUIRED
88+
ssl_context.load_default_certs(Purpose.SERVER_AUTH)
89+
self.ssl_context = ssl_context
90+
else:
91+
self.ssl_context = None
8092

8193
def session(self):
8294
""" Create a new session based on the graph database details
@@ -417,13 +429,16 @@ class Session(object):
417429

418430
def __init__(self, driver):
419431
self.driver = driver
420-
self.connection = connect(driver.host, driver.port, **driver.config)
432+
self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config)
421433
self.transaction = None
422434
self.last_cursor = None
423435

424436
def __del__(self):
425-
if not self.connection.closed:
426-
self.connection.close()
437+
try:
438+
if not self.connection.closed:
439+
self.connection.close()
440+
except AttributeError:
441+
pass
427442

428443
def __enter__(self):
429444
return self
@@ -643,6 +658,7 @@ def __eq__(self, other):
643658
def __ne__(self, other):
644659
return not self.__eq__(other)
645660

661+
646662
def record(obj):
647663
""" Obtain an immutable record for the given object
648664
(either by calling obj.__record__() or by copying out the record data)

test/test_session.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,21 @@
1919
# limitations under the License.
2020

2121

22+
from os import remove, rename
23+
from os.path import isfile
2224
from socket import socket
2325
from ssl import SSLSocket
2426
from unittest import TestCase
2527

2628
from mock import patch
29+
from neo4j.v1.constants import KNOWN_HOSTS, SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE, SECURITY_VERIFIED
2730
from neo4j.v1.session import GraphDatabase, CypherError, Record, record
2831
from neo4j.v1.typesystem import Node, Relationship, Path
2932

3033

34+
KNOWN_HOSTS_BACKUP = KNOWN_HOSTS + ".backup"
35+
36+
3137
class DriverTestCase(TestCase):
3238

3339
def test_healthy_session_will_be_returned_to_the_pool_on_close(self):
@@ -81,17 +87,57 @@ def test_sessions_are_not_reused_if_still_in_use(self):
8187
session_1.close()
8288
assert session_1 is not session_2
8389

84-
def test_insecure_session_uses_insecure_socket(self):
85-
driver = GraphDatabase.driver("bolt://localhost", secure=False)
90+
91+
class SecurityTestCase(TestCase):
92+
93+
def setUp(self):
94+
if isfile(KNOWN_HOSTS):
95+
rename(KNOWN_HOSTS, KNOWN_HOSTS_BACKUP)
96+
97+
def tearDown(self):
98+
if isfile(KNOWN_HOSTS_BACKUP):
99+
rename(KNOWN_HOSTS_BACKUP, KNOWN_HOSTS)
100+
101+
def test_default_session_uses_security_none(self):
102+
# TODO: verify this is the correct default (maybe TOFU?)
103+
driver = GraphDatabase.driver("bolt://localhost")
104+
assert driver.security == SECURITY_NONE
105+
106+
def test_insecure_session_uses_normal_socket(self):
107+
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE)
108+
session = driver.session()
109+
connection = session.connection
110+
assert isinstance(connection.channel.socket, socket)
111+
assert connection.der_encoded_server_certificate is None
112+
session.close()
113+
114+
def test_tofu_session_uses_secure_socket(self):
115+
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE)
86116
session = driver.session()
87-
assert isinstance(session.connection.channel.socket, socket)
117+
connection = session.connection
118+
assert isinstance(connection.channel.socket, SSLSocket)
119+
assert connection.der_encoded_server_certificate is not None
88120
session.close()
89121

90-
def test_secure_session_uses_secure_socket(self):
91-
driver = GraphDatabase.driver("bolt://localhost", secure=True)
122+
def test_tofu_session_trusts_certificate_after_first_use(self):
123+
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE)
92124
session = driver.session()
93-
assert isinstance(session.connection.channel.socket, SSLSocket)
125+
connection = session.connection
126+
certificate = connection.der_encoded_server_certificate
94127
session.close()
128+
session = driver.session()
129+
connection = session.connection
130+
assert connection.der_encoded_server_certificate == certificate
131+
session.close()
132+
133+
# TODO: Find a way to run this test
134+
# def test_verified_session_uses_secure_socket(self):
135+
# driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_VERIFIED)
136+
# session = driver.session()
137+
# connection = session.connection
138+
# assert isinstance(connection.channel.socket, SSLSocket)
139+
# assert connection.der_encoded_server_certificate is not None
140+
# session.close()
95141

96142

97143
class RunTestCase(TestCase):

0 commit comments

Comments
 (0)