Skip to content

Add column Type to results, colname case insesitiveness #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions memsql/common/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import operator
import six
from memsql.common import util

try:
from _thread import get_ident as _get_ident
Expand Down Expand Up @@ -150,7 +151,7 @@ def _query(self, query, parameters, kwparameters, debug=False):
if self._result is None:
return self._rowcount

fields = [ f[0] for f in self._result.describe() ]
fields = [ (f[0], util.get_field_type_by_code(f[1]), ) for f in self._result.describe() ]
rows = self._result.fetch_row(0)
return SelectResult(fields, rows)

Expand Down Expand Up @@ -180,31 +181,32 @@ def _ensure_connected(self):
class Row(object):
"""A fast, ordered, partially-immutable dictlike object (or objectlike dict)."""

def __init__(self, fields, values):
self._fields = fields
def __init__(self, fields_and_types_tuple, values):
self._fields = list(map(lambda a: a[0].lower(), fields_and_types_tuple))
self._values = values
self._types = fields_and_types_tuple

def __getattr__(self, name):
try:
return self._values[self._fields.index(name)]
return self._values[self._fields.index(name.lower())]
except (ValueError, IndexError):
raise AttributeError(name)

def __getitem__(self, name):
try:
return self._values[self._fields.index(name)]
return self._values[self._fields.index(name.lower())]
except (ValueError, IndexError):
raise KeyError(name)

def __setitem__(self, name, value):
try:
self._values[self._fields.index(name)] = value
self._values[self._fields.index(name.lower())] = value
except (ValueError, IndexError):
self._fields += (name,)
self._fields += (name.lower(),)
self._values += (value,)

def __contains__(self, name):
return name in self._fields
return name.lower() in self._fields

has_key = __contains__

Expand All @@ -219,7 +221,7 @@ def __len__(self):

def get(self, name, default=None):
try:
return self.__getitem__(name)
return self.__getitem__(name.lower())
except KeyError:
return default

Expand All @@ -235,6 +237,9 @@ def items(self):
for item in zip(self._fields, self._values):
yield item

def get_types(self):
return self._types

def __eq__(self, other):
if isinstance(other, Row):
return dict.__eq__(dict(self.items()), other) and all(map(operator.eq, self, other))
Expand Down Expand Up @@ -273,10 +278,11 @@ def nope(self, *args, **kwargs):

class SelectResult(list):
def __init__(self, fieldnames, rows):
self.fieldnames = tuple(fieldnames)
# self.fieldnames = tuple(map(lambda a: a[0], fieldnames))
self.fieldnames = fieldnames
self.rows = rows

data = [Row(self.fieldnames, row) for row in self.rows]
data = [Row(fieldnames, row) for row in self.rows]
list.__init__(self, data)

def width(self):
Expand Down
35 changes: 21 additions & 14 deletions memsql/common/test/test_select_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,41 @@
except ImportError:
from ordereddict import OrderedDict

