Skip to content

1.0 session pool #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 1 addition & 44 deletions neo4j/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,10 @@
from json import loads as json_loads
from sys import stdout, stderr

from .util import Watcher
from .v1.session import GraphDatabase, CypherError


class ColourFormatter(logging.Formatter):
""" Colour formatter for pretty log output.
"""

def format(self, record):
s = super(ColourFormatter, self).format(record)
if record.levelno == logging.CRITICAL:
return "\x1b[31;1m%s\x1b[0m" % s # bright red
elif record.levelno == logging.ERROR:
return "\x1b[33;1m%s\x1b[0m" % s # bright yellow
elif record.levelno == logging.WARNING:
return "\x1b[33m%s\x1b[0m" % s # yellow
elif record.levelno == logging.INFO:
return "\x1b[36m%s\x1b[0m" % s # cyan
elif record.levelno == logging.DEBUG:
return "\x1b[34m%s\x1b[0m" % s # blue
else:
return s


class Watcher(object):
""" Log watcher for debug output.
"""

handlers = {}

def __init__(self, logger_name):
super(Watcher, self).__init__()
self.logger_name = logger_name
self.logger = logging.getLogger(self.logger_name)
self.formatter = ColourFormatter("%(asctime)s %(message)s")

def watch(self, level=logging.INFO, out=stdout):
try:
self.logger.removeHandler(self.handlers[self.logger_name])
except KeyError:
pass
handler = logging.StreamHandler(out)
handler.setFormatter(self.formatter)
self.handlers[self.logger_name] = handler
self.logger.addHandler(handler)
self.logger.setLevel(level)


def main():
parser = ArgumentParser(description="Execute one or more Cypher statements using Bolt.")
parser.add_argument("statement", nargs="+")
Expand Down
76 changes: 76 additions & 0 deletions neo4j/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) 2002-2016 "Neo Technology,"
# Network Engine for Objects in Lund AB [http://neotechnology.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import unicode_literals

import logging
from argparse import ArgumentParser
from json import loads as json_loads
from sys import stdout, stderr

from .v1.session import GraphDatabase, CypherError


class ColourFormatter(logging.Formatter):
""" Colour formatter for pretty log output.
"""

def format(self, record):
s = super(ColourFormatter, self).format(record)
if record.levelno == logging.CRITICAL:
return "\x1b[31;1m%s\x1b[0m" % s # bright red
elif record.levelno == logging.ERROR:
return "\x1b[33;1m%s\x1b[0m" % s # bright yellow
elif record.levelno == logging.WARNING:
return "\x1b[33m%s\x1b[0m" % s # yellow
elif record.levelno == logging.INFO:
return "\x1b[36m%s\x1b[0m" % s # cyan
elif record.levelno == logging.DEBUG:
return "\x1b[34m%s\x1b[0m" % s # blue
else:
return s


class Watcher(object):
""" Log watcher for debug output.
"""

handlers = {}

def __init__(self, logger_name):
super(Watcher, self).__init__()
self.logger_name = logger_name
self.logger = logging.getLogger(self.logger_name)
self.formatter = ColourFormatter("%(asctime)s %(message)s")

def watch(self, level=logging.INFO, out=stdout):
self.stop()
handler = logging.StreamHandler(out)
handler.setFormatter(self.formatter)
self.handlers[self.logger_name] = handler
self.logger.addHandler(handler)
self.logger.setLevel(level)

def stop(self):
try:
self.logger.removeHandler(self.handlers[self.logger_name])
except KeyError:
pass
78 changes: 53 additions & 25 deletions neo4j/v1/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

# Signature bytes for each message type
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
ACK_FAILURE = b"\x0F" # 0000 1111 // ACK_FAILURE
RESET = b"\x0F" # 0000 1111 // RESET
RUN = b"\x10" # 0001 0000 // RUN <statement> <parameters>
DISCARD_ALL = b"\x2F" # 0010 1111 // DISCARD *
PULL_ALL = b"\x3F" # 0011 1111 // PULL *
Expand All @@ -56,7 +56,7 @@

