Skip to content

Commit e90c0f5

Browse files
authored
PYTHON-1371 Add explicit exception type for serialization failures (#1193)
1 parent 120277d commit e90c0f5

File tree

3 files changed

+84
-14
lines changed

3 files changed

+84
-14
lines changed

cassandra/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,4 +743,10 @@ def __init__(self, msg, excs=[]):
743743
complete_msg = msg
744744
if excs:
745745
complete_msg += ("The following exceptions were observed: \n" + '\n'.join(str(e) for e in excs))
746-
Exception.__init__(self, complete_msg)
746+
Exception.__init__(self, complete_msg)
747+
748+
class VectorDeserializationFailure(DriverException):
749+
"""
750+
The driver was unable to deserialize a given vector
751+
"""
752+
pass

cassandra/cqltypes.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
float_pack, float_unpack, double_pack, double_unpack,
5050
varint_pack, varint_unpack, point_be, point_le,
5151
vints_pack, vints_unpack)
52-
from cassandra import util
52+
from cassandra import util, VectorDeserializationFailure
5353

5454
_little_endian_flag = 1 # we always serialize LE
5555
import ipaddress
@@ -461,6 +461,7 @@ def serialize(uuid, protocol_version):
461461

462462
class BooleanType(_CassandraType):
463463
typename = 'boolean'
464+
serial_size = 1
464465

465466
@staticmethod
466467
def deserialize(byts, protocol_version):
@@ -500,6 +501,7 @@ def serialize(var, protocol_version):
500501

501502
class FloatType(_CassandraType):
502503
typename = 'float'
504+
serial_size = 4
503505

504506
@staticmethod
505507
def deserialize(byts, protocol_version):
@@ -512,6 +514,7 @@ def serialize(byts, protocol_version):
512514

513515
class DoubleType(_CassandraType):
514516
typename = 'double'
517+
serial_size = 8
515518

516519
@staticmethod
517520
def deserialize(byts, protocol_version):
@@ -524,6 +527,7 @@ def serialize(byts, protocol_version):
524527

525528
class LongType(_CassandraType):
526529
typename = 'bigint'
530+
serial_size = 8
527531

528532
@staticmethod
529533
def deserialize(byts, protocol_version):
@@ -536,6 +540,7 @@ def serialize(byts, protocol_version):
536540

537541
class Int32Type(_CassandraType):
538542
typename = 'int'
543+
serial_size = 4
539544

540545
@staticmethod
541546
def deserialize(byts, protocol_version):
@@ -648,6 +653,7 @@ class TimestampType(DateType):
648653

649654
class TimeUUIDType(DateType):
650655
typename = 'timeuuid'
656+
serial_size = 16
651657

652658
def my_timestamp(self):
653659
return util.unix_time_from_uuid1(self.val)
@@ -694,6 +700,7 @@ def serialize(val, protocol_version):
694700

695701
class ShortType(_CassandraType):
696702
typename = 'smallint'
703+
serial_size = 2
697704

698705
@staticmethod
699706
def deserialize(byts, protocol_version):
@@ -706,6 +713,7 @@ def serialize(byts, protocol_version):
706713

707714
class TimeType(_CassandraType):
708715
typename = 'time'
716+
serial_size = 8
709717

710718
@staticmethod
711719
def deserialize(byts, protocol_version):
@@ -1411,8 +1419,11 @@ def apply_parameters(cls, params, names):
14111419

14121420
@classmethod
14131421
def deserialize(cls, byts, protocol_version):
1414-
indexes = (4 * x for x in range(0, cls.vector_size))
1415-
return [cls.subtype.deserialize(byts[idx:idx + 4], protocol_version) for idx in indexes]
1422+
serialized_size = getattr(cls.subtype, "serial_size", None)
1423+
if not serialized_size:
1424+
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
1425+
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1426+
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
14161427

14171428
@classmethod
14181429
def serialize(cls, v, protocol_version):

tests/unit/test_types.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import datetime
1717
import tempfile
1818
import time
19+
import uuid
1920
from binascii import unhexlify
2021