FIELDS = ['l\\u203pez', 'ಠ_ಠ', 'cloud', 'moon', 'water', 'computer', 'school', 'network',
'hammer', 'walking', 'mediocre', 'literature', 'chair', 'two', 'window', 'cords', 'musical',
'zebra', 'xylophone', 'penguin', 'home', 'dog', 'final', 'ink', 'teacher', 'fun', 'website',
'banana', 'uncle', 'softly', 'mega', 'ten', 'awesome', 'attatch', 'blue', 'internet', 'bottle',
'tight', 'zone', 'tomato', 'prison', 'hydro', 'cleaning', 'telivision', 'send', 'frog', 'cup',
'book', 'zooming', 'falling', 'evily', 'gamer', 'lid', 'juice', 'moniter', 'captain', 'bonding']
FIELDS = [('l\\u203pez', 'VARCHAR'), ('ಠ_ಠ', 'VARCHAR'), ('cloud', 'VARCHAR'), ('moon', 'VARCHAR'), ('water', 'VARCHAR'),
('computer', 'VARCHAR'), ('school', 'VARCHAR'), ('network', 'VARCHAR'),
('hammer', 'VARCHAR'), ('walking', 'VARCHAR'), ('mediocre', 'VARCHAR'), ('literature', 'VARCHAR'),
('chair', 'VARCHAR'), ('two', 'VARCHAR'), ('window', 'VARCHAR'), ('cords', 'VARCHAR'), ('musical', 'VARCHAR'),
('zebra', 'VARCHAR'), ('xylophone', 'VARCHAR'), ('penguin', 'VARCHAR'), ('home', 'VARCHAR'),
('dog', 'VARCHAR'), ('final', 'VARCHAR'), ('ink', 'VARCHAR'), ('teacher', 'VARCHAR'), ('fun', 'VARCHAR'), ('website', 'VARCHAR'),
('banana', 'VARCHAR'), ('uncle', 'VARCHAR'), ('softly', 'VARCHAR'), ('mega', 'VARCHAR'), ('ten', 'VARCHAR'),
('awesome', 'VARCHAR'), ('attatch', 'VARCHAR'), ('blue', 'VARCHAR'), ('internet', 'VARCHAR'), ('bottle', 'VARCHAR'),
('tight', 'VARCHAR'), ('zone', 'VARCHAR'), ('tomato', 'VARCHAR'), ('prison', 'VARCHAR'), ('hydro', 'VARCHAR'),
('cleaning', 'VARCHAR'), ('telivision', 'VARCHAR'), ('send', 'VARCHAR'), ('frog', 'VARCHAR'), ('cup', 'VARCHAR'),
('book', 'VARCHAR'), ('zooming', 'VARCHAR'), ('falling', 'VARCHAR'), ('evily', 'VARCHAR'), ('gamer', 'VARCHAR'),
('lid', 'VARCHAR'), ('juice', 'VARCHAR'), ('moniter', 'VARCHAR'), ('captain', 'VARCHAR'), ('bonding', 'VARCHAR')]

def test_result_order():
raw_data = [[random.randint(1, 2 ** 32) for _ in range(len(FIELDS))] for _ in range(256)]
res = database.SelectResult(FIELDS, raw_data)

for i, row in enumerate(res):
reference = dict(zip(FIELDS, raw_data[i]))
ordered = OrderedDict(zip(FIELDS, raw_data[i]))
_FIELDS = list(map(lambda a: a[0], FIELDS))
reference = dict(zip(_FIELDS, raw_data[i]))
ordered = OrderedDict(zip(_FIELDS, raw_data[i]))
doppel = database.Row(FIELDS, raw_data[i])

assert doppel == row
assert row == reference
assert row == ordered
assert list(row.keys()) == FIELDS
assert list(row.keys()) == _FIELDS
assert list(row.values()) == raw_data[i]
assert sorted(row) == sorted(FIELDS)
assert list(row.items()) == list(zip(FIELDS, raw_data[i]))
assert sorted(row) == sorted(_FIELDS)
assert list(row.items()) == list(zip(_FIELDS, raw_data[i]))
assert list(row.values()) == raw_data[i]
assert list(row.keys()) == FIELDS
assert list(row.items()) == list(zip(FIELDS, raw_data[i]))
assert list(row.keys()) == _FIELDS
assert list(row.items()) == list(zip(_FIELDS, raw_data[i]))

for f in FIELDS:
for f in _FIELDS:
assert f in row
assert f in row
assert row[f] == reference[f]
Expand Down
38 changes: 38 additions & 0 deletions memsql/common/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
FIELD_TYPE_DICT = {
0: 'DECIMAL',
1: 'TINY',
2: 'SHORT',
3: 'LONG',
4: 'FLOAT',
5: 'DOUBLE',
6: 'NULL',
7: 'TIMESTAMP',
8: 'LONGLONG',
9: 'INT24',
10: 'DATE',
11: 'TIME',
12: 'DATETIME',
13: 'YEAR',
14: 'NEWDATE',
15: 'VARCHAR',
16: 'BIT',
246: 'NEWDECIMAL',
247: 'INTERVAL',
248: 'SET',
249: 'TINY_BLOB',
250: 'MEDIUM_BLOB',
251: 'LONG_BLOB',
252: 'BLOB',
253: 'VAR_STRING',
254: 'STRING',
255: 'GEOMETRY'
}

N_A_TYPE = 'N/A'


def timedelta_total_seconds(td):
""" Needed for python 2.6 compat """
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10. ** 6) / 10. ** 6


def get_field_type_by_code(id):
return FIELD_TYPE_DICT.get(id, N_A_TYPE)