diff --git a/neo4j/v1/exceptions.py b/neo4j/v1/exceptions.py index 8f18e31b..ca85fd02 100644 --- a/neo4j/v1/exceptions.py +++ b/neo4j/v1/exceptions.py @@ -23,8 +23,6 @@ class ProtocolError(Exception): """ Raised when an unexpected or unsupported protocol event occurs. """ - pass - class CypherError(Exception): """ Raised when the Cypher engine returns an error to the client. @@ -38,3 +36,8 @@ def __init__(self, data): for key, value in data.items(): if not key.startswith("_"): setattr(self, key, value) + + +class ResultError(Exception): + """ Raised when an error occurs while consuming a result. + """ diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 04ea8647..18808bb8 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -33,7 +33,7 @@ class which can be used to obtain `Driver` instances that are used for from .compat import integer, string, urlparse from .connection import connect, Response, RUN, PULL_ALL from .constants import ENCRYPTED_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES -from .exceptions import CypherError +from .exceptions import CypherError, ResultError from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED from .types import hydrated @@ -256,6 +256,32 @@ def consume(self): self.connection = None return self._summary + def single(self): + """ Return the next record, failing if none or more than one remain. + """ + records = list(self) + num_records = len(records) + if num_records == 0: + raise ResultError("No records found in stream") + elif num_records != 1: + raise ResultError("Multiple records found in stream") + else: + return records[0] + + def peek(self): + """ Return the next record without advancing the cursor. Fails + if no records remain. + """ + if self._buffer: + values = self._buffer[0] + return Record(self.keys(), tuple(map(hydrated, values))) + while not self._buffer and not self._consumed: + self.connection.fetch() + if self._buffer: + values = self._buffer[0] + return Record(self.keys(), tuple(map(hydrated, values))) + raise ResultError("End of stream") + class ResultSummary(object): """ A summary of execution returned with a :class:`.StatementResult` object. diff --git a/runtests.sh b/runtests.sh index ccd63d82..5f834fef 100755 --- a/runtests.sh +++ b/runtests.sh @@ -21,6 +21,7 @@ DRIVER_HOME=$(dirname $0) NEORUN_OPTIONS="" RUNNING=0 +QUICK=0 KNOWN_HOSTS="${HOME}/.neo4j/known_hosts" KNOWN_HOSTS_BACKUP="${KNOWN_HOSTS}.backup" @@ -28,12 +29,15 @@ FG_BRIGHT_RED='\033[1;31m' FG_DEFAULT='\033[0m' # Parse options -while getopts ":dr" OPTION +while getopts ":dqr" OPTION do case ${OPTION} in d) NEORUN_OPTIONS="-f" ;; + q) + QUICK=1 + ;; r) RUNNING=1 ;; @@ -79,34 +83,33 @@ echo "" TEST_RUNNER="coverage run -m ${UNITTEST} discover -vfs ${TEST}" EXAMPLES_RUNNER="coverage run -m ${UNITTEST} discover -vfs examples" -BEHAVE_RUNNER="behave --tags=-db --tags=-in_dev --tags=-streaming_and_cursor_navigation test/tck" +BEHAVE_RUNNER="behave --tags=-db --tags=-in_dev test/tck" if [ ${RUNNING} -eq 1 ] then ${TEST_RUNNER} check_exit_status $? else - #echo "Updating password" - #mv ${KNOWN_HOSTS} ${KNOWN_HOSTS_BACKUP} - #neokit/neorun ${NEORUN_OPTIONS} "python -m test.auth password" ${VERSIONS} - #EXIT_STATUS=$? - #mv ${KNOWN_HOSTS_BACKUP} ${KNOWN_HOSTS} - #check_exit_status ${EXIT_STATUS} export NEO4J_PASSWORD="password" echo "Running unit tests" neokit/neorun ${NEORUN_OPTIONS} "${TEST_RUNNER}" ${VERSIONS} check_exit_status $? - echo "Testing example code" - neokit/neorun ${NEORUN_OPTIONS} "${EXAMPLES_RUNNER}" ${VERSIONS} - check_exit_status $? + if [ ${QUICK} -eq 0 ] + then + echo "Testing example code" + neokit/neorun ${NEORUN_OPTIONS} "${EXAMPLES_RUNNER}" ${VERSIONS} + check_exit_status $? + + echo "Testing TCK" + coverage report --show-missing + python -c 'from test.tck.configure_feature_files import *; set_up()' + echo "Feature files downloaded" + neokit/neorun ${NEORUN_OPTIONS} "${BEHAVE_RUNNER}" ${VERSIONS} + python -c 'from test.tck.configure_feature_files import *; clean_up()' + echo "Feature files removed" - coverage report --show-missing - python -c 'from test.tck.configure_feature_files import *; set_up()' - echo "Feature files downloaded" - neokit/neorun ${NEORUN_OPTIONS} "${BEHAVE_RUNNER}" ${VERSIONS} - python -c 'from test.tck.configure_feature_files import *; clean_up()' - echo "Feature files removed" + fi fi diff --git a/test/test_session.py b/test/test_session.py index 58d8513a..c383e365 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -26,7 +26,7 @@ from mock import patch from neo4j.v1.constants import TRUST_ON_FIRST_USE -from neo4j.v1.exceptions import CypherError +from neo4j.v1.exceptions import CypherError, ResultError from neo4j.v1.session import GraphDatabase, basic_auth, Record, SSL_AVAILABLE from neo4j.v1.types import Node, Relationship, Path @@ -575,3 +575,74 @@ def test_can_consume_result_after_session_with_error(self): tx.commit() session.close() assert [record[0] for record in result] == [1, 2, 3] + + def test_single_with_exactly_one_record(self): + session = self.driver.session() + result = session.run("UNWIND range(1, 1) AS n RETURN n") + record = result.single() + assert list(record.values()) == [1] + + def test_single_with_no_records(self): + session = self.driver.session() + result = session.run("CREATE ()") + with self.assertRaises(ResultError): + result.single() + + def test_single_with_multiple_records(self): + session = self.driver.session() + result = session.run("UNWIND range(1, 3) AS n RETURN n") + with self.assertRaises(ResultError): + result.single() + + def test_single_consumes_entire_result_if_one_record(self): + session = self.driver.session() + result = session.run("UNWIND range(1, 1) AS n RETURN n") + _ = result.single() + assert result._consumed + + def test_single_consumes_entire_result_if_multiple_records(self): + session = self.driver.session() + result = session.run("UNWIND range(1, 3) AS n RETURN n") + with self.assertRaises(ResultError): + _ = result.single() + assert result._consumed + + def test_peek_can_look_one_ahead(self): + session = self.driver.session() + result = session.run("UNWIND range(1, 3) AS n RETURN n") + record = result.peek() + assert list(record.values()) == [1] + + def test_peek_fails_if_nothing_remains(self): + session = self.driver.session() + result = session.run("CREATE ()") + with self.assertRaises(ResultError): + result.peek() + + def test_peek_does_not_advance_cursor(self): + session = self.driver.session() + result = session.run("UNWIND range(1, 3) AS n RETURN n") + result.peek() + assert [record[0] for record in result] == [1, 2, 3] + + def test_peek_at_different_stages(self): + session = self.driver.session() + result = session.run("UNWIND range(0, 9) AS n RETURN n") + # Peek ahead to the first record + expected_next = 0 + upcoming = result.peek() + assert upcoming[0] == expected_next + # Then look through all the other records + for expected, record in enumerate(result): + # Check this record is as expected + assert record[0] == expected + # Check the upcoming record is as expected... + if expected < 9: + # ...when one should follow + expected_next = expected + 1 + upcoming = result.peek() + assert upcoming[0] == expected_next + else: + # ...when none should follow + with self.assertRaises(ResultError): + result.peek()