Skip to content

Commit 69991fb

Browse files
committed
Adjust unit tests
1 parent 02cc9d5 commit 69991fb

File tree

1 file changed

+23
-26
lines changed

1 file changed

+23
-26
lines changed

tests/unit/common/test_debug.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def setup_mock(*logger_names):
4848
loggers = [logging.getLogger(name) for name in logger_names]
4949
for logger in loggers:
5050
logger.addHandler = mocker.Mock()
51-
logger.addFilter = mocker.Mock(side_effect=logger.addFilter)
5251
logger.removeHandler = mocker.Mock()
5352
logger.setLevel = mocker.Mock()
5453
return loggers
@@ -179,6 +178,7 @@ def test_watcher_colour(logger_mocker, colour, thread, task) -> None:
179178
thread_info=thread, task_info=task)
180179
watcher.watch()
181180

181+
logger.addHandler.assert_called_once()
182182
(handler,), _ = logger.addHandler.call_args
183183
assert isinstance(handler, logging.Handler)
184184
assert isinstance(handler.formatter, logging.Formatter)
@@ -198,6 +198,7 @@ def test_watcher_format(logger_mocker, colour, thread, task) -> None:
198198
thread_info=thread, task_info=task)
199199
watcher.watch()
200200

201+
logger.addHandler.assert_called_once()
201202
(handler,), _ = logger.addHandler.call_args
202203
assert isinstance(handler, logging.Handler)
203204
assert isinstance(handler.formatter, logging.Formatter)
@@ -213,14 +214,12 @@ def test_watcher_format(logger_mocker, colour, thread, task) -> None:
213214
assert format_ == expected_format
214215

215216

216-
@pytest.mark.parametrize("colour", (True, False))
217-
@pytest.mark.parametrize("thread", (True, False))
218-
@pytest.mark.parametrize("task", (True, False))
219-
def test_watcher_task_injection(
220-
mocker, logger_mocker, colour, thread, task
217+
def _assert_task_injection(
218+
async_: bool, mocker, logger_mocker, colour: bool, thread: bool, task: bool
221219
) -> None:
220+
handler_cls_mock = mocker.patch("neo4j.debug.StreamHandler", autospec=True)
221+
handler_mock = handler_cls_mock.return_value
222222
logger_name = "neo4j"
223-
logger = logger_mocker(logger_name)[0]
224223
watcher = neo4j_debug.Watcher(logger_name, colour=colour,
225224
thread_info=thread, task_info=task)
226225
record_mock = mocker.Mock(spec=logging.LogRecord)
@@ -229,34 +228,32 @@ def test_watcher_task_injection(
229228
watcher.watch()
230229

231230
if task:
232-
(filter_,), _ = logger.addFilter.call_args
231+
handler_mock.addFilter.assert_called_once()
232+
(filter_,), _ = handler_mock.addFilter.call_args
233233
assert isinstance(filter_, logging.Filter)
234234
filter_.filter(record_mock)
235-
assert record_mock.task is None
235+
if async_:
236+
assert record_mock.task == id(asyncio.current_task())
237+
else:
238+
assert record_mock.task is None
236239
else:
237-
logger.addFilter.assert_not_called()
240+
handler_mock.addFilter.assert_not_called()
238241

239242

240243
@pytest.mark.parametrize("colour", (True, False))
241244
@pytest.mark.parametrize("thread", (True, False))
242245
@pytest.mark.parametrize("task", (True, False))
243-
@mark_async_test
244-
async def test_async_watcher_task_injection(
246+
def test_watcher_task_injection(
245247
mocker, logger_mocker, colour, thread, task
246248
) -> None:
247-
logger_name = "neo4j"
248-
logger = logger_mocker(logger_name)[0]
249-
watcher = neo4j_debug.Watcher(logger_name, colour=colour,
250-
thread_info=thread, task_info=task)
251-
record_mock = mocker.Mock(spec=logging.LogRecord)
252-
assert not hasattr(record_mock, "task")
249+
_assert_task_injection(False, mocker, logger_mocker, colour, thread, task)
253250

254-
watcher.watch()
255251

256-
if task:
257-
(filter_,), _ = logger.addFilter.call_args
258-
assert isinstance(filter_, logging.Filter)
259-
filter_.filter(record_mock)
260-
assert record_mock.task == id(asyncio.current_task())
261-
else:
262-
logger.addFilter.assert_not_called()
252+
@pytest.mark.parametrize("colour", (True, False))
253+
@pytest.mark.parametrize("thread", (True, False))
254+
@pytest.mark.parametrize("task", (True, False))
255+
@mark_async_test
256+
async def test_async_watcher_task_injection(
257+
mocker, logger_mocker, colour, thread, task
258+
) -> None:
259+
_assert_task_injection(True, mocker, logger_mocker, colour, thread, task)

0 commit comments

Comments
 (0)