2122
import cassandra
22-
from cassandra import util
23+
from cassandra import util, VectorDeserializationFailure
2324
from cassandra.cqltypes import (
2425
CassandraType, DateRangeType, DateType, DecimalType,
2526
EmptyValue, LongType, SetType, UTF8Type,
@@ -308,15 +309,67 @@ def test_cql_quote(self):
308309
self.assertEqual(cql_quote('test'), "'test'")
309310
self.assertEqual(cql_quote(0), '0')
310311

311-
def test_vector_round_trip(self):
312-
base = [3.4, 2.9, 41.6, 12.0]
313-
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
314-
base_bytes = ctype.serialize(base, 0)
315-
self.assertEqual(16, len(base_bytes))
316-
result = ctype.deserialize(base_bytes, 0)
317-
self.assertEqual(len(base), len(result))
318-
for idx in range(0,len(base)):
319-
self.assertAlmostEqual(base[idx], result[idx], places=5)
312+
def test_vector_round_trip_types_with_serialized_size(self):
313+
# Test all the types which specify a serialized size... see PYTHON-1371 for details
314+
self._round_trip_test([True, False, False, True], \
315+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)")
316+
self._round_trip_test([3.4, 2.9, 41.6, 12.0], \
317+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
318+
self._round_trip_test([3.4, 2.9, 41.6, 12.0], \
319+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DoubleType, 4)")
320+
self._round_trip_test([3, 2, 41, 12], \
321+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.LongType, 4)")
322+
self._round_trip_test([3, 2, 41, 12], \
323+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 4)")
324+
self._round_trip_test([uuid.uuid1(), uuid.uuid1(), uuid.uuid1(), uuid.uuid1()], \
325+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)")
326+
self._round_trip_test([3, 2, 41, 12], \
327+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)")
328+
self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \
329+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)")
330+
331+
def test_vector_round_trip_types_without_serialized_size(self):
332+
# Test all the types which do not specify a serialized size... see PYTHON-1371 for details
333+
# Varints
334+
with self.assertRaises(VectorDeserializationFailure):
335+
self._round_trip_test([3, 2, 41, 12], \
336+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)")
337+
# ASCII text
338+
with self.assertRaises(VectorDeserializationFailure):
339+
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
340+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)")
341+
# UTF8 text
342+
with self.assertRaises(VectorDeserializationFailure):
343+
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
344+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)")
345+
# Duration (containts varints)
346+
with self.assertRaises(VectorDeserializationFailure):
347+
self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \
348+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)")
349+
# List (of otherwise serializable type)
350+
with self.assertRaises(VectorDeserializationFailure):
351+
self._round_trip_test([[3.4], [2.9], [41.6], [12.0]], \
352+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType(org.apache.cassandra.db.marshal.FloatType), 4)")
353+
# Set (of otherwise serializable type)
354+
with self.assertRaises(VectorDeserializationFailure):
355+
self._round_trip_test([set([3.4]), set([2.9]), set([41.6]), set([12.0])], \
356+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.FloatType), 4)")
357+
# Map (of otherwise serializable types)
358+
with self.assertRaises(VectorDeserializationFailure):
359+
self._round_trip_test([{1:3.4}, {2:2.9}, {3:41.6}, {4:12.0}], \
360+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \
361+
(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)")
362+
363+
def _round_trip_test(self, data, ctype_str):
364+
ctype = parse_casstype_args(ctype_str)
365+
data_bytes = ctype.serialize(data, 0)
366+
serialized_size = getattr(ctype.subtype, "serial_size", None)
367+
if serialized_size:
368+
self.assertEqual(serialized_size * len(data), len(data_bytes))
369+
result = ctype.deserialize(data_bytes, 0)
370+
self.assertEqual(len(data), len(result))
371+
for idx in range(0,len(data)):
372+
self.assertAlmostEqual(data[idx], result[idx], places=5)
320373

321374
def test_vector_cql_parameterized_type(self):
322375
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")

0 commit comments

Comments
 (0)