diff --git a/arangoasync/database.py b/arangoasync/database.py index f5cb8e4..8ba7c62 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -1,6 +1,7 @@ __all__ = [ "Database", "StandardDatabase", + "TransactionDatabase", ] @@ -24,6 +25,11 @@ PermissionResetError, PermissionUpdateError, ServerStatusError, + TransactionAbortError, + TransactionCommitError, + TransactionInitError, + TransactionListError, + TransactionStatusError, UserCreateError, UserDeleteError, UserGetError, @@ -31,7 +37,7 @@ UserReplaceError, UserUpdateError, ) -from arangoasync.executor import ApiExecutor, DefaultApiExecutor +from arangoasync.executor import ApiExecutor, DefaultApiExecutor, TransactionApiExecutor from arangoasync.request import Method, Request from arangoasync.response import Response from arangoasync.serialization import Deserializer, Serializer @@ -84,6 +90,16 @@ def deserializer(self) -> Deserializer[Json, Jsons]: """Return the deserializer.""" return self._executor.deserializer + @property + def context(self) -> str: + """Return the API execution context. + + Returns: + str: API execution context. Possible values are "default", "transaction". + :rtype: str + """ + return self._executor.context + async def properties(self) -> Result[DatabaseProperties]: """Return database properties. @@ -1065,7 +1081,209 @@ def response_handler(resp: Response) -> Json: class StandardDatabase(Database): - """Standard database API wrapper.""" + """Standard database API wrapper. + + Args: + connection (Connection): Connection object to be used by the API executor. + """ def __init__(self, connection: Connection) -> None: super().__init__(DefaultApiExecutor(connection)) + + def __repr__(self) -> str: + return f"" + + async def begin_transaction( + self, + read: Optional[str | Sequence[str]] = None, + write: Optional[str | Sequence[str]] = None, + exclusive: Optional[str | Sequence[str]] = None, + wait_for_sync: Optional[bool] = None, + allow_implicit: Optional[bool] = None, + lock_timeout: Optional[int] = None, + max_transaction_size: Optional[int] = None, + allow_dirty_read: Optional[bool] = None, + skip_fast_lock_round: Optional[bool] = None, + ) -> "TransactionDatabase": + """Begin a Stream Transaction. + + Args: + read (str | list | None): Name(s) of collections read during transaction. + Read-only collections are added lazily but should be declared if + possible to avoid deadlocks. + write (str | list | None): Name(s) of collections written to during + transaction with shared access. + exclusive (str | list | None): Name(s) of collections written to during + transaction with exclusive access. + wait_for_sync (bool | None): If `True`, will force the transaction to write + all data to disk before returning + allow_implicit (bool | None): Allow reading from undeclared collections. + lock_timeout (int | None): Timeout for waiting on collection locks. Setting + it to 0 will make ArangoDB not time out waiting for a lock. + max_transaction_size (int | None): Transaction size limit in bytes. + allow_dirty_read (bool | None): If `True`, allows the Coordinator to ask any + shard replica for the data, not only the shard leader. This may result + in “dirty reads”. This setting decides about dirty reads for the entire + transaction. Individual read operations, that are performed as part of + the transaction, cannot override it. + skip_fast_lock_round (bool | None): Whether to disable fast locking for + write operations. + + Returns: + TransactionDatabase: Database API wrapper specifically tailored for + transactions. + + Raises: + TransactionInitError: If the operation fails on the server side. + """ + collections = dict() + if read is not None: + collections["read"] = read + if write is not None: + collections["write"] = write + if exclusive is not None: + collections["exclusive"] = exclusive + + data: Json = dict(collections=collections) + if wait_for_sync is not None: + data["waitForSync"] = wait_for_sync + if allow_implicit is not None: + data["allowImplicit"] = allow_implicit + if lock_timeout is not None: + data["lockTimeout"] = lock_timeout + if max_transaction_size is not None: + data["maxTransactionSize"] = max_transaction_size + if skip_fast_lock_round is not None: + data["skipFastLockRound"] = skip_fast_lock_round + + headers = dict() + if allow_dirty_read is not None: + headers["x-arango-allow-dirty-read"] = str(allow_dirty_read).lower() + + request = Request( + method=Method.POST, + endpoint="/_api/transaction/begin", + data=self.serializer.dumps(data), + headers=headers, + ) + + def response_handler(resp: Response) -> str: + if not resp.is_success: + raise TransactionInitError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body)["result"] + return cast(str, result["id"]) + + transaction_id = await self._executor.execute(request, response_handler) + return TransactionDatabase(self.connection, transaction_id) + + def fetch_transaction(self, transaction_id: str) -> "TransactionDatabase": + """Fetch an existing transaction. + + Args: + transaction_id (str): Transaction ID. + + Returns: + TransactionDatabase: Database API wrapper specifically tailored for + transactions. + """ + return TransactionDatabase(self.connection, transaction_id) + + async def list_transactions(self) -> Result[Jsons]: + """List all currently running stream transactions. + + Returns: + list: List of transactions, with each transaction containing + an "id" and a "state" field. + + Raises: + TransactionListError: If the operation fails on the server side. + """ + request = Request(method=Method.GET, endpoint="/_api/transaction") + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise TransactionListError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body) + return cast(Jsons, result["transactions"]) + + return await self._executor.execute(request, response_handler) + + +class TransactionDatabase(Database): + """Database API tailored specifically for + `Stream Transactions `__. + + It allows you start a transaction, run multiple operations (eg. AQL queries) over a short period of time, + and then commit or abort the transaction. + + See :func:`arangoasync.database.StandardDatabase.begin_transaction`. + + Args: + connection (Connection): Connection object to be used by the API executor. + transaction_id (str): Transaction ID. + """ # noqa: E501 + + def __init__(self, connection: Connection, transaction_id: str) -> None: + super().__init__(TransactionApiExecutor(connection, transaction_id)) + self._standard_executor = DefaultApiExecutor(connection) + self._transaction_id = transaction_id + + def __repr__(self) -> str: + return f"" + + @property + def transaction_id(self) -> str: + """Transaction ID.""" + return self._transaction_id + + async def transaction_status(self) -> str: + """Get the status of the transaction. + + Returns: + str: Transaction status: one of "running", "committed" or "aborted". + + Raises: + TransactionStatusError: If the transaction is not found. + """ + request = Request( + method=Method.GET, + endpoint=f"/_api/transaction/{self.transaction_id}", + ) + + def response_handler(resp: Response) -> str: + if not resp.is_success: + raise TransactionStatusError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body)["result"] + return cast(str, result["status"]) + + return await self._executor.execute(request, response_handler) + + async def commit_transaction(self) -> None: + """Commit the transaction. + + Raises: + TransactionCommitError: If the operation fails on the server side. + """ + request = Request( + method=Method.PUT, + endpoint=f"/_api/transaction/{self.transaction_id}", + ) + + def response_handler(resp: Response) -> None: + if not resp.is_success: + raise TransactionCommitError(resp, request) + + await self._executor.execute(request, response_handler) + + async def abort_transaction(self) -> None: + """Abort the transaction.""" + request = Request( + method=Method.DELETE, + endpoint=f"/_api/transaction/{self.transaction_id}", + ) + + def response_handler(resp: Response) -> None: + if not resp.is_success: + raise TransactionAbortError(resp, request) + + await self._executor.execute(request, response_handler) diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index 62b871a..00e668b 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -195,6 +195,26 @@ class ServerStatusError(ArangoServerError): """Failed to retrieve server status.""" +class TransactionAbortError(ArangoServerError): + """Failed to abort transaction.""" + + +class TransactionCommitError(ArangoServerError): + """Failed to commit transaction.""" + + +class TransactionInitError(ArangoServerError): + """Failed to initialize transaction.""" + + +class TransactionListError(ArangoServerError): + """Failed to retrieve transactions.""" + + +class TransactionStatusError(ArangoServerError): + """Failed to retrieve transaction status.""" + + class UserCreateError(ArangoServerError): """Failed to create user.""" diff --git a/arangoasync/executor.py b/arangoasync/executor.py index 175096e..c830be7 100644 --- a/arangoasync/executor.py +++ b/arangoasync/executor.py @@ -60,4 +60,39 @@ async def execute( return response_handler(response) -ApiExecutor = DefaultApiExecutor +class TransactionApiExecutor(DefaultApiExecutor): + """Executes transaction API requests. + + Args: + connection: HTTP connection. + transaction_id: str: Transaction ID generated by the server. + """ + + def __init__(self, connection: Connection, transaction_id: str) -> None: + super().__init__(connection) + self._id = transaction_id + + @property + def context(self) -> str: + return "transaction" + + @property + def id(self) -> str: + """Return the transaction ID.""" + return self._id + + async def execute( + self, request: Request, response_handler: Callable[[Response], T] + ) -> T: + """Execute the request and handle the response. + + Args: + request: HTTP request. + response_handler: HTTP response handler. + """ + request.headers["x-arango-trx-id"] = self.id + response = await self._conn.send_request(request) + return response_handler(response) + + +ApiExecutor = DefaultApiExecutor | TransactionApiExecutor diff --git a/arangoasync/typings.py b/arangoasync/typings.py index 0744ada..aefe43c 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -836,7 +836,7 @@ class IndexProperties(JsonWrapper): } References: - - `get-an-index `__ + - `get-an-index `__ """ # noqa: E501 def __init__(self, data: Json) -> None: diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 0000000..e8730fd --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,190 @@ +import asyncio + +import pytest + +from arangoasync.database import TransactionDatabase +from arangoasync.errno import BAD_PARAMETER, FORBIDDEN, TRANSACTION_NOT_FOUND +from arangoasync.exceptions import ( + TransactionAbortError, + TransactionCommitError, + TransactionInitError, + TransactionStatusError, +) + + +@pytest.mark.asyncio +async def test_transaction_document_insert(db, bad_db, doc_col, docs): + # Start a basic transaction + txn_db = await db.begin_transaction( + read=doc_col.name, + write=doc_col.name, + exclusive=[], + wait_for_sync=True, + allow_implicit=False, + lock_timeout=1000, + max_transaction_size=1024 * 1024, + skip_fast_lock_round=True, + allow_dirty_read=False, + ) + + # Make sure the object properties are set correctly + assert isinstance(txn_db, TransactionDatabase) + assert txn_db.name == db.name + assert txn_db.context == "transaction" + assert txn_db.transaction_id is not None + assert repr(txn_db) == f"" + txn_col = txn_db.collection(doc_col.name) + assert txn_col.db_name == db.name + + with pytest.raises(TransactionInitError) as err: + await bad_db.begin_transaction() + assert err.value.error_code == FORBIDDEN + + # Insert a document in the transaction + for doc in docs: + result = await txn_col.insert(doc) + assert result["_id"] == f"{doc_col.name}/{doc['_key']}" + assert result["_key"] == doc["_key"] + assert isinstance(result["_rev"], str) + assert (await txn_col.get(doc["_key"]))["val"] == doc["val"] + + # Abort the transaction + await txn_db.abort_transaction() + + +@pytest.mark.asyncio +async def test_transaction_status(db, doc_col): + # Begin a transaction + txn_db = await db.begin_transaction(read=doc_col.name) + assert await txn_db.transaction_status() == "running" + + # Commit the transaction + await txn_db.commit_transaction() + assert await txn_db.transaction_status() == "committed" + + # Begin another transaction + txn_db = await db.begin_transaction(read=doc_col.name) + assert await txn_db.transaction_status() == "running" + + # Abort the transaction + await txn_db.abort_transaction() + assert await txn_db.transaction_status() == "aborted" + + # Test with an illegal transaction ID + txn_db = db.fetch_transaction("illegal") + with pytest.raises(TransactionStatusError) as err: + await txn_db.transaction_status() + # Error code differs between single server and cluster mode + assert err.value.error_code in {BAD_PARAMETER, TRANSACTION_NOT_FOUND} + + +@pytest.mark.asyncio +async def test_transaction_commit(db, doc_col, docs): + # Begin a transaction + txn_db = await db.begin_transaction( + read=doc_col.name, + write=doc_col.name, + ) + txn_col = txn_db.collection(doc_col.name) + + # Insert documents in the transaction + assert "_rev" in await txn_col.insert(docs[0]) + assert "_rev" in await txn_col.insert(docs[2]) + await txn_db.commit_transaction() + assert await txn_db.transaction_status() == "committed" + + # Check the documents, after transaction has been committed + doc = await doc_col.get(docs[2]["_key"]) + assert doc["_key"] == docs[2]["_key"] + assert doc["val"] == docs[2]["val"] + + # Test with an illegal transaction ID + txn_db = db.fetch_transaction("illegal") + with pytest.raises(TransactionCommitError) as err: + await txn_db.commit_transaction() + # Error code differs between single server and cluster mode + assert err.value.error_code in {BAD_PARAMETER, TRANSACTION_NOT_FOUND} + + +@pytest.mark.asyncio +async def test_transaction_abort(db, doc_col, docs): + # Begin a transaction + txn_db = await db.begin_transaction( + read=doc_col.name, + write=doc_col.name, + ) + txn_col = txn_db.collection(doc_col.name) + + # Insert documents in the transaction + assert "_rev" in await txn_col.insert(docs[0]) + assert "_rev" in await txn_col.insert(docs[2]) + await txn_db.abort_transaction() + assert await txn_db.transaction_status() == "aborted" + + # Check the documents, after transaction has been aborted + assert await doc_col.get(docs[2]["_key"]) is None + + # Test with an illegal transaction ID + txn_db = db.fetch_transaction("illegal") + with pytest.raises(TransactionAbortError) as err: + await txn_db.abort_transaction() + # Error code differs between single server and cluster mode + assert err.value.error_code in {BAD_PARAMETER, TRANSACTION_NOT_FOUND} + + +@pytest.mark.asyncio +async def test_transaction_fetch_existing(db, doc_col, docs): + # Begin a transaction + txn_db = await db.begin_transaction( + read=doc_col.name, + write=doc_col.name, + ) + txn_col = txn_db.collection(doc_col.name) + + # Insert documents in the transaction + assert "_rev" in await txn_col.insert(docs[0]) + assert "_rev" in await txn_col.insert(docs[1]) + + txn_db2 = db.fetch_transaction(txn_db.transaction_id) + assert txn_db2.transaction_id == txn_db.transaction_id + txn_col2 = txn_db2.collection(doc_col.name) + assert "_rev" in await txn_col2.insert(docs[2]) + + await txn_db2.commit_transaction() + assert await txn_db.transaction_status() == "committed" + assert await txn_db2.transaction_status() == "committed" + + # Check the documents, after transaction has been aborted + assert all( + await asyncio.gather(*(doc_col.get(docs[idx]["_key"]) for idx in range(3))) + ) + + +@pytest.mark.asyncio +async def test_transaction_list(db): + # There should be no transactions initially + assert await db.list_transactions() == [] + + # Begin a transaction + txn_db1 = await db.begin_transaction() + tx_ls = await db.list_transactions() + assert len(tx_ls) == 1 + assert any(txn_db1.transaction_id == tx["id"] for tx in tx_ls) + + # Begin another transaction + txn_db2 = await db.begin_transaction() + tx_ls = await db.list_transactions() + assert len(tx_ls) == 2 + assert any(txn_db1.transaction_id == tx["id"] for tx in tx_ls) + assert any(txn_db2.transaction_id == tx["id"] for tx in tx_ls) + + # Only the first transaction should be running after aborting the second + await txn_db2.abort_transaction() + tx_ls = await db.list_transactions() + assert len(tx_ls) == 1 + assert any(txn_db1.transaction_id == tx["id"] for tx in tx_ls) + + # Commit the first transaction, no transactions should be left + await txn_db1.commit_transaction() + tx_ls = await db.list_transactions() + assert len(tx_ls) == 0