From 8aba164564678c133267fd24b8590c68aa13f5a2 Mon Sep 17 00:00:00 2001 From: Andrzej Jackowski Date: Wed, 12 Mar 2025 13:18:14 +0100 Subject: [PATCH 1/2] Add support for SCYLLA_USE_METADATA_ID SCYLLA_USE_METADATA_ID extension allows using METADATA_ID (which was introduced in CQLv5) in CQLv4. This commit: - Introduce support for SCYLLA_USE_METADATA_ID in protocol extension negotation - Reuse CQLv5 metadata id implemnetation if use_metadata_id feature is enabled - Modify existing unit tests for introduced changes --- cassandra/connection.py | 2 +- cassandra/protocol.py | 36 +++++++++++++++++----------------- cassandra/protocol_features.py | 14 +++++++++++-- tests/unit/test_protocol.py | 31 +++++++++++++++-------------- 4 files changed, 47 insertions(+), 36 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index a2540a967b..f090a5a3eb 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1110,7 +1110,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) - msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, + msg = encoder(msg, request_id, self.protocol_version, self.features, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version) if self._is_checksumming_enabled: diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 29ae404048..388aca21f3 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -421,7 +421,7 @@ def __init__(self, cqlversion, options): self.cqlversion = cqlversion self.options = options - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): optmap = self.options.copy() optmap['CQL_VERSION'] = self.cqlversion write_stringmap(f, optmap) @@ -456,7 +456,7 @@ class CredentialsMessage(_MessageType): def __init__(self, creds): self.creds = creds - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): if protocol_version > 1: raise UnsupportedOperation( "Credentials-based authentication is not supported with " @@ -487,7 +487,7 @@ class AuthResponseMessage(_MessageType): def __init__(self, response): self.response = response - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_longstring(f, self.response) @@ -507,7 +507,7 @@ class OptionsMessage(_MessageType): opcode = 0x05 name = 'OPTIONS' - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): pass @@ -645,7 +645,7 @@ def __init__(self, query, consistency_level, serial_consistency_level=None, super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, False, continuous_paging_options, keyspace) - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_longstring(f, self.query) self._write_query_params(f, protocol_version) @@ -681,9 +681,9 @@ def _write_query_params(self, f, protocol_version): else: super(ExecuteMessage, self)._write_query_params(f, protocol_version) - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_string(f, self.query_id) - if ProtocolVersion.uses_prepared_metadata(protocol_version): + if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id: write_string(f, self.result_metadata_id) self._write_query_params(f, protocol_version) @@ -734,7 +734,7 @@ class ResultMessage(_MessageType): def __init__(self, kind): self.kind = kind - def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): if self.kind == RESULT_KIND_VOID: return elif self.kind == RESULT_KIND_ROWS: @@ -742,7 +742,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry elif self.kind == RESULT_KIND_SET_KEYSPACE: self.new_keyspace = read_string(f) elif self.kind == RESULT_KIND_PREPARED: - self.recv_results_prepared(f, protocol_version, user_type_map) + self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map) elif self.kind == RESULT_KIND_SCHEMA_CHANGE: self.recv_results_schema_change(f, protocol_version) else: @@ -752,7 +752,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): kind = read_int(f) msg = cls(kind) - msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy) return msg def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): @@ -785,9 +785,9 @@ def decode_row(row): col_md[3].cql_parameterized_type(), str(e))) - def recv_results_prepared(self, f, protocol_version, user_type_map): + def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): self.query_id = read_binary_string(f) - if ProtocolVersion.uses_prepared_metadata(protocol_version): + if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id: self.result_metadata_id = read_binary_string(f) else: self.result_metadata_id = None @@ -909,7 +909,7 @@ def __init__(self, query, keyspace=None): self.query = query self.keyspace = keyspace - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_longstring(f, self.query) flags = 0x00 @@ -953,7 +953,7 @@ def __init__(self, batch_type, queries, consistency_level, self.timestamp = timestamp self.keyspace = keyspace - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_byte(f, self.batch_type.value) write_short(f, len(self.queries)) for prepared, string_or_query_id, params in self.queries: @@ -1012,7 +1012,7 @@ class RegisterMessage(_MessageType): def __init__(self, event_list): self.event_list = event_list - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_stringlist(f, self.event_list) @@ -1086,7 +1086,7 @@ def __init__(self, op_type, op_id, next_pages=0): self.op_id = op_id self.next_pages = next_pages - def send_body(self, f, protocol_version): + def send_body(self, f, protocol_version, protocol_features): write_int(f, self.op_type) write_int(f, self.op_id) if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE: @@ -1122,7 +1122,7 @@ class _ProtocolHandler(object): """Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler""" @classmethod - def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): + def encode_message(cls, msg, stream_id, protocol_version, protocol_features, compressor, allow_beta_protocol_version): """ Encodes a message using the specified frame parameters, and compressor @@ -1138,7 +1138,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") flags |= CUSTOM_PAYLOAD_FLAG write_bytesmap(body, msg.custom_payload) - msg.send_body(body, protocol_version) + msg.send_body(body, protocol_version, protocol_features) body = body.getvalue() # With checksumming, the compression is done at the segment frame encoding diff --git a/cassandra/protocol_features.py b/cassandra/protocol_features.py index 4eb7019f84..84c108319a 100644 --- a/cassandra/protocol_features.py +++ b/cassandra/protocol_features.py @@ -7,25 +7,29 @@ RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR" TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1" +USE_METADATA_ID = "SCYLLA_USE_METADATA_ID" class ProtocolFeatures(object): rate_limit_error = None shard_id = 0 sharding_info = None tablets_routing_v1 = False + use_metadata_id = False - def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False): + def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, use_metadata_id=False): self.rate_limit_error = rate_limit_error self.shard_id = shard_id self.sharding_info = sharding_info self.tablets_routing_v1 = tablets_routing_v1 + self.use_metadata_id = use_metadata_id @staticmethod def parse_from_supported(supported): rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported) shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported) tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported) - return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1) + use_metadata_id = ProtocolFeatures.parse_metadata_id_info(supported) + return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, use_metadata_id) @staticmethod def maybe_parse_rate_limit_error(supported): @@ -49,6 +53,8 @@ def add_startup_options(self, options): options[RATE_LIMIT_ERROR_EXTENSION] = "" if self.tablets_routing_v1: options[TABLETS_ROUTING_V1] = "" + if self.use_metadata_id: + options[USE_METADATA_ID] = "" @staticmethod def parse_sharding_info(options): @@ -72,3 +78,7 @@ def parse_sharding_info(options): @staticmethod def parse_tablets_info(options): return TABLETS_ROUTING_V1 in options + + @staticmethod + def parse_metadata_id_info(options): + return USE_METADATA_ID in options diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 907f62f2bb..b614293a94 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -23,6 +23,7 @@ _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, BatchMessage ) +from cassandra.protocol_features import ProtocolFeatures from cassandra.query import BatchType from cassandra.marshal import uint32_unpack from cassandra.cluster import ContinuousPagingOptions @@ -43,11 +44,11 @@ def test_prepare_message(self): message = PrepareMessage("a") io = Mock() - message.send_body(io, 4) + message.send_body(io, 4, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)]) io.reset_mock() - message.send_body(io, 5) + message.send_body(io, 5, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)]) @@ -55,12 +56,12 @@ def test_execute_message(self): message = ExecuteMessage('1', [], 4) io = Mock() - message.send_body(io, 4) + message.send_body(io, 4, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) io.reset_mock() message.result_metadata_id = 'foo' - message.send_body(io, 5) + message.send_body(io, 5, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x03',), (b'foo',), @@ -80,11 +81,11 @@ def test_query_message(self): message = QueryMessage("a", 3) io = Mock() - message.send_body(io, 4) + message.send_body(io, 4, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) io.reset_mock() - message.send_body(io, 5) + message.send_body(io, 5, ProtocolFeatures()) self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) def _check_calls(self, io, expected): @@ -112,10 +113,10 @@ def test_continuous_paging(self): io = Mock() for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS if not ProtocolVersion.has_continuous_paging_support(version)]: - self.assertRaises(UnsupportedOperation, message.send_body, io, version) + self.assertRaises(UnsupportedOperation, message.send_body, io, version, ProtocolFeatures()) io.reset_mock() - message.send_body(io, ProtocolVersion.DSE_V1) + message.send_body(io, ProtocolVersion.DSE_V1, ProtocolFeatures()) # continuous paging adds two write calls to the buffer self.assertEqual(len(io.write.mock_calls), 6) @@ -142,7 +143,7 @@ def test_prepare_flag(self): message = PrepareMessage("a") io = Mock() for version in ProtocolVersion.SUPPORTED_VERSIONS: - message.send_body(io, version) + message.send_body(io, version, ProtocolFeatures()) if ProtocolVersion.uses_prepare_flags(version): self.assertEqual(len(io.write.mock_calls), 3) else: @@ -155,7 +156,7 @@ def test_prepare_flag_with_keyspace(self): for version in ProtocolVersion.SUPPORTED_VERSIONS: if ProtocolVersion.uses_keyspace_flag(version): - message.send_body(io, version) + message.send_body(io, version, ProtocolFeatures()) self._check_calls(io, [ (b'\x00\x00\x00\x01',), (b'a',), @@ -165,7 +166,7 @@ def test_prepare_flag_with_keyspace(self): ]) else: with self.assertRaises(UnsupportedOperation): - message.send_body(io, version) + message.send_body(io, version, ProtocolFeatures()) io.reset_mock() def test_keyspace_flag_raises_before_v5(self): @@ -173,7 +174,7 @@ def test_keyspace_flag_raises_before_v5(self): io = Mock(name='io') with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'): - keyspace_message.send_body(io, protocol_version=4) + keyspace_message.send_body(io, protocol_version=4, protocol_features=ProtocolFeatures()) io.assert_not_called() def test_keyspace_written_with_length(self): @@ -186,7 +187,7 @@ def test_keyspace_written_with_length(self): ] QueryMessage('a', consistency_level=3, keyspace='ks').send_body( - io, protocol_version=5 + io, protocol_version=5, protocol_features=ProtocolFeatures() ) self._check_calls(io, base_expected + [ (b'\x00\x02',), # length of keyspace string @@ -196,7 +197,7 @@ def test_keyspace_written_with_length(self): io.reset_mock() QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( - io, protocol_version=5 + io, protocol_version=5, protocol_features=ProtocolFeatures() ) self._check_calls(io, base_expected + [ (b'\x00\x08',), # length of keyspace string @@ -215,7 +216,7 @@ def test_batch_message_with_keyspace(self): consistency_level=3, keyspace='ks' ) - batch.send_body(io, protocol_version=5) + batch.send_body(io, protocol_version=5, protocol_features=ProtocolFeatures()) self._check_calls(io, ((b'\x00',), (b'\x00\x03',), (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt a',), From c1809c1744aec7029a76cd84a6a8cb0a940f7777 Mon Sep 17 00:00:00 2001 From: Andrzej Jackowski Date: Fri, 14 Mar 2025 09:48:50 +0100 Subject: [PATCH 2/2] Add skip metadata handling in _QueryMessage This change restores handling of skip_metadata in _QueryMessage, that was removed in 2019 (most likely to prevent metadata inconsistencies in prepared statements). I'm not sure if any other changes are required, as there were many modifications in the codebase since the flag handling was removed. --- cassandra/protocol.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 388aca21f3..4ed1c7dfa8 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -606,6 +606,9 @@ def _write_query_params(self, f, protocol_version): "Keyspaces may only be set on queries with protocol version " "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + if self.skip_meta is not None and self.skip_meta: + flags |= _SKIP_METADATA_FLAG + if ProtocolVersion.uses_int_query_flags(protocol_version): write_uint(f, flags) else: