Skip to content

Commit 6432cca

Browse files
authored
Merge pull request #84 from psqlpy-python/feature/add_support_for_multiquery_execution
Added method execute_batch for Connection and Transaction
2 parents 903008e + aa021ec commit 6432cca

File tree

7 files changed

+147
-0
lines changed

7 files changed

+147
-0
lines changed

docs/components/connection.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,28 @@ async def main() -> None:
6161
dict_results: list[dict[str, Any]] = results.result()
6262
```
6363

64+
### Execute Batch
65+
66+
#### Parameters:
67+
68+
- `querystring`: querystrings separated by semicolons.
69+
70+
Executes a sequence of SQL statements using the simple query protocol.
71+
72+
Statements should be separated by semicolons.
73+
If an error occurs, execution of the sequence will stop at that point.
74+
This is intended for use when, for example,
75+
initializing a database schema.
76+
77+
```python
78+
async def main() -> None:
79+
...
80+
connection = await db_pool.connection()
81+
await connection.execute_batch(
82+
"CREATE TABLE psqlpy (name VARCHAR); CREATE TABLE psqlpy2 (name VARCHAR);",
83+
)
84+
```
85+
6486
### Fetch
6587

6688
#### Parameters:

docs/components/transaction.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,29 @@ async def main() -> None:
144144
dict_results: list[dict[str, Any]] = results.result()
145145
```
146146

147+
### Execute Batch
148+
149+
#### Parameters:
150+
151+
- `querystring`: querystrings separated by semicolons.
152+
153+
Executes a sequence of SQL statements using the simple query protocol.
154+
155+
Statements should be separated by semicolons.
156+
If an error occurs, execution of the sequence will stop at that point.
157+
This is intended for use when, for example,
158+
initializing a database schema.
159+
160+
```python
161+
async def main() -> None:
162+
...
163+
connection = await db_pool.connection()
164+
async with connection.transaction() as transaction:
165+
await transaction.execute_batch(
166+
"CREATE TABLE psqlpy (name VARCHAR); CREATE TABLE psqlpy2 (name VARCHAR);",
167+
)
168+
```
169+
147170
### Fetch
148171

149172
#### Parameters:

python/psqlpy/_internal/__init__.pyi

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,21 @@ class Transaction:
455455
await transaction.commit()
456456
```
457457
"""
458+
async def execute_batch(
459+
self: Self,
460+
querystring: str,
461+
) -> None:
462+
"""
463+
Executes a sequence of SQL statements using the simple query protocol.
464+
465+
Statements should be separated by semicolons.
466+
If an error occurs, execution of the sequence will stop at that point.
467+
This is intended for use when, for example,
468+
initializing a database schema.
469+
470+
### Parameters:
471+
- `querystring`: querystrings separated by semicolons.
472+
"""
458473
async def execute_many(
459474
self: Self,
460475
querystring: str,
@@ -885,6 +900,21 @@ class Connection:
885900
dict_result: List[Dict[Any, Any]] = query_result.result()
886901
```
887902
"""
903+
async def execute_batch(
904+
self: Self,
905+
querystring: str,
906+
) -> None:
907+
"""
908+
Executes a sequence of SQL statements using the simple query protocol.
909+
910+
Statements should be separated by semicolons.
911+
If an error occurs, execution of the sequence will stop at that point.
912+
This is intended for use when, for example,
913+
initializing a database schema.
914+
915+
### Parameters:
916+
- `querystring`: querystrings separated by semicolons.
917+
"""
888918
async def execute_many(
889919
self: Self,
890920
querystring: str,

python/tests/test_connection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,14 @@ async def test_binary_copy_to_table(
236236
f"SELECT COUNT(*) AS rows_count FROM {table_name}",
237237
)
238238
assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row
239+
240+
241+
async def test_execute_batch_method(psql_pool: ConnectionPool) -> None:
242+
"""Test `execute_batch` method."""
243+
await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch")
244+
await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch2")
245+
query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);"
246+
async with psql_pool.acquire() as conn:
247+
await conn.execute_batch(querystring=query)
248+
await conn.execute(querystring="SELECT * FROM execute_batch")
249+
await conn.execute(querystring="SELECT * FROM execute_batch2")

python/tests/test_transaction.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,14 @@ async def test_binary_copy_to_table(
390390
f"SELECT COUNT(*) AS rows_count FROM {table_name}",
391391
)
392392
assert real_table_rows.result()[0]["rows_count"] == expected_inserted_row
393+
394+
395+
async def test_execute_batch_method(psql_pool: ConnectionPool) -> None:
396+
"""Test `execute_batch` method."""
397+
await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch")
398+
await psql_pool.execute(querystring="DROP TABLE IF EXISTS execute_batch2")
399+
query = "CREATE TABLE execute_batch (name VARCHAR);CREATE TABLE execute_batch2 (name VARCHAR);"
400+
async with psql_pool.acquire() as conn, conn.transaction() as transaction:
401+
await transaction.execute_batch(querystring=query)
402+
await transaction.execute(querystring="SELECT * FROM execute_batch")
403+
await transaction.execute(querystring="SELECT * FROM execute_batch2")

src/driver/connection.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,31 @@ impl Connection {
260260
Err(RustPSQLDriverError::ConnectionClosedError)
261261
}
262262

263+
/// Executes a sequence of SQL statements using the simple query protocol.
264+
///
265+
/// Statements should be separated by semicolons.
266+
/// If an error occurs, execution of the sequence will stop at that point.
267+
/// This is intended for use when, for example,
268+
/// initializing a database schema.
269+
///
270+
/// # Errors
271+
///
272+
/// May return Err Result if:
273+
/// 1) Connection is closed.
274+
/// 2) Cannot execute querystring.
275+
pub async fn execute_batch(
276+
self_: pyo3::Py<Self>,
277+
querystring: String,
278+
) -> RustPSQLDriverPyResult<()> {
279+
let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone());
280+
281+
if let Some(db_client) = db_client {
282+
return Ok(db_client.batch_execute(&querystring).await?);
283+
}
284+
285+
Err(RustPSQLDriverError::ConnectionClosedError)
286+
}
287+
263288
/// Execute querystring with parameters.
264289
///
265290
/// It converts incoming parameters to rust readable

src/driver/transaction.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,31 @@ impl Transaction {
301301
Err(RustPSQLDriverError::TransactionClosedError)
302302
}
303303

304+
/// Executes a sequence of SQL statements using the simple query protocol.
305+
///
306+
/// Statements should be separated by semicolons.
307+
/// If an error occurs, execution of the sequence will stop at that point.
308+
/// This is intended for use when, for example,
309+
/// initializing a database schema.
310+
///
311+
/// # Errors
312+
///
313+
/// May return Err Result if:
314+
/// 1) Transaction is closed.
315+
/// 2) Cannot execute querystring.
316+
pub async fn execute_batch(self_: Py<Self>, querystring: String) -> RustPSQLDriverPyResult<()> {
317+
let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| {
318+
let self_ = self_.borrow(gil);
319+
(self_.check_is_transaction_ready(), self_.db_client.clone())
320+
});
321+
is_transaction_ready?;
322+
if let Some(db_client) = db_client {
323+
return Ok(db_client.batch_execute(&querystring).await?);
324+
}
325+
326+
Err(RustPSQLDriverError::TransactionClosedError)
327+
}
328+
304329
/// Fetch result from the database.
305330
///
306331
/// It converts incoming parameters to rust readable

0 commit comments

Comments
 (0)