From 144a96b998d2fa09b0b23bdfe1a97cb751a10aff Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sat, 3 Aug 2019 17:52:19 -0700 Subject: [PATCH 1/6] Migrating db module to new exception types --- firebase_admin/db.py | 91 +++++++++++++++++++++++--------------------- tests/test_db.py | 71 +++++++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 56 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 53efd9b15..be39159ae 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -32,6 +32,7 @@ from six.moves import urllib import firebase_admin +from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _sseclient from firebase_admin import _utils @@ -209,7 +210,7 @@ def get(self, etag=False, shallow=False): Raises: ValueError: If both ``etag`` and ``shallow`` are set to True. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if etag: if shallow: @@ -236,7 +237,7 @@ def get_if_changed(self, etag): Raises: ValueError: If the ETag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not isinstance(etag, six.string_types): raise ValueError('ETag must be a string.') @@ -258,7 +259,7 @@ def set(self, value): Raises: ValueError: If the provided value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -281,7 +282,7 @@ def set_if_unchanged(self, expected_etag, value): Raises: ValueError: If the value is None, or if expected_etag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ # pylint: disable=missing-raises-doc if not isinstance(expected_etag, six.string_types): @@ -293,11 +294,11 @@ def set_if_unchanged(self, expected_etag, value): headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) return True, value, headers.get('ETag') - except ApiCallError as error: - detail = error.detail - if detail.response is not None and 'ETag' in detail.response.headers: - etag = detail.response.headers['ETag'] - snapshot = detail.response.json() + except exceptions.FirebaseError as error: + http_response = error.http_response + if http_response is not None and 'ETag' in http_response.headers: + etag = http_response.headers['ETag'] + snapshot = http_response.json() return False, snapshot, etag else: raise error @@ -317,7 +318,7 @@ def push(self, value=''): Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -333,7 +334,7 @@ def update(self, value): Raises: ValueError: If value is empty or not a dictionary. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not value or not isinstance(value, dict): raise ValueError('Value argument must be a non-empty dictionary.') @@ -345,7 +346,7 @@ def delete(self): """Deletes this node from the database. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ self._client.request('delete', self._add_suffix()) @@ -371,7 +372,7 @@ def listen(self, callback): ListenerRegistration: An object that can be used to stop the event listener. Raises: - ApiCallError: If an error occurs while starting the initial HTTP connection. + FirebaseError: If an error occurs while starting the initial HTTP connection. """ session = _sseclient.KeepAuthSession(self._client.credential) return self._listen_with_session(callback, session) @@ -387,9 +388,9 @@ def transaction(self, transaction_update): value of this reference into a new value. If another client writes to this location before the new value is successfully saved, the update function is called again with the new current value, and the write will be retried. In case of repeated failures, this method - will retry the transaction up to 25 times before giving up and raising a TransactionError. - The update function may also force an early abort by raising an exception instead of - returning a value. + will retry the transaction up to 25 times before giving up and raising a + TransactionAbortedError. The update function may also force an early abort by raising an + exception instead of returning a value. Args: transaction_update: A function which will be passed the current data stored at this @@ -402,7 +403,7 @@ def transaction(self, transaction_update): object: New value of the current database Reference (only if the transaction commits). Raises: - TransactionError: If the transaction aborts after exhausting all retry attempts. + TransactionAbortedError: If the transaction aborts after exhausting all retry attempts. ValueError: If transaction_update is not a function. """ if not callable(transaction_update): @@ -416,7 +417,7 @@ def transaction(self, transaction_update): if success: return new_data tries += 1 - raise TransactionError('Transaction aborted after failed retries.') + raise TransactionAbortedError('Transaction aborted after failed retries.') def order_by_child(self, path): """Returns a Query that orders data by child values. @@ -468,7 +469,7 @@ def _listen_with_session(self, callback, session): sse = _sseclient.SSEClient(url, session) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) class Query(object): @@ -614,7 +615,7 @@ def get(self): object: Decoded JSON result of the Query. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ result = self._client.body('get', self._pathurl, params=self._querystr) if isinstance(result, (dict, list)) and self._order_by != '$priority': @@ -622,19 +623,11 @@ def get(self): return result -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase database server API.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - -class TransactionError(Exception): - """Represents an Exception encountered while performing a transaction.""" +class TransactionAbortedError(exceptions.AbortedError): + """Represents an transaction aborted after exhausting all available retries.""" def __init__(self, message): - Exception.__init__(self, message) + exceptions.AbortedError.__init__(self, message) @@ -934,7 +927,7 @@ def request(self, method, url, **kwargs): Response: An HTTP response object. Raises: - ApiCallError: If an error occurs while making the HTTP call. + FirebaseError: If an error occurs while making the HTTP call. """ query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) extra_params = kwargs.get('params') @@ -950,33 +943,45 @@ def request(self, method, url, **kwargs): try: return super(_Client, self).request(method, url, **kwargs) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) + + @classmethod + def handle_rtdb_error(cls, error): + """Converts an error encountered while calling RTDB into a FirebaseError.""" + if error.response is None: + return _utils.handle_requests_error(error) + + message = cls._extract_error_message(error.response) + return _utils.handle_requests_error(error, message=message) @classmethod - def extract_error_message(cls, error): - """Extracts an error message from an exception. + def _extract_error_message(cls, response): + """Extracts an error message from an error response. - If the server has not sent any response, simply converts the exception into a string. If the server has sent a JSON response with an 'error' field, which is the typical behavior of the Realtime Database REST API, parses the response to retrieve the error message. If the server has sent a non-JSON response, returns the full response as the error message. Args: - error: An exception raised by the requests library. + response: An HTTP error response. Returns: - str: A string error message extracted from the exception. + str: A string error message extracted from the response. """ - if error.response is None: - return str(error) + message = None try: - data = error.response.json() + # RTDB error format: {"error": "text message"} + data = response.json() if isinstance(data, dict): - return '{0}\nReason: {1}'.format(error, data.get('error', 'unknown')) + message = data.get('error') except ValueError: pass - return '{0}\nReason: {1}'.format(error, error.response.content.decode()) + + if not message: + message = 'Unexpected response from database: {0}'.format(response.content.decode()) + + return message class _EmulatorAdminCredentials(google.auth.credentials.Credentials): diff --git a/tests/test_db.py b/tests/test_db.py index 211eabb4b..aef41bd6b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from firebase_admin import _sseclient from tests import testutils @@ -64,6 +65,34 @@ class _Object(object): pass +class _RefOperations(object): + + @classmethod + def get(cls, ref): + ref.get() + + @classmethod + def push(cls, ref): + ref.push() + + @classmethod + def set(cls, ref): + ref.set({'foo': 'bar'}) + + @classmethod + def delete(cls, ref): + ref.delete() + + @classmethod + def query(cls, ref): + query = ref.order_by_key() + query.get() + + @classmethod + def get_ops(cls): + return [cls.get, cls.push, cls.set, cls.delete, cls.query] + + class TestReferencePath(object): """Test cases for Reference paths.""" @@ -132,6 +161,12 @@ class TestReference(object): valid_values = [ '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} ] + error_codes = { + 400: exceptions.InvalidArgumentError, + 401: exceptions.UnauthenticatedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + } @classmethod def setup_class(cls): @@ -449,21 +484,29 @@ def test_get_reference(self, path, expected): else: assert ref.parent.path == parent - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_server_error(self, error_code): + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_server_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: json error message' in str(excinfo.value) - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_other_error(self, error_code): + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None + + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_other_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, 'custom error message', error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: custom error message' in str(excinfo.value) + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'Unexpected response from database: custom error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None class TestListenerRegistration(object): @@ -481,9 +524,11 @@ def test_listen_error(self): session.mount(test_url, adapter) def callback(_): pass - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.InternalError) as excinfo: ref._listen_with_session(callback, session) - assert 'Reason: json error message' in str(excinfo.value) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None finally: testutils.cleanup_apps() From 822d97ce8deea3b328af4c6e4d69a1494d86a4db Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sat, 3 Aug 2019 18:14:00 -0700 Subject: [PATCH 2/6] Error handling for transactions --- firebase_admin/_utils.py | 1 + firebase_admin/db.py | 32 +++++++++++++++++--------------- tests/test_db.py | 28 ++++++++++++++++++++++------ 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 42b83809e..95ed2c414 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -52,6 +52,7 @@ 403: exceptions.PERMISSION_DENIED, 404: exceptions.NOT_FOUND, 409: exceptions.CONFLICT, + 412: exceptions.FAILED_PRECONDITION, 429: exceptions.RESOURCE_EXHAUSTED, 500: exceptions.INTERNAL, 503: exceptions.UNAVAILABLE, diff --git a/firebase_admin/db.py b/firebase_admin/db.py index be39159ae..402c4d7b9 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -294,7 +294,7 @@ def set_if_unchanged(self, expected_etag, value): headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) return True, value, headers.get('ETag') - except exceptions.FirebaseError as error: + except exceptions.FailedPreconditionError as error: http_response = error.http_response if http_response is not None and 'ETag' in http_response.headers: etag = http_response.headers['ETag'] @@ -412,11 +412,15 @@ def transaction(self, transaction_update): tries = 0 data, etag = self.get(etag=True) while tries < _TRANSACTION_MAX_RETRIES: - new_data = transaction_update(data) - success, data, etag = self.set_if_unchanged(etag, new_data) - if success: - return new_data - tries += 1 + try: + new_data = transaction_update(data) + success, data, etag = self.set_if_unchanged(etag, new_data) + if success: + return new_data + tries += 1 + except Exception as error: + message = 'Transaction aborted by raising an exception: {0}'.format(error) + raise TransactionAbortedError(message, cause=error) raise TransactionAbortedError('Transaction aborted after failed retries.') def order_by_child(self, path): @@ -624,11 +628,15 @@ def get(self): class TransactionAbortedError(exceptions.AbortedError): - """Represents an transaction aborted after exhausting all available retries.""" + """A transaction was aborted. - def __init__(self, message): - exceptions.AbortedError.__init__(self, message) + A transaction is aborted when the corresponding update function raises an exception, or when + the number of allowed retries is exceeded. In the former case, the original exception that + caused the transaction to abort can be accessed via the ``cause`` property. + """ + def __init__(self, message, cause=None): + exceptions.AbortedError.__init__(self, message, cause) class _Sorter(object): @@ -962,12 +970,6 @@ def _extract_error_message(cls, response): behavior of the Realtime Database REST API, parses the response to retrieve the error message. If the server has sent a non-JSON response, returns the full response as the error message. - - Args: - response: An HTTP error response. - - Returns: - str: A string error message extracted from the response. """ message = None try: diff --git a/tests/test_db.py b/tests/test_db.py index aef41bd6b..7fd4f127d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -32,14 +32,15 @@ class MockAdapter(testutils.MockAdapter): ETAG = '0' - def __init__(self, data, status, recorder): + def __init__(self, data, status, recorder, etag=ETAG): testutils.MockAdapter.__init__(self, data, status, recorder) + self._etag = etag def send(self, request, **kwargs): if_match = request.headers.get('if-match') if_none_match = request.headers.get('if-none-match') resp = super(MockAdapter, self).send(request, **kwargs) - resp.headers = {'ETag': MockAdapter.ETAG} + resp.headers = {'ETag': self._etag} if if_match and if_match != MockAdapter.ETAG: resp.status_code = 412 elif if_none_match == MockAdapter.ETAG: @@ -176,9 +177,9 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() - def instrument(self, ref, payload, status=200): + def instrument(self, ref, payload, status=200, etag=MockAdapter.ETAG): recorder = [] - adapter = MockAdapter(payload, status, recorder) + adapter = MockAdapter(payload, status, recorder, etag) ref._client.session.mount(self.test_url, adapter) return recorder @@ -456,12 +457,27 @@ def transaction_update(data): del data raise ValueError('test error') - with pytest.raises(ValueError) as excinfo: + with pytest.raises(db.TransactionAbortedError) as excinfo: ref.transaction(transaction_update) - assert str(excinfo.value) == 'test error' + assert str(excinfo.value) == 'Transaction aborted by raising an exception: test error' + assert isinstance(excinfo.value.cause, ValueError) + assert excinfo.value.http_response is None assert len(recorder) == 1 assert recorder[0].method == 'GET' + def test_transaction_abort(self): + ref = db.reference('/test/count') + data = 42 + recorder = self.instrument(ref, json.dumps(data), etag='1') + + with pytest.raises(db.TransactionAbortedError) as excinfo: + ref.transaction(lambda x: x + 1 if x else 1) + assert isinstance(excinfo.value, exceptions.AbortedError) + assert str(excinfo.value) == 'Transaction aborted after failed retries.' + assert excinfo.value.cause is None + assert excinfo.value.http_response is None + assert len(recorder) == 1 + 25 + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) def test_transaction_invalid_function(self, func): ref = db.reference('/test') From 54dd1ad9ae2369fbfb08c12d5cde2f0b35be9d45 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sat, 3 Aug 2019 18:33:03 -0700 Subject: [PATCH 3/6] Updated integration tests --- integration/test_db.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/integration/test_db.py b/integration/test_db.py index d88d145ba..4c2f6bde2 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from integration import conftest from tests import testutils @@ -359,30 +360,26 @@ def init_ref(self, path, app): admin_ref.set('test') assert admin_ref.get() == 'test' - def check_permission_error(self, excinfo): - assert isinstance(excinfo.value, db.ApiCallError) - assert 'Reason: Permission denied' in str(excinfo.value) - def test_no_access(self, app, override_app): path = '_adminsdk/python/admin' self.init_ref(path, app) user_ref = db.reference(path, override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert user_ref.get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read(self, app, override_app): path = '_adminsdk/python/protected/user2' self.init_ref(path, app) user_ref = db.reference(path, override_app) assert user_ref.get() == 'test' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read_write(self, app, override_app): path = '_adminsdk/python/protected/user1' @@ -394,9 +391,9 @@ def test_read_write(self, app, override_app): def test_query(self, override_app): user_ref = db.reference('_adminsdk/python/protected', override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.order_by_key().limit_to_first(2).get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_none_auth_override(self, app, none_override_app): path = '_adminsdk/python/public' @@ -405,14 +402,14 @@ def test_none_auth_override(self, app, none_override_app): assert public_ref.get() == 'test' ref = db.reference('_adminsdk/python', none_override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user1').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user2').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('admin').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' From 9d821dc4de4704cf8e8c9119308c3ea2da70bf07 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 4 Aug 2019 15:05:29 -0700 Subject: [PATCH 4/6] Restoring the old txn abort behavior --- firebase_admin/db.py | 26 +++++++++----------------- tests/test_db.py | 6 ++---- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 402c4d7b9..ef7c96721 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -412,15 +412,12 @@ def transaction(self, transaction_update): tries = 0 data, etag = self.get(etag=True) while tries < _TRANSACTION_MAX_RETRIES: - try: - new_data = transaction_update(data) - success, data, etag = self.set_if_unchanged(etag, new_data) - if success: - return new_data - tries += 1 - except Exception as error: - message = 'Transaction aborted by raising an exception: {0}'.format(error) - raise TransactionAbortedError(message, cause=error) + new_data = transaction_update(data) + success, data, etag = self.set_if_unchanged(etag, new_data) + if success: + return new_data + tries += 1 + raise TransactionAbortedError('Transaction aborted after failed retries.') def order_by_child(self, path): @@ -628,15 +625,10 @@ def get(self): class TransactionAbortedError(exceptions.AbortedError): - """A transaction was aborted. - - A transaction is aborted when the corresponding update function raises an exception, or when - the number of allowed retries is exceeded. In the former case, the original exception that - caused the transaction to abort can be accessed via the ``cause`` property. - """ + """A transaction was aborted aftr exceeding the maximum number of retries.""" - def __init__(self, message, cause=None): - exceptions.AbortedError.__init__(self, message, cause) + def __init__(self, message): + exceptions.AbortedError.__init__(self, message) class _Sorter(object): diff --git a/tests/test_db.py b/tests/test_db.py index 7fd4f127d..81f3906d1 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -457,11 +457,9 @@ def transaction_update(data): del data raise ValueError('test error') - with pytest.raises(db.TransactionAbortedError) as excinfo: + with pytest.raises(ValueError) as excinfo: ref.transaction(transaction_update) - assert str(excinfo.value) == 'Transaction aborted by raising an exception: test error' - assert isinstance(excinfo.value.cause, ValueError) - assert excinfo.value.http_response is None + assert str(excinfo.value) == 'test error' assert len(recorder) == 1 assert recorder[0].method == 'GET' From 7bf06cd7ed65831e3a855658fbeb9ac6df4f5320 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Sun, 4 Aug 2019 17:03:36 -0700 Subject: [PATCH 5/6] Updated error type in snippet --- snippets/database/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snippets/database/index.py b/snippets/database/index.py index fee23f626..adfa13476 100644 --- a/snippets/database/index.py +++ b/snippets/database/index.py @@ -214,7 +214,7 @@ def increment_votes(current_value): try: new_vote_count = upvotes_ref.transaction(increment_votes) print('Transaction completed') - except db.TransactionError: + except db.TransactionAbortedError: print('Transaction failed to commit') # [END transaction] From 206604d9c6313de29b3e365fa3f7af3b30438af9 Mon Sep 17 00:00:00 2001 From: hiranya911 Date: Mon, 5 Aug 2019 13:51:09 -0700 Subject: [PATCH 6/6] Added comment --- tests/test_db.py | 60 ++++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/tests/test_db.py b/tests/test_db.py index 81f3906d1..081c31e3d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -66,34 +66,6 @@ class _Object(object): pass -class _RefOperations(object): - - @classmethod - def get(cls, ref): - ref.get() - - @classmethod - def push(cls, ref): - ref.push() - - @classmethod - def set(cls, ref): - ref.set({'foo': 'bar'}) - - @classmethod - def delete(cls, ref): - ref.delete() - - @classmethod - def query(cls, ref): - query = ref.order_by_key() - query.get() - - @classmethod - def get_ops(cls): - return [cls.get, cls.push, cls.set, cls.delete, cls.query] - - class TestReferencePath(object): """Test cases for Reference paths.""" @@ -155,6 +127,38 @@ def test_invalid_child(self, child): parent.child(child) +class _RefOperations(object): + """A collection of operations that can be performed using a ``db.Reference``. + + This can be used to test any functionality that is common across multiple API calls. + """ + + @classmethod + def get(cls, ref): + ref.get() + + @classmethod + def push(cls, ref): + ref.push() + + @classmethod + def set(cls, ref): + ref.set({'foo': 'bar'}) + + @classmethod + def delete(cls, ref): + ref.delete() + + @classmethod + def query(cls, ref): + query = ref.order_by_key() + query.get() + + @classmethod + def get_ops(cls): + return [cls.get, cls.push, cls.set, cls.delete, cls.query] + + class TestReference(object): """Test cases for database queries via References."""