From f2e6c2a2601d3013aa70c9e977f92fad108e4388 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 12 Jan 2022 10:25:06 +0100 Subject: [PATCH] Roll back data written to output buffer on packing failure While packing data to packstream, several errors can occur (integers that are out of bounds, unknown data types, etc.). On packing failure, the driver should never send the half-finished packed data over the wire. This will most likely cause the server to close the connection as the data will be corrupt. --- neo4j/_async/io/_bolt.py | 3 ++- neo4j/_async/io/_common.py | 20 ++++++++++++++++++++ neo4j/_sync/io/_bolt.py | 3 ++- neo4j/_sync/io/_common.py | 20 ++++++++++++++++++++ 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 503ebeaf..83c03c9e 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -443,7 +443,8 @@ def _append(self, signature, fields=(), response=None): :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks """ - self.packer.pack_struct(signature, fields) + with self.outbox.tmp_buffer(): + self.packer.pack_struct(signature, fields) self.outbox.wrap_message() self.responses.append(response) diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py index 486e33df..aaf458f7 100644 --- a/neo4j/_async/io/_common.py +++ b/neo4j/_async/io/_common.py @@ -17,6 +17,7 @@ import asyncio +from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -94,11 +95,14 @@ def __init__(self, max_chunk_size=16384): self._chunked_data = bytearray() self._raw_data = bytearray() self.write = self._raw_data.extend + self._tmp_buffering = 0 def max_chunk_size(self): return self._max_chunk_size def clear(self): + if self._tmp_buffering: + raise RuntimeError("Cannot clear while buffering") self._chunked_data = bytearray() self._raw_data.clear() @@ -128,13 +132,29 @@ def _chunk_data(self): self._raw_data.clear() def wrap_message(self): + if self._tmp_buffering: + raise RuntimeError("Cannot wrap message while buffering") self._chunk_data() self._chunked_data += b"\x00\x00" def view(self): + if self._tmp_buffering: + raise RuntimeError("Cannot view while buffering") self._chunk_data() return memoryview(self._chunked_data) + @contextmanager + def tmp_buffer(self): + self._tmp_buffering += 1 + old_len = len(self._raw_data) + try: + yield + except Exception: + del self._raw_data[old_len:] + raise + finally: + self._tmp_buffering -= 1 + class ConnectionErrorHandler: """ diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index 82ee8b62..007bb6b4 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -443,7 +443,8 @@ def _append(self, signature, fields=(), response=None): :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks """ - self.packer.pack_struct(signature, fields) + with self.outbox.tmp_buffer(): + self.packer.pack_struct(signature, fields) self.outbox.wrap_message() self.responses.append(response) diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py index 408de0a1..647da7ec 100644 --- a/neo4j/_sync/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -17,6 +17,7 @@ import asyncio +from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -94,11 +95,14 @@ def __init__(self, max_chunk_size=16384): self._chunked_data = bytearray() self._raw_data = bytearray() self.write = self._raw_data.extend + self._tmp_buffering = 0 def max_chunk_size(self): return self._max_chunk_size def clear(self): + if self._tmp_buffering: + raise RuntimeError("Cannot clear while buffering") self._chunked_data = bytearray() self._raw_data.clear() @@ -128,13 +132,29 @@ def _chunk_data(self): self._raw_data.clear() def wrap_message(self): + if self._tmp_buffering: + raise RuntimeError("Cannot wrap message while buffering") self._chunk_data() self._chunked_data += b"\x00\x00" def view(self): + if self._tmp_buffering: + raise RuntimeError("Cannot view while buffering") self._chunk_data() return memoryview(self._chunked_data) + @contextmanager + def tmp_buffer(self): + self._tmp_buffering += 1 + old_len = len(self._raw_data) + try: + yield + except Exception: + del self._raw_data[old_len:] + raise + finally: + self._tmp_buffering -= 1 + class ConnectionErrorHandler: """