diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 48c8f70e..44f43f0b 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -252,6 +252,36 @@ async def test_result_iteration(method, records): await fetch_and_compare_all_records(result, "x", records, method) +@mark_async_test +async def test_result_iteration_mixed_methods(): + records = [[i] for i in range(10)] + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), 4, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + iter1 = AsyncUtil.iter(result) + iter2 = AsyncUtil.iter(result) + assert (await AsyncUtil.next(iter1)).get("x") == records[0][0] + assert (await AsyncUtil.next(iter2)).get("x") == records[1][0] + assert (await AsyncUtil.next(iter2)).get("x") == records[2][0] + assert (await AsyncUtil.next(iter1)).get("x") == records[3][0] + assert (await AsyncUtil.next(iter1)).get("x") == records[4][0] + assert (await AsyncUtil.next(result)).get("x") == records[5][0] + assert (await AsyncUtil.next(iter2)).get("x") == records[6][0] + assert (await AsyncUtil.next(iter1)).get("x") == records[7][0] + assert ((await AsyncUtil.next(AsyncUtil.iter(result))).get("x") + == records[8][0]) + assert [r.get("x") async for r in result] == [records[9][0]] + with pytest.raises(StopAsyncIteration): + await AsyncUtil.next(iter1) + with pytest.raises(StopAsyncIteration): + await AsyncUtil.next(iter2) + with pytest.raises(StopAsyncIteration): + await AsyncUtil.next(result) + with pytest.raises(StopAsyncIteration): + await AsyncUtil.next(AsyncUtil.iter(result)) + assert [r.get("x") async for r in result] == [] + + @pytest.mark.parametrize("method", ("for loop", "next", "one iter", "new iter")) @pytest.mark.parametrize("invert_fetch", (True, False)) diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 6d615079..3c629cdf 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -252,6 +252,36 @@ def test_result_iteration(method, records): fetch_and_compare_all_records(result, "x", records, method) +@mark_sync_test +def test_result_iteration_mixed_methods(): + records = [[i] for i in range(10)] + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), 4, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + iter1 = Util.iter(result) + iter2 = Util.iter(result) + assert (Util.next(iter1)).get("x") == records[0][0] + assert (Util.next(iter2)).get("x") == records[1][0] + assert (Util.next(iter2)).get("x") == records[2][0] + assert (Util.next(iter1)).get("x") == records[3][0] + assert (Util.next(iter1)).get("x") == records[4][0] + assert (Util.next(result)).get("x") == records[5][0] + assert (Util.next(iter2)).get("x") == records[6][0] + assert (Util.next(iter1)).get("x") == records[7][0] + assert ((Util.next(Util.iter(result))).get("x") + == records[8][0]) + assert [r.get("x") for r in result] == [records[9][0]] + with pytest.raises(StopIteration): + Util.next(iter1) + with pytest.raises(StopIteration): + Util.next(iter2) + with pytest.raises(StopIteration): + Util.next(result) + with pytest.raises(StopIteration): + Util.next(Util.iter(result)) + assert [r.get("x") for r in result] == [] + + @pytest.mark.parametrize("method", ("for loop", "next", "one iter", "new iter")) @pytest.mark.parametrize("invert_fetch", (True, False))