diff --git a/memsql/common/database.py b/memsql/common/database.py index 4d3ea3f..da9995a 100644 --- a/memsql/common/database.py +++ b/memsql/common/database.py @@ -4,6 +4,7 @@ import time import operator import six +from memsql.common import util try: from _thread import get_ident as _get_ident @@ -150,7 +151,7 @@ def _query(self, query, parameters, kwparameters, debug=False): if self._result is None: return self._rowcount - fields = [ f[0] for f in self._result.describe() ] + fields = [ (f[0], util.get_field_type_by_code(f[1]), ) for f in self._result.describe() ] rows = self._result.fetch_row(0) return SelectResult(fields, rows) @@ -180,31 +181,32 @@ def _ensure_connected(self): class Row(object): """A fast, ordered, partially-immutable dictlike object (or objectlike dict).""" - def __init__(self, fields, values): - self._fields = fields + def __init__(self, fields_and_types_tuple, values): + self._fields = list(map(lambda a: a[0].lower(), fields_and_types_tuple)) self._values = values + self._types = fields_and_types_tuple def __getattr__(self, name): try: - return self._values[self._fields.index(name)] + return self._values[self._fields.index(name.lower())] except (ValueError, IndexError): raise AttributeError(name) def __getitem__(self, name): try: - return self._values[self._fields.index(name)] + return self._values[self._fields.index(name.lower())] except (ValueError, IndexError): raise KeyError(name) def __setitem__(self, name, value): try: - self._values[self._fields.index(name)] = value + self._values[self._fields.index(name.lower())] = value except (ValueError, IndexError): - self._fields += (name,) + self._fields += (name.lower(),) self._values += (value,) def __contains__(self, name): - return name in self._fields + return name.lower() in self._fields has_key = __contains__ @@ -219,7 +221,7 @@ def __len__(self): def get(self, name, default=None): try: - return self.__getitem__(name) + return self.__getitem__(name.lower()) except KeyError: return default @@ -235,6 +237,9 @@ def items(self): for item in zip(self._fields, self._values): yield item + def get_types(self): + return self._types + def __eq__(self, other): if isinstance(other, Row): return dict.__eq__(dict(self.items()), other) and all(map(operator.eq, self, other)) @@ -273,10 +278,11 @@ def nope(self, *args, **kwargs): class SelectResult(list): def __init__(self, fieldnames, rows): - self.fieldnames = tuple(fieldnames) + # self.fieldnames = tuple(map(lambda a: a[0], fieldnames)) + self.fieldnames = fieldnames self.rows = rows - data = [Row(self.fieldnames, row) for row in self.rows] + data = [Row(fieldnames, row) for row in self.rows] list.__init__(self, data) def width(self): diff --git a/memsql/common/test/test_select_result.py b/memsql/common/test/test_select_result.py index 0f78902..56d157f 100644 --- a/memsql/common/test/test_select_result.py +++ b/memsql/common/test/test_select_result.py @@ -9,34 +9,41 @@ except ImportError: from ordereddict import OrderedDict -FIELDS = ['l\\u203pez', 'ಠ_ಠ', 'cloud', 'moon', 'water', 'computer', 'school', 'network', - 'hammer', 'walking', 'mediocre', 'literature', 'chair', 'two', 'window', 'cords', 'musical', - 'zebra', 'xylophone', 'penguin', 'home', 'dog', 'final', 'ink', 'teacher', 'fun', 'website', - 'banana', 'uncle', 'softly', 'mega', 'ten', 'awesome', 'attatch', 'blue', 'internet', 'bottle', - 'tight', 'zone', 'tomato', 'prison', 'hydro', 'cleaning', 'telivision', 'send', 'frog', 'cup', - 'book', 'zooming', 'falling', 'evily', 'gamer', 'lid', 'juice', 'moniter', 'captain', 'bonding'] +FIELDS = [('l\\u203pez', 'VARCHAR'), ('ಠ_ಠ', 'VARCHAR'), ('cloud', 'VARCHAR'), ('moon', 'VARCHAR'), ('water', 'VARCHAR'), + ('computer', 'VARCHAR'), ('school', 'VARCHAR'), ('network', 'VARCHAR'), + ('hammer', 'VARCHAR'), ('walking', 'VARCHAR'), ('mediocre', 'VARCHAR'), ('literature', 'VARCHAR'), + ('chair', 'VARCHAR'), ('two', 'VARCHAR'), ('window', 'VARCHAR'), ('cords', 'VARCHAR'), ('musical', 'VARCHAR'), + ('zebra', 'VARCHAR'), ('xylophone', 'VARCHAR'), ('penguin', 'VARCHAR'), ('home', 'VARCHAR'), + ('dog', 'VARCHAR'), ('final', 'VARCHAR'), ('ink', 'VARCHAR'), ('teacher', 'VARCHAR'), ('fun', 'VARCHAR'), ('website', 'VARCHAR'), + ('banana', 'VARCHAR'), ('uncle', 'VARCHAR'), ('softly', 'VARCHAR'), ('mega', 'VARCHAR'), ('ten', 'VARCHAR'), + ('awesome', 'VARCHAR'), ('attatch', 'VARCHAR'), ('blue', 'VARCHAR'), ('internet', 'VARCHAR'), ('bottle', 'VARCHAR'), + ('tight', 'VARCHAR'), ('zone', 'VARCHAR'), ('tomato', 'VARCHAR'), ('prison', 'VARCHAR'), ('hydro', 'VARCHAR'), + ('cleaning', 'VARCHAR'), ('telivision', 'VARCHAR'), ('send', 'VARCHAR'), ('frog', 'VARCHAR'), ('cup', 'VARCHAR'), + ('book', 'VARCHAR'), ('zooming', 'VARCHAR'), ('falling', 'VARCHAR'), ('evily', 'VARCHAR'), ('gamer', 'VARCHAR'), + ('lid', 'VARCHAR'), ('juice', 'VARCHAR'), ('moniter', 'VARCHAR'), ('captain', 'VARCHAR'), ('bonding', 'VARCHAR')] def test_result_order(): raw_data = [[random.randint(1, 2 ** 32) for _ in range(len(FIELDS))] for _ in range(256)] res = database.SelectResult(FIELDS, raw_data) for i, row in enumerate(res): - reference = dict(zip(FIELDS, raw_data[i])) - ordered = OrderedDict(zip(FIELDS, raw_data[i])) + _FIELDS = list(map(lambda a: a[0], FIELDS)) + reference = dict(zip(_FIELDS, raw_data[i])) + ordered = OrderedDict(zip(_FIELDS, raw_data[i])) doppel = database.Row(FIELDS, raw_data[i]) assert doppel == row assert row == reference assert row == ordered - assert list(row.keys()) == FIELDS + assert list(row.keys()) == _FIELDS assert list(row.values()) == raw_data[i] - assert sorted(row) == sorted(FIELDS) - assert list(row.items()) == list(zip(FIELDS, raw_data[i])) + assert sorted(row) == sorted(_FIELDS) + assert list(row.items()) == list(zip(_FIELDS, raw_data[i])) assert list(row.values()) == raw_data[i] - assert list(row.keys()) == FIELDS - assert list(row.items()) == list(zip(FIELDS, raw_data[i])) + assert list(row.keys()) == _FIELDS + assert list(row.items()) == list(zip(_FIELDS, raw_data[i])) - for f in FIELDS: + for f in _FIELDS: assert f in row assert f in row assert row[f] == reference[f] diff --git a/memsql/common/util.py b/memsql/common/util.py index 8f337e8..7d7eda6 100644 --- a/memsql/common/util.py +++ b/memsql/common/util.py @@ -1,3 +1,41 @@ +FIELD_TYPE_DICT = { + 0: 'DECIMAL', + 1: 'TINY', + 2: 'SHORT', + 3: 'LONG', + 4: 'FLOAT', + 5: 'DOUBLE', + 6: 'NULL', + 7: 'TIMESTAMP', + 8: 'LONGLONG', + 9: 'INT24', + 10: 'DATE', + 11: 'TIME', + 12: 'DATETIME', + 13: 'YEAR', + 14: 'NEWDATE', + 15: 'VARCHAR', + 16: 'BIT', + 246: 'NEWDECIMAL', + 247: 'INTERVAL', + 248: 'SET', + 249: 'TINY_BLOB', + 250: 'MEDIUM_BLOB', + 251: 'LONG_BLOB', + 252: 'BLOB', + 253: 'VAR_STRING', + 254: 'STRING', + 255: 'GEOMETRY' +} + +N_A_TYPE = 'N/A' + + def timedelta_total_seconds(td): """ Needed for python 2.6 compat """ return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10. ** 6) / 10. ** 6 + + +def get_field_type_by_code(id): + return FIELD_TYPE_DICT.get(id, N_A_TYPE) +