Skip to content

Commit 8aba164

Browse files
Add support for SCYLLA_USE_METADATA_ID
SCYLLA_USE_METADATA_ID extension allows using METADATA_ID (which was introduced in CQLv5) in CQLv4. This commit: - Introduce support for SCYLLA_USE_METADATA_ID in protocol extension negotation - Reuse CQLv5 metadata id implemnetation if use_metadata_id feature is enabled - Modify existing unit tests for introduced changes
1 parent a401409 commit 8aba164

File tree

4 files changed

+47
-36
lines changed

4 files changed

+47
-36
lines changed

cassandra/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
11101110
# queue the decoder function with the request
11111111
# this allows us to inject custom functions per request to encode, decode messages
11121112
self._requests[request_id] = (cb, decoder, result_metadata)
1113-
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
1113+
msg = encoder(msg, request_id, self.protocol_version, self.features, compressor=self.compressor,
11141114
allow_beta_protocol_version=self.allow_beta_protocol_version)
11151115

11161116
if self._is_checksumming_enabled:

cassandra/protocol.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def __init__(self, cqlversion, options):
421421
self.cqlversion = cqlversion
422422
self.options = options
423423

424-
def send_body(self, f, protocol_version):
424+
def send_body(self, f, protocol_version, protocol_features):
425425
optmap = self.options.copy()
426426
optmap['CQL_VERSION'] = self.cqlversion
427427
write_stringmap(f, optmap)
@@ -456,7 +456,7 @@ class CredentialsMessage(_MessageType):
456456
def __init__(self, creds):
457457
self.creds = creds
458458

459-
def send_body(self, f, protocol_version):
459+
def send_body(self, f, protocol_version, protocol_features):
460460
if protocol_version > 1:
461461
raise UnsupportedOperation(
462462
"Credentials-based authentication is not supported with "
@@ -487,7 +487,7 @@ class AuthResponseMessage(_MessageType):
487487
def __init__(self, response):
488488
self.response = response
489489

490-
def send_body(self, f, protocol_version):
490+
def send_body(self, f, protocol_version, protocol_features):
491491
write_longstring(f, self.response)
492492

493493

@@ -507,7 +507,7 @@ class OptionsMessage(_MessageType):
507507
opcode = 0x05
508508
name = 'OPTIONS'
509509

510-
def send_body(self, f, protocol_version):
510+
def send_body(self, f, protocol_version, protocol_features):
511511
pass
512512

513513

@@ -645,7 +645,7 @@ def __init__(self, query, consistency_level, serial_consistency_level=None,
645645
super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size,
646646
paging_state, timestamp, False, continuous_paging_options, keyspace)
647647

648-
def send_body(self, f, protocol_version):
648+
def send_body(self, f, protocol_version, protocol_features):
649649
write_longstring(f, self.query)
650650
self._write_query_params(f, protocol_version)
651651

@@ -681,9 +681,9 @@ def _write_query_params(self, f, protocol_version):
681681
else:
682682
super(ExecuteMessage, self)._write_query_params(f, protocol_version)
683683

684-
def send_body(self, f, protocol_version):
684+
def send_body(self, f, protocol_version, protocol_features):
685685
write_string(f, self.query_id)
686-
if ProtocolVersion.uses_prepared_metadata(protocol_version):
686+
if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id:
687687
write_string(f, self.result_metadata_id)
688688
self._write_query_params(f, protocol_version)
689689

@@ -734,15 +734,15 @@ class ResultMessage(_MessageType):
734734
def __init__(self, kind):
735735
self.kind = kind
736736

737-
def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
737+
def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
738738
if self.kind == RESULT_KIND_VOID:
739739
return
740740
elif self.kind == RESULT_KIND_ROWS:
741741
self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
742742
elif self.kind == RESULT_KIND_SET_KEYSPACE:
743743
self.new_keyspace = read_string(f)
744744
elif self.kind == RESULT_KIND_PREPARED:
745-
self.recv_results_prepared(f, protocol_version, user_type_map)
745+
self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map)
746746
elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
747747
self.recv_results_schema_change(f, protocol_version)
748748
else:
@@ -752,7 +752,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
752752
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
753753
kind = read_int(f)
754754
msg = cls(kind)
755-
msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
755+
msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy)
756756
return msg
757757

758758
def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
@@ -785,9 +785,9 @@ def decode_row(row):
785785
col_md[3].cql_parameterized_type(),
786786
str(e)))
787787

788-
def recv_results_prepared(self, f, protocol_version, user_type_map):
788+
def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
789789
self.query_id = read_binary_string(f)
790-
if ProtocolVersion.uses_prepared_metadata(protocol_version):
790+
if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id:
791791
self.result_metadata_id = read_binary_string(f)
792792
else:
793793
self.result_metadata_id = None
@@ -909,7 +909,7 @@ def __init__(self, query, keyspace=None):
909909
self.query = query
910910
self.keyspace = keyspace
911911

