Skip to content

Add support for SCYLLA_USE_METADATA_ID and skip metadata #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder, result_metadata)
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
msg = encoder(msg, request_id, self.protocol_version, self.features, compressor=self.compressor,
allow_beta_protocol_version=self.allow_beta_protocol_version)

if self._is_checksumming_enabled:
Expand Down
39 changes: 21 additions & 18 deletions cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def __init__(self, cqlversion, options):
self.cqlversion = cqlversion
self.options = options

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
optmap = self.options.copy()
optmap['CQL_VERSION'] = self.cqlversion
write_stringmap(f, optmap)
Expand Down Expand Up @@ -456,7 +456,7 @@ class CredentialsMessage(_MessageType):
def __init__(self, creds):
self.creds = creds

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
if protocol_version > 1:
raise UnsupportedOperation(
"Credentials-based authentication is not supported with "
Expand Down Expand Up @@ -487,7 +487,7 @@ class AuthResponseMessage(_MessageType):
def __init__(self, response):
self.response = response

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_longstring(f, self.response)


Expand All @@ -507,7 +507,7 @@ class OptionsMessage(_MessageType):
opcode = 0x05
name = 'OPTIONS'

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
pass


Expand Down Expand Up @@ -606,6 +606,9 @@ def _write_query_params(self, f, protocol_version):
"Keyspaces may only be set on queries with protocol version "
"5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.")

if self.skip_meta is not None and self.skip_meta:
flags |= _SKIP_METADATA_FLAG

if ProtocolVersion.uses_int_query_flags(protocol_version):
write_uint(f, flags)
else:
Expand Down Expand Up @@ -645,7 +648,7 @@ def __init__(self, query, consistency_level, serial_consistency_level=None,
super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size,
paging_state, timestamp, False, continuous_paging_options, keyspace)

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_longstring(f, self.query)
self._write_query_params(f, protocol_version)

Expand Down Expand Up @@ -681,9 +684,9 @@ def _write_query_params(self, f, protocol_version):
else:
super(ExecuteMessage, self)._write_query_params(f, protocol_version)

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_string(f, self.query_id)
if ProtocolVersion.uses_prepared_metadata(protocol_version):
if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id:
write_string(f, self.result_metadata_id)
self._write_query_params(f, protocol_version)

Expand Down Expand Up @@ -734,15 +737,15 @@ class ResultMessage(_MessageType):
def __init__(self, kind):
self.kind = kind

def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
if self.kind == RESULT_KIND_VOID:
return
elif self.kind == RESULT_KIND_ROWS:
self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
elif self.kind == RESULT_KIND_SET_KEYSPACE:
self.new_keyspace = read_string(f)
elif self.kind == RESULT_KIND_PREPARED:
self.recv_results_prepared(f, protocol_version, user_type_map)
self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map)
elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
self.recv_results_schema_change(f, protocol_version)
else:
Expand All @@ -752,7 +755,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
kind = read_int(f)
msg = cls(kind)
msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy)
return msg

def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
Expand Down Expand Up @@ -785,9 +788,9 @@ def decode_row(row):
col_md[3].cql_parameterized_type(),
str(e)))

def recv_results_prepared(self, f, protocol_version, user_type_map):
def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
self.query_id = read_binary_string(f)
if ProtocolVersion.uses_prepared_metadata(protocol_version):
if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id:
self.result_metadata_id = read_binary_string(f)
else:
self.result_metadata_id = None
Expand Down Expand Up @@ -909,7 +912,7 @@ def __init__(self, query, keyspace=None):
self.query = query
self.keyspace = keyspace

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_longstring(f, self.query)

flags = 0x00
Expand Down Expand Up @@ -953,7 +956,7 @@ def __init__(self, batch_type, queries, consistency_level,
self.timestamp = timestamp
self.keyspace = keyspace

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_byte(f, self.batch_type.value)
write_short(f, len(self.queries))
for prepared, string_or_query_id, params in self.queries:
Expand Down Expand Up @@ -1012,7 +1015,7 @@ class RegisterMessage(_MessageType):
def __init__(self, event_list):
self.event_list = event_list

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_stringlist(f, self.event_list)


Expand Down Expand Up @@ -1086,7 +1089,7 @@ def __init__(self, op_type, op_id, next_pages=0):
self.op_id = op_id
self.next_pages = next_pages

def send_body(self, f, protocol_version):
def send_body(self, f, protocol_version, protocol_features):
write_int(f, self.op_type)
write_int(f, self.op_id)
if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE:
Expand Down Expand Up @@ -1122,7 +1125,7 @@ class _ProtocolHandler(object):
"""Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler"""

@classmethod
def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version):
def encode_message(cls, msg, stream_id, protocol_version, protocol_features, compressor, allow_beta_protocol_version):
"""
Encodes a message using the specified frame parameters, and compressor

Expand All @@ -1138,7 +1141,7 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
flags |= CUSTOM_PAYLOAD_FLAG
write_bytesmap(body, msg.custom_payload)
msg.send_body(body, protocol_version)
msg.send_body(body, protocol_version, protocol_features)
body = body.getvalue()

# With checksumming, the compression is done at the segment frame encoding
Expand Down
14 changes: 12 additions & 2 deletions cassandra/protocol_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,29 @@

RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1"
USE_METADATA_ID = "SCYLLA_USE_METADATA_ID"

class ProtocolFeatures(object):
rate_limit_error = None
shard_id = 0
sharding_info = None
tablets_routing_v1 = False
use_metadata_id = False

def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False):
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, use_metadata_id=False):
self.rate_limit_error = rate_limit_error
self.shard_id = shard_id
self.sharding_info = sharding_info
self.tablets_routing_v1 = tablets_routing_v1
self.use_metadata_id = use_metadata_id

@staticmethod
def parse_from_supported(supported):
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1)
use_metadata_id = ProtocolFeatures.parse_metadata_id_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, use_metadata_id)

@staticmethod
def maybe_parse_rate_limit_error(supported):
Expand All @@ -49,6 +53,8 @@ def add_startup_options(self, options):
options[RATE_LIMIT_ERROR_EXTENSION] = ""
if self.tablets_routing_v1:
options[TABLETS_ROUTING_V1] = ""
if self.use_metadata_id:
options[USE_METADATA_ID] = ""

@staticmethod
def parse_sharding_info(options):
Expand All @@ -72,3 +78,7 @@ def parse_sharding_info(options):
@staticmethod
def parse_tablets_info(options):
return TABLETS_ROUTING_V1 in options

@staticmethod
def parse_metadata_id_info(options):
return USE_METADATA_ID in options
31 changes: 16 additions & 15 deletions tests/unit/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG,
BatchMessage
)
from cassandra.protocol_features import ProtocolFeatures
from cassandra.query import BatchType
from cassandra.marshal import uint32_unpack
from cassandra.cluster import ContinuousPagingOptions
Expand All @@ -43,24 +44,24 @@ def test_prepare_message(self):
message = PrepareMessage("a")
io = Mock()

message.send_body(io, 4)
message.send_body(io, 4, ProtocolFeatures())
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)])

io.reset_mock()
message.send_body(io, 5)
message.send_body(io, 5, ProtocolFeatures())

self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)])

def test_execute_message(self):
message = ExecuteMessage('1', [], 4)
io = Mock()

message.send_body(io, 4)
message.send_body(io, 4, ProtocolFeatures())
self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)])

io.reset_mock()
message.result_metadata_id = 'foo'
message.send_body(io, 5)
message.send_body(io, 5, ProtocolFeatures())

self._check_calls(io, [(b'\x00\x01',), (b'1',),
(b'\x00\x03',), (b'foo',),
Expand All @@ -80,11 +81,11 @@ def test_query_message(self):
message = QueryMessage("a", 3)
io = Mock()

message.send_body(io, 4)
message.send_body(io, 4, ProtocolFeatures())
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)])

io.reset_mock()
message.send_body(io, 5)
message.send_body(io, 5, ProtocolFeatures())
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)])

def _check_calls(self, io, expected):
Expand Down Expand Up @@ -112,10 +113,10 @@ def test_continuous_paging(self):
io = Mock()
for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS
if not ProtocolVersion.has_continuous_paging_support(version)]:
self.assertRaises(UnsupportedOperation, message.send_body, io, version)
self.assertRaises(UnsupportedOperation, message.send_body, io, version, ProtocolFeatures())

io.reset_mock()
message.send_body(io, ProtocolVersion.DSE_V1)
message.send_body(io, ProtocolVersion.DSE_V1, ProtocolFeatures())

# continuous paging adds two write calls to the buffer
self.assertEqual(len(io.write.mock_calls), 6)
Expand All @@ -142,7 +143,7 @@ def test_prepare_flag(self):
message = PrepareMessage("a")
io = Mock()
for version in ProtocolVersion.SUPPORTED_VERSIONS:
message.send_body(io, version)
message.send_body(io, version, ProtocolFeatures())
if ProtocolVersion.uses_prepare_flags(version):
self.assertEqual(len(io.write.mock_calls), 3)
else:
Expand All @@ -155,7 +156,7 @@ def test_prepare_flag_with_keyspace(self):

for version in ProtocolVersion.SUPPORTED_VERSIONS:
if ProtocolVersion.uses_keyspace_flag(version):
message.send_body(io, version)
message.send_body(io, version, ProtocolFeatures())
self._check_calls(io, [
(b'\x00\x00\x00\x01',),
(b'a',),
Expand All @@ -165,15 +166,15 @@ def test_prepare_flag_with_keyspace(self):
])
else:
with self.assertRaises(UnsupportedOperation):
message.send_body(io, version)
message.send_body(io, version, ProtocolFeatures())
io.reset_mock()

def test_keyspace_flag_raises_before_v5(self):
keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks')
io = Mock(name='io')

with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'):
keyspace_message.send_body(io, protocol_version=4)
keyspace_message.send_body(io, protocol_version=4, protocol_features=ProtocolFeatures())
io.assert_not_called()

def test_keyspace_written_with_length(self):
Expand All @@ -186,7 +187,7 @@ def test_keyspace_written_with_length(self):
]

QueryMessage('a', consistency_level=3, keyspace='ks').send_body(
io, protocol_version=5
io, protocol_version=5, protocol_features=ProtocolFeatures()
)
self._check_calls(io, base_expected + [
(b'\x00\x02',), # length of keyspace string
Expand All @@ -196,7 +197,7 @@ def test_keyspace_written_with_length(self):
io.reset_mock()

QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body(
io, protocol_version=5
io, protocol_version=5, protocol_features=ProtocolFeatures()
)
self._check_calls(io, base_expected + [
(b'\x00\x08',), # length of keyspace string
Expand All @@ -215,7 +216,7 @@ def test_batch_message_with_keyspace(self):
consistency_level=3,
keyspace='ks'
)
batch.send_body(io, protocol_version=5)
batch.send_body(io, protocol_version=5, protocol_features=ProtocolFeatures())
self._check_calls(io,
((b'\x00',), (b'\x00\x03',), (b'\x00',),
(b'\x00\x00\x00\x06',), (b'stmt a',),
Expand Down
Loading