diff --git a/memsql/common/database.py b/memsql/common/database.py index 4d3ea3f..3e051e6 100644 --- a/memsql/common/database.py +++ b/memsql/common/database.py @@ -4,6 +4,7 @@ import time import operator import six +from MySQLdb.constants import FIELD_TYPE try: from _thread import get_ident as _get_ident @@ -150,9 +151,9 @@ def _query(self, query, parameters, kwparameters, debug=False): if self._result is None: return self._rowcount - fields = [ f[0] for f in self._result.describe() ] + fields_and_types = [(f[0], f[1], ) for f in self._result.describe()] rows = self._result.fetch_row(0) - return SelectResult(fields, rows) + return SelectResult(fields_and_types, rows) def _execute(self, query, parameters, kwparameters, debug=False): if parameters and kwparameters: @@ -272,16 +273,20 @@ def nope(self, *args, **kwargs): __reversed__ = nope class SelectResult(list): - def __init__(self, fieldnames, rows): - self.fieldnames = tuple(fieldnames) + def __init__(self, fieldnames_and_types, rows): + self.fieldnames = tuple(map(lambda a: a[0], fieldnames_and_types)) + self._fields_and_types = fieldnames_and_types self.rows = rows - data = [Row(self.fieldnames, row) for row in self.rows] list.__init__(self, data) def width(self): return len(self.fieldnames) + def get_fields_and_types(self): + return map(lambda field_tuple: (field_tuple[0], list(FIELD_TYPE.__dict__.keys())[list(FIELD_TYPE.__dict__.values()).index(field_tuple[1])]), + self._fields_and_types) + def __getitem__(self, i): if isinstance(i, slice): return SelectResult(self.fieldnames, self.rows[i]) diff --git a/memsql/common/test/test_select_result.py b/memsql/common/test/test_select_result.py index 0f78902..82103e9 100644 --- a/memsql/common/test/test_select_result.py +++ b/memsql/common/test/test_select_result.py @@ -3,6 +3,7 @@ import random import pytest import simplejson as json +from MySQLdb.constants import FIELD_TYPE try: from collections import OrderedDict @@ -16,9 +17,20 @@ 'tight', 'zone', 'tomato', 'prison', 'hydro', 'cleaning', 'telivision', 'send', 'frog', 'cup', 'book', 'zooming', 'falling', 'evily', 'gamer', 'lid', 'juice', 'moniter', 'captain', 'bonding'] +FIELDS_AND_TYPES = [('VARCHAR_FIELD', FIELD_TYPE.VARCHAR), ('BIT_FIELD', FIELD_TYPE.BIT), + ('DATETIME_FIELD', FIELD_TYPE.DATETIME), ('TIMESTAMP_FIELD', FIELD_TYPE.TIMESTAMP)] + +def get_sample_data(fields): + return [[random.randint(1, 2 ** 32) for _ in range(len(fields))] for _ in range(256)] + +def test_get_fields_and_types(): + raw_data = get_sample_data(FIELDS_AND_TYPES) + res = database.SelectResult(FIELDS_AND_TYPES, raw_data) + assert list(res.get_fields_and_types()) == list(map(lambda field_tuple: (field_tuple[0], field_tuple[0].split('_')[0]), FIELDS_AND_TYPES)) + def test_result_order(): - raw_data = [[random.randint(1, 2 ** 32) for _ in range(len(FIELDS))] for _ in range(256)] - res = database.SelectResult(FIELDS, raw_data) + raw_data = get_sample_data(FIELDS) + res = database.SelectResult(map(lambda a: (a, FIELD_TYPE.VARCHAR), FIELDS), raw_data) for i, row in enumerate(res): reference = dict(zip(FIELDS, raw_data[i]))