912-
def send_body(self, f, protocol_version):
912+
def send_body(self, f, protocol_version, protocol_features):
913913
write_longstring(f, self.query)
914914

915915
flags = 0x00
@@ -953,7 +953,7 @@ def __init__(self, batch_type, queries, consistency_level,
953953
self.timestamp = timestamp
954954
self.keyspace = keyspace
955955

956-
def send_body(self, f, protocol_version):
956+
def send_body(self, f, protocol_version, protocol_features):
957957
write_byte(f, self.batch_type.value)
958958
write_short(f, len(self.queries))
959959
for prepared, string_or_query_id, params in self.queries:
@@ -1012,7 +1012,7 @@ class RegisterMessage(_MessageType):
10121012
def __init__(self, event_list):
10131013
self.event_list = event_list
10141014

1015-
def send_body(self, f, protocol_version):
1015+
def send_body(self, f, protocol_version, protocol_features):
10161016
write_stringlist(f, self.event_list)
10171017

10181018

@@ -1086,7 +1086,7 @@ def __init__(self, op_type, op_id, next_pages=0):
10861086
self.op_id = op_id
10871087
self.next_pages = next_pages
10881088

1089-
def send_body(self, f, protocol_version):
1089+
def send_body(self, f, protocol_version, protocol_features):
10901090
write_int(f, self.op_type)
10911091
write_int(f, self.op_id)
10921092
if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE:
@@ -1122,7 +1122,7 @@ class _ProtocolHandler(object):
11221122
"""Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler"""
11231123

11241124
@classmethod
1125-
def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version):
1125+
def encode_message(cls, msg, stream_id, protocol_version, protocol_features, compressor, allow_beta_protocol_version):
11261126
"""
11271127
Encodes a message using the specified frame parameters, and compressor
11281128
@@ -1138,7 +1138,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
11381138
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
11391139
flags |= CUSTOM_PAYLOAD_FLAG
11401140
write_bytesmap(body, msg.custom_payload)
1141-
msg.send_body(body, protocol_version)
1141+
msg.send_body(body, protocol_version, protocol_features)
11421142
body = body.getvalue()
11431143

11441144
# With checksumming, the compression is done at the segment frame encoding

cassandra/protocol_features.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,29 @@
77

88
RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
99
TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1"
10+
USE_METADATA_ID = "SCYLLA_USE_METADATA_ID"
1011

1112
class ProtocolFeatures(object):
1213
rate_limit_error = None
1314
shard_id = 0
1415
sharding_info = None
1516
tablets_routing_v1 = False
17+
use_metadata_id = False
1618

17-
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False):
19+
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, use_metadata_id=False):
1820
self.rate_limit_error = rate_limit_error
1921
self.shard_id = shard_id
2022
self.sharding_info = sharding_info
2123
self.tablets_routing_v1 = tablets_routing_v1
24+
self.use_metadata_id = use_metadata_id
2225

2326
@staticmethod
2427
def parse_from_supported(supported):
2528
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
2629
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
2730
tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported)
28-
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1)
31+
use_metadata_id = ProtocolFeatures.parse_metadata_id_info(supported)
32+
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, use_metadata_id)
2933

3034
@staticmethod
3135
def maybe_parse_rate_limit_error(supported):
@@ -49,6 +53,8 @@ def add_startup_options(self, options):
4953
options[RATE_LIMIT_ERROR_EXTENSION] = ""
5054
if self.tablets_routing_v1:
5155
options[TABLETS_ROUTING_V1] = ""
56+
if self.use_metadata_id:
57+
options[USE_METADATA_ID] = ""
5258

5359
@staticmethod
5460
def parse_sharding_info(options):
@@ -72,3 +78,7 @@ def parse_sharding_info(options):
7278
@staticmethod
7379
def parse_tablets_info(options):
7480
return TABLETS_ROUTING_V1 in options
81+
82+
@staticmethod
83+
def parse_metadata_id_info(options):
84+
return USE_METADATA_ID in options

tests/unit/test_protocol.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG,
2424
BatchMessage
2525
)
26+
from cassandra.protocol_features import ProtocolFeatures
2627
from cassandra.query import BatchType
2728
from cassandra.marshal import uint32_unpack
2829
from cassandra.cluster import ContinuousPagingOptions
@@ -43,24 +44,24 @@ def test_prepare_message(self):
4344
message = PrepareMessage("a")
4445
io = Mock()
4546

46-
message.send_body(io, 4)
47+
message.send_body(io, 4, ProtocolFeatures())
4748
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)])
4849

4950
io.reset_mock()
50-
message.send_body(io, 5)
51+
message.send_body(io, 5, ProtocolFeatures())
5152

5253
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)])
5354

5455
def test_execute_message(self):
5556
message = ExecuteMessage('1', [], 4)
5657
io = Mock()
5758

58-
message.send_body(io, 4)
59+
message.send_body(io, 4, ProtocolFeatures())
5960
self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)])
6061

6162
io.reset_mock()
6263
message.result_metadata_id = 'foo'
63-
message.send_body(io, 5)
64+
message.send_body(io, 5, ProtocolFeatures())
6465

6566
self._check_calls(io, [(b'\x00\x01',), (b'1',),
6667
(b'\x00\x03',), (b'foo',),
@@ -80,11 +81,11 @@ def test_query_message(self):
8081
message = QueryMessage("a", 3)
8182
io = Mock()
8283

83-
message.send_body(io, 4)
84+
message.send_body(io, 4, ProtocolFeatures())
8485
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)])
8586

8687
io.reset_mock()
87-
message.send_body(io, 5)
88+
message.send_body(io, 5, ProtocolFeatures())
8889
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)])
8990

9091
def _check_calls(self, io, expected):
@@ -112,10 +113,10 @@ def test_continuous_paging(self):
112113
io = Mock()
113114
for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS
114115
if not ProtocolVersion.has_continuous_paging_support(version)]:
115-
self.assertRaises(UnsupportedOperation, message.send_body, io, version)
116+
self.assertRaises(UnsupportedOperation, message.send_body, io, version, ProtocolFeatures())
116117

117118
io.reset_mock()
118-
message.send_body(io, ProtocolVersion.DSE_V1)
119+
message.send_body(io, ProtocolVersion.DSE_V1, ProtocolFeatures())
119120

120121
# continuous paging adds two write calls to the buffer
121122
self.assertEqual(len(io.write.mock_calls), 6)
@@ -142,7 +143,7 @@ def test_prepare_flag(self):
142143
message = PrepareMessage("a")
143144
io = Mock()
144145
for version in ProtocolVersion.SUPPORTED_VERSIONS:
145-
message.send_body(io, version)
146+
message.send_body(io, version, ProtocolFeatures())
146147
if ProtocolVersion.uses_prepare_flags(version):
147148
self.assertEqual(len(io.write.mock_calls), 3)
148149
else:
@@ -155,7 +156,7 @@ def test_prepare_flag_with_keyspace(self):
155156

156157
for version in ProtocolVersion.SUPPORTED_VERSIONS:
157158
if ProtocolVersion.uses_keyspace_flag(version):
158-
message.send_body(io, version)
159+
message.send_body(io, version, ProtocolFeatures())
159160
self._check_calls(io, [
160161
(b'\x00\x00\x00\x01',),
161162
(b'a',),
@@ -165,15 +166,15 @@ def test_prepare_flag_with_keyspace(self):
165166
])
166167
else:
167168
with self.assertRaises(UnsupportedOperation):
168-
message.send_body(io, version)
169+
message.send_body(io, version, ProtocolFeatures())
169170
io.reset_mock()
170171

171172
def test_keyspace_flag_raises_before_v5(self):
172173
keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks')
173174
io = Mock(name='io')
174175

175176
with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'):
176-
keyspace_message.send_body(io, protocol_version=4)
177+
keyspace_message.send_body(io, protocol_version=4, protocol_features=ProtocolFeatures())
177178
io.assert_not_called()
178179

179180
def test_keyspace_written_with_length(self):
@@ -186,7 +187,7 @@ def test_keyspace_written_with_length(self):
186187
]
187188

188189
QueryMessage('a', consistency_level=3, keyspace='ks').send_body(
189-
io, protocol_version=5
190+
io, protocol_version=5, protocol_features=ProtocolFeatures()
190191
)
191192
self._check_calls(io, base_expected + [
192193
(b'\x00\x02',), # length of keyspace string
@@ -196,7 +197,7 @@ def test_keyspace_written_with_length(self):
196197
io.reset_mock()
197198

198199
QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body(
199-
io, protocol_version=5
200+
io, protocol_version=5, protocol_features=ProtocolFeatures()
200201
)
201202
self._check_calls(io, base_expected + [
202203
(b'\x00\x08',), # length of keyspace string
@@ -215,7 +216,7 @@ def test_batch_message_with_keyspace(self):
215216
consistency_level=3,
216217
keyspace='ks'
217218
)
218-
batch.send_body(io, protocol_version=5)
219+
batch.send_body(io, protocol_version=5, protocol_features=ProtocolFeatures())
219220
self._check_calls(io,
220221
((b'\x00',), (b'\x00\x03',), (b'\x00',),
221222
(b'\x00\x00\x00\x06',), (b'stmt a',),

0 commit comments

Comments
 (0)