diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 42b83809e..95ed2c414 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -52,6 +52,7 @@ 403: exceptions.PERMISSION_DENIED, 404: exceptions.NOT_FOUND, 409: exceptions.CONFLICT, + 412: exceptions.FAILED_PRECONDITION, 429: exceptions.RESOURCE_EXHAUSTED, 500: exceptions.INTERNAL, 503: exceptions.UNAVAILABLE, diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 53efd9b15..ef7c96721 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -32,6 +32,7 @@ from six.moves import urllib import firebase_admin +from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _sseclient from firebase_admin import _utils @@ -209,7 +210,7 @@ def get(self, etag=False, shallow=False): Raises: ValueError: If both ``etag`` and ``shallow`` are set to True. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if etag: if shallow: @@ -236,7 +237,7 @@ def get_if_changed(self, etag): Raises: ValueError: If the ETag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not isinstance(etag, six.string_types): raise ValueError('ETag must be a string.') @@ -258,7 +259,7 @@ def set(self, value): Raises: ValueError: If the provided value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -281,7 +282,7 @@ def set_if_unchanged(self, expected_etag, value): Raises: ValueError: If the value is None, or if expected_etag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ # pylint: disable=missing-raises-doc if not isinstance(expected_etag, six.string_types): @@ -293,11 +294,11 @@ def set_if_unchanged(self, expected_etag, value): headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) return True, value, headers.get('ETag') - except ApiCallError as error: - detail = error.detail - if detail.response is not None and 'ETag' in detail.response.headers: - etag = detail.response.headers['ETag'] - snapshot = detail.response.json() + except exceptions.FailedPreconditionError as error: + http_response = error.http_response + if http_response is not None and 'ETag' in http_response.headers: + etag = http_response.headers['ETag'] + snapshot = http_response.json() return False, snapshot, etag else: raise error @@ -317,7 +318,7 @@ def push(self, value=''): Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -333,7 +334,7 @@ def update(self, value): Raises: ValueError: If value is empty or not a dictionary. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not value or not isinstance(value, dict): raise ValueError('Value argument must be a non-empty dictionary.') @@ -345,7 +346,7 @@ def delete(self): """Deletes this node from the database. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ self._client.request('delete', self._add_suffix()) @@ -371,7 +372,7 @@ def listen(self, callback): ListenerRegistration: An object that can be used to stop the event listener. Raises: - ApiCallError: If an error occurs while starting the initial HTTP connection. + FirebaseError: If an error occurs while starting the initial HTTP connection. """ session = _sseclient.KeepAuthSession(self._client.credential) return self._listen_with_session(callback, session) @@ -387,9 +388,9 @@ def transaction(self, transaction_update): value of this reference into a new value. If another client writes to this location before the new value is successfully saved, the update function is called again with the new current value, and the write will be retried. In case of repeated failures, this method - will retry the transaction up to 25 times before giving up and raising a TransactionError. - The update function may also force an early abort by raising an exception instead of - returning a value. + will retry the transaction up to 25 times before giving up and raising a + TransactionAbortedError. The update function may also force an early abort by raising an + exception instead of returning a value. Args: transaction_update: A function which will be passed the current data stored at this @@ -402,7 +403,7 @@ def transaction(self, transaction_update): object: New value of the current database Reference (only if the transaction commits). Raises: - TransactionError: If the transaction aborts after exhausting all retry attempts. + TransactionAbortedError: If the transaction aborts after exhausting all retry attempts. ValueError: If transaction_update is not a function. """ if not callable(transaction_update): @@ -416,7 +417,8 @@ def transaction(self, transaction_update): if success: return new_data tries += 1 - raise TransactionError('Transaction aborted after failed retries.') + + raise TransactionAbortedError('Transaction aborted after failed retries.') def order_by_child(self, path): """Returns a Query that orders data by child values. @@ -468,7 +470,7 @@ def _listen_with_session(self, callback, session): sse = _sseclient.SSEClient(url, session) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) class Query(object): @@ -614,7 +616,7 @@ def get(self): object: Decoded JSON result of the Query. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ result = self._client.body('get', self._pathurl, params=self._querystr) if isinstance(result, (dict, list)) and self._order_by != '$priority': @@ -622,20 +624,11 @@ def get(self): return result -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase database server API.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - -class TransactionError(Exception): - """Represents an Exception encountered while performing a transaction.""" +class TransactionAbortedError(exceptions.AbortedError): + """A transaction was aborted aftr exceeding the maximum number of retries.""" def __init__(self, message): - Exception.__init__(self, message) - + exceptions.AbortedError.__init__(self, message) class _Sorter(object): @@ -934,7 +927,7 @@ def request(self, method, url, **kwargs): Response: An HTTP response object. Raises: - ApiCallError: If an error occurs while making the HTTP call. + FirebaseError: If an error occurs while making the HTTP call. """ query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) extra_params = kwargs.get('params') @@ -950,33 +943,39 @@ def request(self, method, url, **kwargs): try: return super(_Client, self).request(method, url, **kwargs) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) + + @classmethod + def handle_rtdb_error(cls, error): + """Converts an error encountered while calling RTDB into a FirebaseError.""" + if error.response is None: + return _utils.handle_requests_error(error) + + message = cls._extract_error_message(error.response) + return _utils.handle_requests_error(error, message=message) @classmethod - def extract_error_message(cls, error): - """Extracts an error message from an exception. + def _extract_error_message(cls, response): + """Extracts an error message from an error response. - If the server has not sent any response, simply converts the exception into a string. If the server has sent a JSON response with an 'error' field, which is the typical behavior of the Realtime Database REST API, parses the response to retrieve the error message. If the server has sent a non-JSON response, returns the full response as the error message. - - Args: - error: An exception raised by the requests library. - - Returns: - str: A string error message extracted from the exception. """ - if error.response is None: - return str(error) + message = None try: - data = error.response.json() + # RTDB error format: {"error": "text message"} + data = response.json() if isinstance(data, dict): - return '{0}\nReason: {1}'.format(error, data.get('error', 'unknown')) + message = data.get('error') except ValueError: pass - return '{0}\nReason: {1}'.format(error, error.response.content.decode()) + + if not message: + message = 'Unexpected response from database: {0}'.format(response.content.decode()) + + return message class _EmulatorAdminCredentials(google.auth.credentials.Credentials): diff --git a/integration/test_db.py b/integration/test_db.py index d88d145ba..4c2f6bde2 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from integration import conftest from tests import testutils @@ -359,30 +360,26 @@ def init_ref(self, path, app): admin_ref.set('test') assert admin_ref.get() == 'test' - def check_permission_error(self, excinfo): - assert isinstance(excinfo.value, db.ApiCallError) - assert 'Reason: Permission denied' in str(excinfo.value) - def test_no_access(self, app, override_app): path = '_adminsdk/python/admin' self.init_ref(path, app) user_ref = db.reference(path, override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert user_ref.get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read(self, app, override_app): path = '_adminsdk/python/protected/user2' self.init_ref(path, app) user_ref = db.reference(path, override_app) assert user_ref.get() == 'test' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read_write(self, app, override_app): path = '_adminsdk/python/protected/user1' @@ -394,9 +391,9 @@ def test_read_write(self, app, override_app): def test_query(self, override_app): user_ref = db.reference('_adminsdk/python/protected', override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.order_by_key().limit_to_first(2).get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_none_auth_override(self, app, none_override_app): path = '_adminsdk/python/public' @@ -405,14 +402,14 @@ def test_none_auth_override(self, app, none_override_app): assert public_ref.get() == 'test' ref = db.reference('_adminsdk/python', none_override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user1').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user2').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('admin').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' diff --git a/snippets/database/index.py b/snippets/database/index.py index fee23f626..adfa13476 100644 --- a/snippets/database/index.py +++ b/snippets/database/index.py @@ -214,7 +214,7 @@ def increment_votes(current_value): try: new_vote_count = upvotes_ref.transaction(increment_votes) print('Transaction completed') - except db.TransactionError: + except db.TransactionAbortedError: print('Transaction failed to commit') # [END transaction] diff --git a/tests/test_db.py b/tests/test_db.py index 211eabb4b..081c31e3d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from firebase_admin import _sseclient from tests import testutils @@ -31,14 +32,15 @@ class MockAdapter(testutils.MockAdapter): ETAG = '0' - def __init__(self, data, status, recorder): + def __init__(self, data, status, recorder, etag=ETAG): testutils.MockAdapter.__init__(self, data, status, recorder) + self._etag = etag def send(self, request, **kwargs): if_match = request.headers.get('if-match') if_none_match = request.headers.get('if-none-match') resp = super(MockAdapter, self).send(request, **kwargs) - resp.headers = {'ETag': MockAdapter.ETAG} + resp.headers = {'ETag': self._etag} if if_match and if_match != MockAdapter.ETAG: resp.status_code = 412 elif if_none_match == MockAdapter.ETAG: @@ -125,6 +127,38 @@ def test_invalid_child(self, child): parent.child(child) +class _RefOperations(object): + """A collection of operations that can be performed using a ``db.Reference``. + + This can be used to test any functionality that is common across multiple API calls. + """ + + @classmethod + def get(cls, ref): + ref.get() + + @classmethod + def push(cls, ref): + ref.push() + + @classmethod + def set(cls, ref): + ref.set({'foo': 'bar'}) + + @classmethod + def delete(cls, ref): + ref.delete() + + @classmethod + def query(cls, ref): + query = ref.order_by_key() + query.get() + + @classmethod + def get_ops(cls): + return [cls.get, cls.push, cls.set, cls.delete, cls.query] + + class TestReference(object): """Test cases for database queries via References.""" @@ -132,6 +166,12 @@ class TestReference(object): valid_values = [ '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} ] + error_codes = { + 400: exceptions.InvalidArgumentError, + 401: exceptions.UnauthenticatedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + } @classmethod def setup_class(cls): @@ -141,9 +181,9 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() - def instrument(self, ref, payload, status=200): + def instrument(self, ref, payload, status=200, etag=MockAdapter.ETAG): recorder = [] - adapter = MockAdapter(payload, status, recorder) + adapter = MockAdapter(payload, status, recorder, etag) ref._client.session.mount(self.test_url, adapter) return recorder @@ -427,6 +467,19 @@ def transaction_update(data): assert len(recorder) == 1 assert recorder[0].method == 'GET' + def test_transaction_abort(self): + ref = db.reference('/test/count') + data = 42 + recorder = self.instrument(ref, json.dumps(data), etag='1') + + with pytest.raises(db.TransactionAbortedError) as excinfo: + ref.transaction(lambda x: x + 1 if x else 1) + assert isinstance(excinfo.value, exceptions.AbortedError) + assert str(excinfo.value) == 'Transaction aborted after failed retries.' + assert excinfo.value.cause is None + assert excinfo.value.http_response is None + assert len(recorder) == 1 + 25 + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) def test_transaction_invalid_function(self, func): ref = db.reference('/test') @@ -449,21 +502,29 @@ def test_get_reference(self, path, expected): else: assert ref.parent.path == parent - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_server_error(self, error_code): + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_server_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: json error message' in str(excinfo.value) - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_other_error(self, error_code): + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None + + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_other_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, 'custom error message', error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: custom error message' in str(excinfo.value) + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'Unexpected response from database: custom error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None class TestListenerRegistration(object): @@ -481,9 +542,11 @@ def test_listen_error(self): session.mount(test_url, adapter) def callback(_): pass - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.InternalError) as excinfo: ref._listen_with_session(callback, session) - assert 'Reason: json error message' in str(excinfo.value) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None finally: testutils.cleanup_apps()