Skip to content

Commit 33f3077

Browse files
committed
Merge branch 'PYTHON-231'
Conflicts: cassandra/cqltypes.py cassandra/util.py
2 parents a1a2919 + dce4d17 commit 33f3077

File tree

5 files changed

+45
-25
lines changed

5 files changed

+45
-25
lines changed

cassandra/cqltypes.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -784,12 +784,12 @@ class MapType(_ParameterizedType):
784784

785785
@classmethod
786786
def validate(cls, val):
787-
subkeytype, subvaltype = cls.subtypes
788-
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
787+
key_type, value_type = cls.subtypes
788+
return dict((key_type.validate(k), value_type.validate(v)) for (k, v) in six.iteritems(val))
789789

790790
@classmethod
791791
def deserialize_safe(cls, byts, protocol_version):
792-
subkeytype, subvaltype = cls.subtypes
792+
key_type, value_type = cls.subtypes
793793
if protocol_version >= 3:
794794
unpack = int32_unpack
795795
length = 4
@@ -798,7 +798,7 @@ def deserialize_safe(cls, byts, protocol_version):
798798
length = 2
799799
numelements = unpack(byts[:length])
800800
p = length
801-
themap = util.OrderedMap()
801+
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
802802
for _ in range(numelements):
803803
key_len = unpack(byts[p:p + length])
804804
p += length
@@ -808,14 +808,14 @@ def deserialize_safe(cls, byts, protocol_version):
808808
p += length
809809
valbytes = byts[p:p + val_len]
810810
p += val_len
811-
key = subkeytype.from_binary(keybytes, protocol_version)
812-
val = subvaltype.from_binary(valbytes, protocol_version)
813-
themap._insert(key, val)
811+
key = key_type.from_binary(keybytes, protocol_version)
812+
val = value_type.from_binary(valbytes, protocol_version)
813+
themap._insert_unchecked(key, keybytes, val)
814814
return themap
815815

816816
@classmethod
817817
def serialize_safe(cls, themap, protocol_version):
818-
subkeytype, subvaltype = cls.subtypes
818+
key_type, value_type = cls.subtypes
819819
pack = int32_pack if protocol_version >= 3 else uint16_pack
820820
buf = io.BytesIO()
821821
buf.write(pack(len(themap)))
@@ -824,8 +824,8 @@ def serialize_safe(cls, themap, protocol_version):
824824
except AttributeError:
825825
raise TypeError("Got a non-map object for a map value")
826826
for key, val in items:
827-
keybytes = subkeytype.to_binary(key, protocol_version)
828-
valbytes = subvaltype.to_binary(val, protocol_version)
827+
keybytes = key_type.to_binary(key, protocol_version)
828+
valbytes = value_type.to_binary(val, protocol_version)
829829
buf.write(pack(len(keybytes)))
830830
buf.write(keybytes)
831831
buf.write(pack(len(valbytes)))

cassandra/util.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def _intersect(self, other):
680680
isect.add(item)
681681
return isect
682682

683+
683684
from collections import Mapping
684685
from six.moves import cPickle
685686

