diff --git a/CHANGELOG.md b/CHANGELOG.md index 05885d8d..0664bab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- Support custom packer and unpacker factories (#191). ### Changed diff --git a/tarantool/connection.py b/tarantool/connection.py index 5849994c..fcaa7e83 100644 --- a/tarantool/connection.py +++ b/tarantool/connection.py @@ -21,8 +21,12 @@ import msgpack -from tarantool.response import Response +from tarantool.response import ( + unpacker_factory as default_unpacker_factory, + Response, +) from tarantool.request import ( + packer_factory as default_packer_factory, Request, # RequestOK, RequestCall, @@ -357,7 +361,9 @@ def __init__(self, host, port, ssl_key_file=DEFAULT_SSL_KEY_FILE, ssl_cert_file=DEFAULT_SSL_CERT_FILE, ssl_ca_file=DEFAULT_SSL_CA_FILE, - ssl_ciphers=DEFAULT_SSL_CIPHERS): + ssl_ciphers=DEFAULT_SSL_CIPHERS, + packer_factory=default_packer_factory, + unpacker_factory=default_unpacker_factory): """ :param host: Server hostname or IP address. Use ``None`` for Unix sockets. @@ -395,6 +401,16 @@ def __init__(self, host, port, :param encoding: ``'utf-8'`` or ``None``. Use ``None`` to work with non-UTF8 strings. + If non-default + :paramref:`~tarantool.Connection.packer_factory` option is + used, :paramref:`~tarantool.Connection.encoding` option + value is ignored on encode until the factory explicitly uses + its value. If non-default + :paramref:`~tarantool.Connection.unpacker_factory` option is + used, :paramref:`~tarantool.Connection.encoding` option + value is ignored on decode until the factory explicitly uses + its value. + If ``'utf-8'``, pack Unicode string (:obj:`str`) to MessagePack string (`mp_str`_) and unpack MessagePack string (`mp_str`_) Unicode string (:obj:`str`), pack :obj:`bytes` @@ -429,6 +445,13 @@ def __init__(self, host, port, :param use_list: If ``True``, unpack MessagePack array (`mp_array`_) to :obj:`list`. Otherwise, unpack to :obj:`tuple`. + + If non-default + :paramref:`~tarantool.Connection.unpacker_factory` option is + used, + :paramref:`~tarantool.Connection.use_list` option value is + ignored on decode until the factory explicitly uses its + value. :type use_list: :obj:`bool`, optional :param call_16: @@ -463,6 +486,23 @@ def __init__(self, host, port, suites the connection can use. :type ssl_ciphers: :obj:`str` or :obj:`None`, optional + :param packer_factory: Request MessagePack packer factory. + Supersedes :paramref:`~tarantool.Connection.encoding`. See + :func:`~tarantool.request.packer_factory` for example of + a packer factory. + :type packer_factory: + callable[[:obj:`~tarantool.Connection`], :obj:`~msgpack.Packer`], + optional + + :param unpacker_factory: Response MessagePack unpacker factory. + Supersedes :paramref:`~tarantool.Connection.encoding` and + :paramref:`~tarantool.Connection.use_list`. See + :func:`~tarantool.response.unpacker_factory` for example of + an unpacker factory. + :type unpacker_factory: + callable[[:obj:`~tarantool.Connection`], :obj:`~msgpack.Unpacker`], + optional + :raise: :exc:`~tarantool.error.ConfigurationError`, :meth:`~tarantool.Connection.connect` exceptions @@ -514,6 +554,8 @@ def __init__(self, host, port, IPROTO_FEATURE_ERROR_EXTENSION: False, IPROTO_FEATURE_WATCHERS: False, } + self._packer_factory_impl = packer_factory + self._unpacker_factory_impl = unpacker_factory if connect_now: self.connect() @@ -1749,3 +1791,9 @@ def _check_features(self): features_list = [val for val in CONNECTOR_FEATURES if val in server_features] for val in features_list: self._features[val] = True + + def _packer_factory(self): + return self._packer_factory_impl(self) + + def _unpacker_factory(self): + return self._unpacker_factory_impl(self) diff --git a/tarantool/request.py b/tarantool/request.py index 7274b8d2..e9ea6e01 100644 --- a/tarantool/request.py +++ b/tarantool/request.py @@ -63,7 +63,7 @@ from tarantool.msgpack_ext.packer import default as packer_default -def build_packer(conn): +def packer_factory(conn): """ Build packer to pack request. @@ -148,7 +148,7 @@ def __init__(self, conn): self._body = '' self.response_class = Response - self.packer = build_packer(conn) + self.packer = conn._packer_factory() def _dumps(self, src): """ diff --git a/tarantool/response.py b/tarantool/response.py index 7fef6e90..3ef931df 100644 --- a/tarantool/response.py +++ b/tarantool/response.py @@ -32,7 +32,7 @@ from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook -def build_unpacker(conn): +def unpacker_factory(conn): """ Build unpacker to unpack request response. @@ -108,7 +108,7 @@ def __init__(self, conn, response): # created in the __new__(). # super(Response, self).__init__() - unpacker = build_unpacker(conn) + unpacker = conn._unpacker_factory() unpacker.feed(response) header = unpacker.unpack() diff --git a/test/suites/__init__.py b/test/suites/__init__.py index ee951038..49014e97 100644 --- a/test/suites/__init__.py +++ b/test/suites/__init__.py @@ -22,6 +22,7 @@ from .test_package import TestSuite_Package from .test_error_ext import TestSuite_ErrorExt from .test_push import TestSuite_Push +from .test_connection import TestSuite_Connection test_cases = (TestSuite_Schema_UnicodeConnection, TestSuite_Schema_BinaryConnection, @@ -29,7 +30,8 @@ TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI, TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl, TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime, - TestSuite_Interval, TestSuite_ErrorExt, TestSuite_Push,) + TestSuite_Interval, TestSuite_ErrorExt, TestSuite_Push, + TestSuite_Connection,) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/test/suites/test_connection.py b/test/suites/test_connection.py new file mode 100644 index 00000000..4132d35e --- /dev/null +++ b/test/suites/test_connection.py @@ -0,0 +1,161 @@ +import sys +import unittest + +import decimal +import msgpack + +import tarantool +import tarantool.msgpack_ext.decimal as ext_decimal + +from .lib.skip import skip_or_run_decimal_test, skip_or_run_varbinary_test +from .lib.tarantool_server import TarantoolServer + +class TestSuite_Connection(unittest.TestCase): + @classmethod + def setUpClass(self): + print(' CONNECTION '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + self.srv = TarantoolServer() + self.srv.script = 'test/suites/box.lua' + self.srv.start() + + self.adm = self.srv.admin + self.adm(r""" + box.schema.user.create('test', {password = 'test', if_not_exists = true}) + box.schema.user.grant('test', 'read,write,execute', 'universe') + + box.schema.create_space('space_varbin') + + box.space['space_varbin']:format({ + { + 'id', + type = 'number', + is_nullable = false + }, + { + 'varbin', + type = 'varbinary', + is_nullable = false, + } + }) + + box.space['space_varbin']:create_index('id', { + type = 'tree', + parts = {1, 'number'}, + unique = true}) + + box.space['space_varbin']:create_index('varbin', { + type = 'tree', + parts = {2, 'varbinary'}, + unique = true}) + """) + + def setUp(self): + # prevent a remote tarantool from clean our session + if self.srv.is_started(): + self.srv.touch_lock() + + @skip_or_run_decimal_test + def test_custom_packer(self): + def my_ext_type_encoder(obj): + if isinstance(obj, decimal.Decimal): + obj = obj + 1 + return msgpack.ExtType(ext_decimal.EXT_ID, ext_decimal.encode(obj, None)) + raise TypeError("Unknown type: %r" % (obj,)) + + def my_packer_factory(_): + return msgpack.Packer(default=my_ext_type_encoder) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test', + packer_factory=my_packer_factory) + + resp = self.con.eval("return ...", (decimal.Decimal('27756'),)) + self.assertSequenceEqual(resp, [decimal.Decimal('27757')]) + + def test_custom_packer_supersedes_encoding(self): + def my_packer_factory(_): + return msgpack.Packer(use_bin_type=False) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test', + encoding='utf-8', + packer_factory=my_packer_factory) + + # bytes -> mp_str (string) for encoding=None + # bytes -> mp_bin (varbinary) for encoding='utf-8' + resp = self.con.eval("return type(...)", (bytes(bytearray.fromhex('DEADBEAF0103')),)) + self.assertSequenceEqual(resp, ['string']) + + @skip_or_run_decimal_test + def test_custom_unpacker(self): + def my_ext_type_decoder(code, data): + if code == ext_decimal.EXT_ID: + return ext_decimal.decode(data, None) - 1 + raise NotImplementedError("Unknown msgpack extension type code %d" % (code,)) + + def my_unpacker_factory(_): + if msgpack.version >= (1, 0, 0): + return msgpack.Unpacker(ext_hook=my_ext_type_decoder, strict_map_key=False) + return msgpack.Unpacker(ext_hook=my_ext_type_decoder) + + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test', + unpacker_factory=my_unpacker_factory) + + resp = self.con.eval("return require('decimal').new('27756')") + self.assertSequenceEqual(resp, [decimal.Decimal('27755')]) + + @skip_or_run_varbinary_test + def test_custom_unpacker_supersedes_encoding(self): + def my_unpacker_factory(_): + if msgpack.version >= (0, 5, 2): + if msgpack.version >= (1, 0, 0): + return msgpack.Unpacker(raw=True, strict_map_key=False) + + return msgpack.Unpacker(raw=True) + return msgpack.Unpacker(encoding=None) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test', + encoding='utf-8', + unpacker_factory=my_unpacker_factory) + + data_id = 1 + data_hex = 'DEADBEAF' + data = bytes(bytearray.fromhex(data_hex)) + space = 'space_varbin' + + self.con.execute(""" + INSERT INTO "%s" VALUES (%d, x'%s'); + """ % (space, data_id, data_hex)) + + resp = self.con.execute(""" + SELECT * FROM "%s" WHERE "varbin" == x'%s'; + """ % (space, data_hex)) + self.assertSequenceEqual(resp, [[data_id, data]]) + + def test_custom_unpacker_supersedes_use_list(self): + def my_unpacker_factory(_): + if msgpack.version >= (1, 0, 0): + return msgpack.Unpacker(use_list=False, strict_map_key=False) + return msgpack.Unpacker(use_list=False) + + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + user='test', password='test', + use_list=True, + unpacker_factory=my_unpacker_factory) + + resp = self.con.eval("return {1, 2, 3}") + self.assertIsInstance(resp[0], tuple) + + @classmethod + def tearDown(self): + if hasattr(self, 'con'): + self.con.close() + + @classmethod + def tearDownClass(self): + self.srv.stop() + self.srv.clean() diff --git a/test/suites/test_error_ext.py b/test/suites/test_error_ext.py index 199e7b50..89f5c35f 100644 --- a/test/suites/test_error_ext.py +++ b/test/suites/test_error_ext.py @@ -8,8 +8,6 @@ from tarantool.msgpack_ext.packer import default as packer_default from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook -from tarantool.request import build_packer -from tarantool.response import build_unpacker from .lib.tarantool_server import TarantoolServer from .lib.skip import skip_or_run_error_ext_type_test @@ -273,7 +271,7 @@ def test_msgpack_decode(self): unpacker_ext_hook( 3, case['msgpack'], - build_unpacker(conn) + conn._unpacker_factory(), ), case['python']) @@ -330,7 +328,7 @@ def test_msgpack_encode(self): case = self.cases[name] conn = getattr(self, case['conn']) - self.assertEqual(packer_default(case['python'], build_packer(conn)), + self.assertEqual(packer_default(case['python'], conn._packer_factory()), msgpack.ExtType(code=3, data=case['msgpack'])) @skip_or_run_error_ext_type_test diff --git a/test/suites/test_interval.py b/test/suites/test_interval.py index a3458ad7..2252ebe8 100644 --- a/test/suites/test_interval.py +++ b/test/suites/test_interval.py @@ -9,7 +9,6 @@ from tarantool.msgpack_ext.packer import default as packer_default from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook -from tarantool.response import build_unpacker from .lib.tarantool_server import TarantoolServer from .lib.skip import skip_or_run_datetime_test @@ -154,7 +153,7 @@ def test_msgpack_decode(self): self.assertEqual(unpacker_ext_hook( 6, case['msgpack'], - build_unpacker(self.con), + self.con._unpacker_factory(), ), case['python']) @@ -206,13 +205,13 @@ def test_unknown_field_decode(self): case = b'\x01\x09\xce\x00\x98\x96\x80' self.assertRaisesRegex( MsgpackError, 'Unknown interval field id 9', - lambda: unpacker_ext_hook(6, case, build_unpacker(self.con))) + lambda: unpacker_ext_hook(6, case, self.con._unpacker_factory())) def test_unknown_adjust_decode(self): case = b'\x02\x07\xce\x00\x98\x96\x80\x08\x03' self.assertRaisesRegex( MsgpackError, '3 is not a valid Adjust', - lambda: unpacker_ext_hook(6, case, build_unpacker(self.con))) + lambda: unpacker_ext_hook(6, case, self.con._unpacker_factory())) arithmetic_cases = {