From 6588f14d4d4ba09ab63f2007597ab6faa45d8aec Mon Sep 17 00:00:00 2001 From: Simon Lundmark Date: Mon, 30 Aug 2021 10:18:46 +0200 Subject: [PATCH 1/2] Let multi statements be optional - Disabling multi statements can help protect against SQL injection attacks. --- MySQLdb/connections.py | 22 +++++++++++++++------- tests/test_connection.py | 26 ++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 tests/test_connection.py diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index eca51ed5..5dffb577 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -110,6 +110,11 @@ class object, used to create cursors (keyword only) :param int client_flag: flags to use or 0 (see MySQL docs or constants/CLIENTS.py) + :param bool multi_statements: + If True, enable multi statements for clients >= 4.1 and multi + results for clients >= 5.0. Set to False to disable it, which gives + some protection against injection attacks. Defaults to True. + :param str ssl_mode: specify the security settings for connection to the server; see the MySQL documentation for more details @@ -169,13 +174,16 @@ class object, used to create cursors (keyword only) self._binary_prefix = kwargs2.pop("binary_prefix", False) client_flag = kwargs.get("client_flag", 0) - client_version = tuple( - [numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]] - ) - if client_version >= (4, 1): - client_flag |= CLIENT.MULTI_STATEMENTS - if client_version >= (5, 0): - client_flag |= CLIENT.MULTI_RESULTS + + multi_statements = kwargs2.pop("multi_statements", True) + if multi_statements: + client_version = tuple( + [numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]] + ) + if client_version >= (4, 1): + client_flag |= CLIENT.MULTI_STATEMENTS + if client_version >= (5, 0): + client_flag |= CLIENT.MULTI_RESULTS kwargs2["client_flag"] = client_flag diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 00000000..960de572 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,26 @@ +import pytest + +from MySQLdb._exceptions import ProgrammingError + +from configdb import connection_factory + + +def test_multi_statements_default_true(): + conn = connection_factory() + cursor = conn.cursor() + + cursor.execute("select 17; select 2") + rows = cursor.fetchall() + assert rows == ((17,),) + + +def test_multi_statements_false(): + conn = connection_factory(multi_statements=False) + cursor = conn.cursor() + + with pytest.raises(ProgrammingError): + cursor.execute("select 17; select 2") + + cursor.execute("select 17") + rows = cursor.fetchall() + assert rows == ((17,),) From 77788e65404de839786ccb47de696fe1db3f4fcc Mon Sep 17 00:00:00 2001 From: Simon Lundmark Date: Thu, 2 Sep 2021 15:48:17 +0200 Subject: [PATCH 2/2] multi statements: Fixes after PR review --- MySQLdb/connections.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index 5dffb577..7adc8835 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -111,9 +111,8 @@ class object, used to create cursors (keyword only) flags to use or 0 (see MySQL docs or constants/CLIENTS.py) :param bool multi_statements: - If True, enable multi statements for clients >= 4.1 and multi - results for clients >= 5.0. Set to False to disable it, which gives - some protection against injection attacks. Defaults to True. + If True, enable multi statements for clients >= 4.1. + Defaults to True. :param str ssl_mode: specify the security settings for connection to the server; @@ -175,15 +174,17 @@ class object, used to create cursors (keyword only) client_flag = kwargs.get("client_flag", 0) + client_version = tuple( + [numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]] + ) + multi_statements = kwargs2.pop("multi_statements", True) if multi_statements: - client_version = tuple( - [numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]] - ) if client_version >= (4, 1): client_flag |= CLIENT.MULTI_STATEMENTS - if client_version >= (5, 0): - client_flag |= CLIENT.MULTI_RESULTS + + if client_version >= (5, 0): + client_flag |= CLIENT.MULTI_RESULTS kwargs2["client_flag"] = client_flag