diff --git a/neo4j/v1/bolt.py b/neo4j/v1/bolt.py index 4863105c..453c1ec0 100644 --- a/neo4j/v1/bolt.py +++ b/neo4j/v1/bolt.py @@ -92,8 +92,7 @@ def fill(self): ready_to_read, _, _ = select((self.socket,), (), (), 0) received = self.socket.recv(65539) if received: - if __debug__: - log_debug("S: b%r", received) + log_debug("S: b%r", received) self.buffer[len(self.buffer):] = received else: if ready_to_read is not None: @@ -174,8 +173,7 @@ def send(self): """ Send all queued messages to the server. """ data = self.raw.getvalue() - if __debug__: - log_debug("C: b%r", data) + log_debug("C: b%r", data) self.socket.sendall(data) self.raw.seek(self.raw.truncate(0)) @@ -264,8 +262,7 @@ def append(self, signature, fields=(), response=None): :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ - if __debug__: - log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields))) + log_info("C: %s %r", message_names[signature], fields) self.packer.pack_struct_header(len(fields), signature) for field in fields: @@ -329,29 +326,36 @@ def fetch(self): self.defunct = True self.close() raise - # Unpack from the raw byte stream and call the relevant message handler(s) - self.unpacker.load(message_data) - size, signature = self.unpacker.unpack_structure_header() - fields = [self.unpacker.unpack() for _ in range(size)] - if __debug__: - log_info("S: %s %r", message_names[signature], fields) + unpacker = self.unpacker + unpacker.load(message_data) + size, signature = unpacker.unpack_structure_header() + if size > 1: + raise ProtocolError("Expected one field") if signature == SUCCESS: + metadata = unpacker.unpack_map() + log_info("S: SUCCESS (%r)", metadata) response = self.responses.popleft() response.complete = True - response.on_success(*fields) + response.on_success(metadata or {}) elif signature == RECORD: + data = unpacker.unpack_list() + log_info("S: RECORD (%r)", data) response = self.responses[0] - response.on_record(*fields) + response.on_record(data or []) elif signature == IGNORED: + metadata = unpacker.unpack_map() + log_info("S: IGNORED (%r)", metadata) response = self.responses.popleft() response.complete = True - response.on_ignored(*fields) + response.on_ignored(metadata or {}) elif signature == FAILURE: + metadata = unpacker.unpack_map() + log_info("S: FAILURE (%r)", metadata) response = self.responses.popleft() response.complete = True - response.on_failure(*fields) + response.on_failure(metadata or {}) else: raise ProtocolError("Unexpected response message with signature %02X" % signature) @@ -365,8 +369,7 @@ def close(self): """ Close the connection. """ if not self.closed: - if __debug__: - log_info("~~ [CLOSE]") + log_info("~~ [CLOSE]") self.channel.socket.close() self.closed = True @@ -476,7 +479,7 @@ def connect(address, ssl_context=None, **config): # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html - if __debug__: log_info("~~ [CONNECT] %s", address) + log_info("~~ [CONNECT] %s", address) try: s = create_connection(address) except SocketError as error: @@ -488,7 +491,7 @@ def connect(address, ssl_context=None, **config): # Secure the connection if an SSL context has been provided if ssl_context and SSL_AVAILABLE: host, port = address - if __debug__: log_info("~~ [SECURE] %s", host) + log_info("~~ [SECURE] %s", host) try: s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) except SSLError as cause: @@ -514,9 +517,9 @@ def connect(address, ssl_context=None, **config): # Send details of the protocol versions supported supported_versions = [1, 0, 0, 0] handshake = [MAGIC_PREAMBLE] + supported_versions - if __debug__: log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions) + log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions) data = b"".join(struct_pack(">I", num) for num in handshake) - if __debug__: log_debug("C: b%r", data) + log_debug("C: b%r", data) s.sendall(data) # Handle the handshake response @@ -531,15 +534,15 @@ def connect(address, ssl_context=None, **config): log_error("S: [CLOSE]") raise ProtocolError("Connection to %r closed without handshake response" % (address,)) if data_size == 4: - if __debug__: log_debug("S: b%r", data) + log_debug("S: b%r", data) else: # Some other garbled data has been received log_error("S: @*#!") raise ProtocolError("Expected four byte handshake response, received %r instead" % data) agreed_version, = struct_unpack(">I", data) - if __debug__: log_info("S: [HANDSHAKE] %d", agreed_version) + log_info("S: [HANDSHAKE] %d", agreed_version) if agreed_version == 0: - if __debug__: log_info("~~ [CLOSE]") + log_info("~~ [CLOSE]") s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: diff --git a/neo4j/v1/packstream.py b/neo4j/v1/packstream.py index f9557688..406b7906 100644 --- a/neo4j/v1/packstream.py +++ b/neo4j/v1/packstream.py @@ -708,64 +708,12 @@ def unpack(self): return self.read(byte_size).decode(ENCODING) # List - elif marker_high == 0x90: - size = marker & 0x0F - return [unpack() for _ in range(size)] - elif marker == 0xD4: # LIST_8: - size = UNPACKED_UINT_8[self.read_bytes(1)] - return [unpack() for _ in range(size)] - elif marker == 0xD5: # LIST_16: - size = UNPACKED_UINT_16[self.read_bytes(2)] - return [unpack() for _ in range(size)] - elif marker == 0xD6: # LIST_32: - size = struct_unpack(UINT_32_STRUCT, self.read(4))[0] - return [unpack() for _ in range(size)] - elif marker == 0xD7: # LIST_STREAM: - value = [] - item = None - while item is not EndOfStream: - item = unpack() - if item is not EndOfStream: - value.append(item) - return value + elif 0x90 <= marker <= 0x9F or 0xD4 <= marker <= 0xD7: + return self._unpack_list(marker) # Map - elif marker_high == 0xA0: - size = marker & 0x0F - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xD8: # MAP_8: - size = UNPACKED_UINT_8[self.read_bytes(1)] - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xD9: # MAP_16: - size = UNPACKED_UINT_16[self.read_bytes(2)] - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xDA: # MAP_32: - size = struct_unpack(UINT_32_STRUCT, self.read(4))[0] - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xDB: # MAP_STREAM: - value = {} - key = None - while key is not EndOfStream: - key = unpack() - if key is not EndOfStream: - value[key] = unpack() - return value + elif 0xA0 <= marker <= 0xAF or 0xD8 <= marker <= 0xDB: + return self._unpack_map(marker) # Structure elif marker_high == 0xB0: @@ -793,6 +741,87 @@ def unpack(self): else: raise RuntimeError("Unknown PackStream marker %02X" % marker) + def unpack_list(self): + marker = self.read_marker() + return self._unpack_list(marker) + + def _unpack_list(self, marker): + marker_high = marker & 0xF0 + unpack = self.unpack + if marker_high == 0x90: + size = marker & 0x0F + if size == 0: + return [] + elif size == 1: + return [unpack()] + else: + return [unpack() for _ in range(size)] + elif marker == 0xD4: # LIST_8: + size = UNPACKED_UINT_8[self.read_bytes(1)] + return [unpack() for _ in range(size)] + elif marker == 0xD5: # LIST_16: + size = UNPACKED_UINT_16[self.read_bytes(2)] + return [unpack() for _ in range(size)] + elif marker == 0xD6: # LIST_32: + size = struct_unpack(UINT_32_STRUCT, self.read_bytes(4))[0] + return [unpack() for _ in range(size)] + elif marker == 0xD7: # LIST_STREAM: + value = [] + item = None + while item is not EndOfStream: + item = unpack() + if item is not EndOfStream: + value.append(item) + return value + else: + return None + + def unpack_map(self): + marker = self.read_marker() + return self._unpack_map(marker) + + def _unpack_map(self, marker): + marker_high = marker & 0xF0 + unpack = self.unpack + if marker_high == 0xA0: + size = marker & 0x0F + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xD8: # MAP_8: + size = UNPACKED_UINT_8[self.read_bytes(1)] + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xD9: # MAP_16: + size = UNPACKED_UINT_16[self.read_bytes(2)] + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xDA: # MAP_32: + size = struct_unpack(UINT_32_STRUCT, self.read_bytes(4))[0] + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xDB: # MAP_STREAM: + value = {} + key = None + while key is not EndOfStream: + key = unpack() + if key is not EndOfStream: + value[key] = unpack() + return value + else: + return None + def unpack_structure_header(self): marker = self.read_marker() if marker == -1: