Skip to content

Commit c4a808d

Browse files
authored
PYTHON-1369 Extend driver vector support to arbitrary subtypes and fix handling of variable length types (OSS C* 5.0) (#1217)
1 parent d05e9d3 commit c4a808d

File tree

7 files changed

+504
-72
lines changed

7 files changed

+504
-72
lines changed

cassandra/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -744,9 +744,3 @@ def __init__(self, msg, excs=[]):
744744
if excs:
745745
complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs))
746746
Exception.__init__(self, complete_msg)
747-
748-
class VectorDeserializationFailure(DriverException):
749-
"""
750-
The driver was unable to deserialize a given vector
751-
"""
752-
pass

cassandra/cqltypes.py

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
int32_pack, int32_unpack, int64_pack, int64_unpack,
4949
float_pack, float_unpack, double_pack, double_unpack,
5050
varint_pack, varint_unpack, point_be, point_le,
51-
vints_pack, vints_unpack)
52-
from cassandra import util, VectorDeserializationFailure
51+
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
52+
from cassandra import util
5353

5454
_little_endian_flag = 1 # we always serialize LE
5555
import ipaddress
@@ -392,6 +392,9 @@ def cass_parameterized_type(cls, full=False):
392392
"""
393393
return cls.cass_parameterized_type_with(cls.subtypes, full=full)
394394

395+
@classmethod
396+
def serial_size(cls):
397+
return None
395398

396399
# it's initially named with a _ to avoid registering it as a real type, but
397400
# client programs may want to use the name still for isinstance(), etc
@@ -457,10 +460,12 @@ def serialize(uuid, protocol_version):
457460
except AttributeError:
458461
raise TypeError("Got a non-UUID object for a UUID value")
459462

463+
@classmethod
464+
def serial_size(cls):
465+
return 16
460466

461467
class BooleanType(_CassandraType):
462468
typename = 'boolean'
463-
serial_size = 1
464469

465470
@staticmethod
466471
def deserialize(byts, protocol_version):
@@ -470,6 +475,10 @@ def deserialize(byts, protocol_version):
470475
def serialize(truth, protocol_version):
471476
return int8_pack(truth)
472477

478+
@classmethod
479+
def serial_size(cls):
480+
return 1
481+
473482
class ByteType(_CassandraType):
474483
typename = 'tinyint'
475484

@@ -500,7 +509,6 @@ def serialize(var, protocol_version):
500509

501510
class FloatType(_CassandraType):
502511
typename = 'float'
503-
serial_size = 4
504512

505513
@staticmethod
506514
def deserialize(byts, protocol_version):
@@ -510,10 +518,12 @@ def deserialize(byts, protocol_version):
510518
def serialize(byts, protocol_version):
511519
return float_pack(byts)
512520

521+
@classmethod
522+
def serial_size(cls):
523+
return 4
513524

514525
class DoubleType(_CassandraType):
515526
typename = 'double'
516-
serial_size = 8
517527

518528
@staticmethod
519529
def deserialize(byts, protocol_version):
@@ -523,10 +533,12 @@ def deserialize(byts, protocol_version):
523533
def serialize(byts, protocol_version):
524534
return double_pack(byts)
525535

536+
@classmethod
537+
def serial_size(cls):
538+
return 8
526539

527540
class LongType(_CassandraType):
528541
typename = 'bigint'
529-
serial_size = 8
530542

531543
@staticmethod
532544
def deserialize(byts, protocol_version):
@@ -536,10 +548,12 @@ def deserialize(byts, protocol_version):
536548
def serialize(byts, protocol_version):
537549
return int64_pack(byts)
538550

551+
@classmethod
552+
def serial_size(cls):
553+
return 8
539554

540555
class Int32Type(_CassandraType):
541556
typename = 'int'
542-
serial_size = 4
543557

544558
@staticmethod
545559
def deserialize(byts, protocol_version):
@@ -549,6 +563,9 @@ def deserialize(byts, protocol_version):
549563
def serialize(byts, protocol_version):
550564
return int32_pack(byts)
551565

566+
@classmethod
567+
def serial_size(cls):
568+
return 4
552569

553570
class IntegerType(_CassandraType):
554571
typename = 'varint'
@@ -645,14 +662,16 @@ def serialize(v, protocol_version):
645662

646663
return int64_pack(int(timestamp))
647664

665+
@classmethod
666+
def serial_size(cls):
667+
return 8
648668

649669
class TimestampType(DateType):
650670
pass
651671

652672

653673
class TimeUUIDType(DateType):
654674
typename = 'timeuuid'
655-
serial_size = 16
656675

657676
def my_timestamp(self):
658677
return util.unix_time_from_uuid1(self.val)
@@ -668,6 +687,9 @@ def serialize(timeuuid, protocol_version):
668687
except AttributeError:
669688
raise TypeError("Got a non-UUID object for a UUID value")
670689

690+
@classmethod
691+
def serial_size(cls):
692+
return 16
671693

672694
class SimpleDateType(_CassandraType):
673695
typename = 'date'
@@ -699,7 +721,6 @@ def serialize(val, protocol_version):
699721

700722
class ShortType(_CassandraType):
701723
typename = 'smallint'
702-
serial_size = 2
703724

704725
@staticmethod
705726
def deserialize(byts, protocol_version):
@@ -709,10 +730,14 @@ def deserialize(byts, protocol_version):
709730
def serialize(byts, protocol_version):
710731
return int16_pack(byts)
711732

712-
713733
class TimeType(_CassandraType):
714734
typename = 'time'
715-
serial_size = 8
735+
# Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
736+
# variable size... and we have to match what the server expects since the server
737+
# uses that specification to encode data of that type.
738+
#@classmethod
739+
#def serial_size(cls):
740+
# return 8
716741

717742
@staticmethod
718743
def deserialize(byts, protocol_version):
@@ -1409,6 +1434,11 @@ class VectorType(_CassandraType):
14091434
vector_size = 0
14101435
subtype = None
14111436

1437+
@classmethod
1438+
def serial_size(cls):
1439+
serialized_size = cls.subtype.serial_size()
1440+
return cls.vector_size * serialized_size if serialized_size is not None else None
1441+
14121442
@classmethod
14131443
def apply_parameters(cls, params, names):
14141444
assert len(params) == 2
@@ -1418,19 +1448,50 @@ def apply_parameters(cls, params, names):
14181448

14191449
@classmethod
14201450
def deserialize(cls, byts, protocol_version):
1421-
serialized_size = getattr(cls.subtype, "serial_size", None)
1422-
if not serialized_size:
1423-
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
1424-
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1425-
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
1451+
serialized_size = cls.subtype.serial_size()
1452+
if serialized_size is not None:
1453+
expected_byte_size = serialized_size * cls.vector_size
1454+
if len(byts) != expected_byte_size:
1455+
raise ValueError(
1456+
"Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\
1457+
.format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts)))
1458+
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1459+
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
1460+
1461+
idx = 0
1462+
rv = []
1463+
while (len(rv) < cls.vector_size):
1464+
try:
1465+
size, bytes_read = uvint_unpack(byts[idx:])
1466+
idx += bytes_read
1467+
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
1468+
idx += size
1469+
except:
1470+
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
1471+
.format(len(rv)))
1472+
1473+
# If we have any additional data in the serialized vector treat that as an error as well
1474+
if idx < len(byts):
1475+
raise ValueError("Additional bytes remaining after vector deserialization completed")
1476+
return rv
14261477

14271478
@classmethod
14281479
def serialize(cls, v, protocol_version):
1480+
v_length = len(v)
1481+
if cls.vector_size != v_length:
1482+
raise ValueError(
1483+
"Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\
1484+
.format(cls.vector_size, cls.subtype.typename, v_length))
1485+
1486+
serialized_size = cls.subtype.serial_size()
14291487
buf = io.BytesIO()
14301488
for item in v:
1431-
buf.write(cls.subtype.serialize(item, protocol_version))
1489+
item_bytes = cls.subtype.serialize(item, protocol_version)
1490+
if serialized_size is None:
1491+
buf.write(uvint_pack(len(item_bytes)))
1492+
buf.write(item_bytes)
14321493
return buf.getvalue()
14331494

14341495
@classmethod
14351496
def cql_parameterized_type(cls):
1436-
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
1497+
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)

cassandra/encoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
log = logging.getLogger(__name__)
2222

2323
from binascii import hexlify
24+
from decimal import Decimal
2425
import calendar
2526
import datetime
2627
import math
@@ -59,6 +60,7 @@ class Encoder(object):
5960
def __init__(self):
6061
self.mapping = {
6162
float: self.cql_encode_float,
63+
Decimal: self.cql_encode_decimal,
6264
bytearray: self.cql_encode_bytes,
6365
str: self.cql_encode_str,
6466
int: self.cql_encode_object,
@@ -217,3 +219,6 @@ def cql_encode_ipaddress(self, val):
217219
is suitable for ``inet`` type columns.
218220
"""
219221
return "'%s'" % val.compressed
222+
223+
def cql_encode_decimal(self, val):
224+
return self.cql_encode_float(float(val))

cassandra/marshal.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def vints_unpack(term): # noqa
111111

112112
return tuple(values)
113113

114-
115114
def vints_pack(values):
116115
revbytes = bytearray()
117116
values = [int(v) for v in values[::-1]]
@@ -143,3 +142,48 @@ def vints_pack(values):
143142

144143
revbytes.reverse()
145144
return bytes(revbytes)
145+
146+
def uvint_unpack(bytes):
147+
first_byte = bytes[0]
148+
149+
if (first_byte & 128) == 0:
150+
return (first_byte,1)
151+
152+
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
153+
rv = first_byte & (0xff >> num_extra_bytes)
154+
for idx in range(1,num_extra_bytes + 1):
155+
new_byte = bytes[idx]
156+
rv <<= 8
157+
rv |= new_byte & 0xff
158+
159+
return (rv, num_extra_bytes + 1)
160+
161+
def uvint_pack(val):
162+
rv = bytearray()
163+
if val < 128:
164+
rv.append(val)
165+
else:
166+
v = val
167+
num_extra_bytes = 0
168+
num_bits = v.bit_length()
169+
# We need to reserve (num_extra_bytes+1) bits in the first byte
170+
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
171+
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
172+
reserved_bits = num_extra_bytes + 1
173+
while num_bits > (8-(reserved_bits)):
174+
num_extra_bytes += 1
175+
num_bits -= 8
176+
reserved_bits = min(num_extra_bytes + 1, 8)
177+
rv.append(v & 0xff)
178+
v >>= 8
179+
180+
if num_extra_bytes > 8:
181+
raise ValueError('Value %d is too big and cannot be encoded as vint' % val)
182+
183+
# We can now store the last bits in the first byte
184+
n = 8 - num_extra_bytes
185+
v |= (0xff >> n << n)
186+
rv.append(abs(v))
187+
188+
rv.reverse()
189+
return bytes(rv)

tests/integration/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ def _id_and_mark(f):
330330
greaterthanorequalcass36 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.6'), 'Cassandra version 3.6 or greater required')
331331
greaterthanorequalcass3_10 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.10'), 'Cassandra version 3.10 or greater required')
332332
greaterthanorequalcass3_11 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.11'), 'Cassandra version 3.11 or greater required')
333-
greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0-a'), 'Cassandra version 4.0 or greater required')
334-
lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0-a'), 'Cassandra version less or equal to 4.0 required')
335-
lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0-a'), 'Cassandra version less than 4.0 required')
333+
greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required')
334+
greaterthanorequalcass50 = unittest.skipUnless(CASSANDRA_VERSION >= Version('5.0-beta'), 'Cassandra version 5.0 or greater required')
335+
lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0'), 'Cassandra version less or equal to 4.0 required')
336+
lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0'), 'Cassandra version less than 4.0 required')
336337
lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < Version('3.0'), 'Cassandra version less then 3.0 required')
337338
greaterthanorequaldse68 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.8'), "DSE 6.8 or greater required for this test")
338339
greaterthanorequaldse67 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.7'), "DSE 6.7 or greater required for this test")

0 commit comments

Comments
 (0)