@@ -715,6 +716,7 @@ class OrderedMap(Mapping):
715716
or higher.
716717
717718
'''
719+
718720
def __init__(self, *args, **kwargs):
719721
if len(args) > 1:
720722
raise TypeError('expected at most 1 arguments, got %d' % len(args))
@@ -776,11 +778,25 @@ def __repr__(self):
776778
def __str__(self):
777779
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
778780

779-
@staticmethod
780-
def _serialize_key(key):
781+
def _serialize_key(self, key):
781782
return cPickle.dumps(key)
782783

783784

785+
class OrderedMapSerializedKey(OrderedMap):
786+
787+
def __init__(self, cass_type, protocol_version):
788+
super(OrderedMapSerializedKey, self).__init__()
789+
self.cass_key_type = cass_type
790+
self.protocol_version = protocol_version
791+
792+
def _insert_unchecked(self, key, flat_key, value):
793+
self._items.append((key, value))
794+
self._index[flat_key] = len(self._items) - 1
795+
796+
def _serialize_key(self, key):
797+
return self.cass_key_type.serialize(key, self.protocol_version)
798+
799+
784800
import datetime
785801
import time
786802

tests/unit/test_marshalling.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from decimal import Decimal
2424
from uuid import UUID
2525

26-
from cassandra.cqltypes import lookup_casstype
27-
from cassandra.util import OrderedMap, sortedset, Time
26+
from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type
27+
from cassandra.util import OrderedMap, OrderedMapSerializedKey, sortedset, Time
2828

2929
marshalled_value_pairs = (
3030
# binary form, type, python native type
@@ -75,7 +75,7 @@
7575
(b'', 'MapType(AsciiType, BooleanType)', None),
7676
(b'', 'ListType(FloatType)', None),
7777
(b'', 'SetType(LongType)', None),
78-
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
78+
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 0)),
7979
(b'\x00\x00', 'ListType(FloatType)', []),
8080
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
8181
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
@@ -84,9 +84,10 @@
8484
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1))
8585
)
8686

87-
ordered_map_value = OrderedMap([(u'\u307fbob', 199),
88-
(u'', -1),
89-
(u'\\', 0)])
87+
ordered_map_value = OrderedMapSerializedKey(UTF8Type, 2)
88+
ordered_map_value._insert(u'\u307fbob', 199)
89+
ordered_map_value._insert(u'', -1)
90+
ordered_map_value._insert(u'\\', 0)
9091

9192
# these following entries work for me right now, but they're dependent on
9293
# vagaries of internal python ordering for unordered types

tests/unit/test_policies.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
try:
1616
import unittest2 as unittest
1717
except ImportError:
18-
import unittest # noqa
18+
import unittest # noqa
1919

2020
from itertools import islice, cycle
2121
from mock import Mock
@@ -139,21 +139,22 @@ def host_down():
139139
threads.append(Thread(target=host_down))
140140

141141
# make the GIL switch after every instruction, maximizing
142-
# the chace of race conditions
143-
if six.PY2:
142+
# the chance of race conditions
143+
check = six.PY2 or '__pypy__' in sys.builtin_module_names
144+
if check:
144145
original_interval = sys.getcheckinterval()
145146
else:
146147
original_interval = sys.getswitchinterval()
147148

148149
try:
149-
if six.PY2:
150+
if check:
150151
sys.setcheckinterval(0)
151152
else:
152153
sys.setswitchinterval(0.0001)
153154
map(lambda t: t.start(), threads)
154155
map(lambda t: t.join(), threads)
155156
finally:
156-
if six.PY2:
157+
if check:
157158
sys.setcheckinterval(original_interval)
158159
else:
159160
sys.setswitchinterval(original_interval)
@@ -362,6 +363,7 @@ def test_default_dc(self):
362363
policy.on_add(host_remote)
363364
self.assertFalse(policy.local_dc)
364365

366+
365367
class TokenAwarePolicyTest(unittest.TestCase):
366368

367369
def test_wrap_round_robin(self):
@@ -519,7 +521,6 @@ def test_status_updates(self):
519521
qplan = list(policy.make_query_plan())
520522
self.assertEqual(qplan, [])
521523

522-
523524
def test_statement_keyspace(self):
524525
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
525526
for host in hosts:
@@ -666,6 +667,7 @@ def test_schedule(self):
666667

667668
ONE = ConsistencyLevel.ONE
668669

670+
669671
class RetryPolicyTest(unittest.TestCase):
670672

671673
def test_read_timeout(self):

tests/unit/test_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import calendar
2121
import datetime
2222
import tempfile
23+
import six
2324
import time
2425

2526
import cassandra
@@ -228,7 +229,7 @@ def __init__(self, subtypes, names):
228229

229230
@classmethod
230231
def apply_parameters(cls, subtypes, names):
231-
return cls(subtypes, [unhexlify(name) if name is not None else name for name in names])
232+
return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names])
232233

233234
class BarType(FooType):
234235
typename = 'org.apache.cassandra.db.marshal.BarType'

0 commit comments

Comments
 (0)