@@ -317,6 +317,40 @@ class PostgresConnectionTests: XCTestCase {
317
317
}
318
318
}
319
319
320
+ func testCloseImmediatelyWithSimpleQuery( ) async throws {
321
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
322
+
323
+ try await withThrowingTaskGroup ( of: Void . self) { [ logger] taskGroup async throws -> ( ) in
324
+ for _ in 1 ... 2 {
325
+ taskGroup. addTask {
326
+ try await connection. __simpleQuery ( " SELECT 1; " , logger: logger)
327
+ }
328
+ }
329
+
330
+ let query = try await channel. waitForSimpleQueryRequest ( )
331
+ XCTAssertEqual ( query. query, " SELECT 1; " )
332
+
333
+ async let close : ( ) = connection. close ( )
334
+
335
+ try await channel. closeFuture. get ( )
336
+ XCTAssertEqual ( channel. isActive, false )
337
+
338
+ try await close
339
+
340
+ while let taskResult = await taskGroup. nextResult ( ) {
341
+ switch taskResult {
342
+ case . success:
343
+ XCTFail ( " Expected queries to fail " )
344
+ case . failure( let failure) :
345
+ guard let error = failure as? PSQLError else {
346
+ return XCTFail ( " Unexpected error type: \( failure) " )
347
+ }
348
+ XCTAssertEqual ( error. code, . clientClosedConnection)
349
+ }
350
+ }
351
+ }
352
+ }
353
+
320
354
func testIfServerJustClosesTheErrorReflectsThat( ) async throws {
321
355
let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
322
356
let logger = self . logger
@@ -346,6 +380,35 @@ class PostgresConnectionTests: XCTestCase {
346
380
}
347
381
}
348
382
383
+ func testIfServerJustClosesTheErrorReflectsThatInSimpleQuery( ) async throws {
384
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
385
+ let logger = self . logger
386
+
387
+ async let response = try await connection. __simpleQuery ( " SELECT 1; " , logger: logger)
388
+
389
+ let query = try await channel. waitForSimpleQueryRequest ( )
390
+ XCTAssertEqual ( query. query, " SELECT 1; " )
391
+
392
+ try await channel. testingEventLoop. executeInContext { channel. pipeline. fireChannelInactive ( ) }
393
+ try await channel. testingEventLoop. executeInContext { channel. pipeline. fireChannelUnregistered ( ) }
394
+
395
+ do {
396
+ _ = try await response
397
+ XCTFail ( " Expected to throw " )
398
+ } catch {
399
+ XCTAssertEqual ( ( error as? PSQLError ) ? . code, . serverClosedConnection)
400
+ }
401
+
402
+ // retry on same connection
403
+
404
+ do {
405
+ _ = try await connection. __simpleQuery ( " SELECT 1; " , logger: self . logger)
406
+ XCTFail ( " Expected to throw " )
407
+ } catch {
408
+ XCTAssertEqual ( ( error as? PSQLError ) ? . code, . serverClosedConnection)
409
+ }
410
+ }
411
+
349
412
struct TestPrepareStatement : PostgresPreparedStatement {
350
413
static let sql = " SELECT datname FROM pg_stat_activity WHERE state = $1 "
351
414
typealias Row = String
@@ -692,6 +755,14 @@ extension NIOAsyncTestingChannel {
692
755
return UnpreparedRequest ( parse: parse, describe: describe, bind: bind, execute: execute)
693
756
}
694
757
758
+ func waitForSimpleQueryRequest( ) async throws -> PostgresFrontendMessage . Query {
759
+ let query = try await self . waitForOutboundWrite ( as: PostgresFrontendMessage . self)
760
+ guard case . query( let query) = query else {
761
+ fatalError ( )
762
+ }
763
+ return query
764
+ }
765
+
695
766
func waitForPrepareRequest( ) async throws -> PrepareRequest {
696
767
let parse = try await self . waitForOutboundWrite ( as: PostgresFrontendMessage . self)
697
768
let describe = try await self . waitForOutboundWrite ( as: PostgresFrontendMessage . self)
0 commit comments