message_names = {
INIT: "INIT",
ACK_FAILURE: "ACK_FAILURE",
RESET: "RESET",
RUN: "RUN",
DISCARD_ALL: "DISCARD_ALL",
PULL_ALL: "PULL_ALL",
Expand Down Expand Up @@ -169,14 +169,6 @@ def chunk_reader(self):
data = self._recv(chunk_size)
yield data

def close(self):
""" Shut down and close the connection.
"""
if __debug__: log_info("~~ [CLOSE]")
socket = self.socket
socket.shutdown(SHUT_RDWR)
socket.close()


class Response(object):
""" Subscriber object for a full response (zero or
Expand All @@ -200,12 +192,6 @@ def on_ignored(self, metadata=None):
pass


class AckFailureResponse(Response):

def on_failure(self, metadata):
raise ProtocolError("Could not acknowledge failure")


class Connection(object):
""" Server connection through which all protocol messages
are sent and received. This class is designed for protocol
Expand All @@ -215,9 +201,11 @@ class Connection(object):
"""

def __init__(self, sock, **config):
self.defunct = False
self.channel = ChunkChannel(sock)
self.packer = Packer(self.channel)
self.responses = deque()
self.closed = False

# Determine the user agent and ensure it is a Unicode value
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
Expand All @@ -235,8 +223,15 @@ def on_failure(metadata):
while not response.complete:
self.fetch_next()

def __del__(self):
self.close()

def append(self, signature, fields=(), response=None):
""" Add a message to the outgoing queue.

:arg signature: the signature of the message
:arg fields: the fields of the message as a tuple
:arg response: a response object to handle callbacks
"""
if __debug__:
log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields)))
Expand All @@ -247,42 +242,75 @@ def append(self, signature, fields=(), response=None):
self.channel.flush(end_of_message=True)
self.responses.append(response)

def reset(self):
""" Add a RESET message to the outgoing queue, send
it and consume all remaining messages.
"""
response = Response(self)

def on_failure(metadata):
raise ProtocolError("Reset failed")

response.on_failure = on_failure

self.append(RESET, response=response)
self.send()
fetch_next = self.fetch_next
while not response.complete:
fetch_next()

def send(self):
""" Send all queued messages to the server.
"""
if self.closed:
raise ProtocolError("Cannot write to a closed connection")
if self.defunct:
raise ProtocolError("Cannot write to a defunct connection")
self.channel.send()

def fetch_next(self):
""" Receive exactly one message from the server.
"""
if self.closed:
raise ProtocolError("Cannot read from a closed connection")
if self.defunct:
raise ProtocolError("Cannot read from a defunct connection")
raw = BytesIO()
unpack = Unpacker(raw).unpack
raw.writelines(self.channel.chunk_reader())

try:
raw.writelines(self.channel.chunk_reader())
except ProtocolError:
self.defunct = True
self.close()
raise
# Unpack from the raw byte stream and call the relevant message handler(s)
raw.seek(0)
response = self.responses[0]
for signature, fields in unpack():
if __debug__:
log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields)))
if signature in SUMMARY:
response.complete = True
self.responses.popleft()
if signature == FAILURE:
self.reset()
handler_name = "on_%s" % message_names[signature].lower()
try:
handler = getattr(response, handler_name)
except AttributeError:
pass
else:
handler(*fields)
if signature in SUMMARY:
response.complete = True
self.responses.popleft()
if signature == FAILURE:
self.append(ACK_FAILURE, response=AckFailureResponse(self))
raw.close()

def close(self):
""" Shut down and close the connection.
""" Close the connection.
"""
self.channel.close()
if not self.closed:
if __debug__:
log_info("~~ [CLOSE]")
self.channel.socket.close()
self.closed = True


def connect(host, port=None, **config):
Expand Down
Loading