Skip to content

Commit 1ed7e7d

Browse files
committed
add connection tests
1 parent e904c86 commit 1ed7e7d

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ extension PostgresFrontendMessage {
218218
case .saslResponse:
219219
preconditionFailure("TODO: Unimplemented")
220220
case .query:
221-
return .query
221+
guard let query = buffer.readNullTerminatedString() else {
222+
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
223+
}
224+
return .query(.init(query: query))
222225
case .sync:
223226
return .sync
224227
case .terminate:

Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ enum PostgresFrontendMessage: Equatable {
5959
}
6060
}
6161

62+
struct Query: Hashable {
63+
/// The query string.
64+
let query: String
65+
}
66+
6267
struct Parse: Hashable {
6368
/// The name of the destination prepared statement (an empty string selects the unnamed prepared statement).
6469
let preparedStatementName: String
@@ -179,7 +184,7 @@ enum PostgresFrontendMessage: Equatable {
179184
case saslInitialResponse(SASLInitialResponse)
180185
case saslResponse(SASLResponse)
181186
case sslRequest
182-
case query
187+
case query(Query)
183188
case sync
184189
case startup(Startup)
185190
case terminate

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,40 @@ class PostgresConnectionTests: XCTestCase {
317317
}
318318
}
319319

320+
func testCloseImmediatelyWithSimpleQuery() async throws {
321+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
322+
323+
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
324+
for _ in 1...2 {
325+
taskGroup.addTask {
326+
try await connection.__simpleQuery("SELECT 1;", logger: logger)
327+
}
328+
}
329+
330+
let query = try await channel.waitForSimpleQueryRequest()
331+
XCTAssertEqual(query.query, "SELECT 1;")
332+
333+
async let close: () = connection.close()
334+
335+
try await channel.closeFuture.get()
336+
XCTAssertEqual(channel.isActive, false)
337+
338+
try await close
339+
340+
while let taskResult = await taskGroup.nextResult() {
341+
switch taskResult {
342+
case .success:
343+
XCTFail("Expected queries to fail")
344+
case .failure(let failure):
345+
guard let error = failure as? PSQLError else {
346+
return XCTFail("Unexpected error type: \(failure)")
347+
}
348+
XCTAssertEqual(error.code, .clientClosedConnection)
349+
}
350+
}
351+
}
352+
}
353+
320354
func testIfServerJustClosesTheErrorReflectsThat() async throws {
321355
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
322356
let logger = self.logger
@@ -346,6 +380,35 @@ class PostgresConnectionTests: XCTestCase {
346380
}
347381
}
348382

383+
func testIfServerJustClosesTheErrorReflectsThatInSimpleQuery() async throws {
384+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
385+
let logger = self.logger
386+
387+
async let response = try await connection.__simpleQuery("SELECT 1;", logger: logger)
388+
389+
let query = try await channel.waitForSimpleQueryRequest()
390+
XCTAssertEqual(query.query, "SELECT 1;")
391+
392+
try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() }
393+
try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() }
394+
395+
do {
396+
_ = try await response
397+
XCTFail("Expected to throw")
398+
} catch {
399+
XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection)
400+
}
401+
402+
// retry on same connection
403+
404+
do {
405+
_ = try await connection.__simpleQuery("SELECT 1;", logger: self.logger)
406+
XCTFail("Expected to throw")
407+
} catch {
408+
XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection)
409+
}
410+
}
411+
349412
struct TestPrepareStatement: PostgresPreparedStatement {
350413
static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1"
351414
typealias Row = String
@@ -692,6 +755,14 @@ extension NIOAsyncTestingChannel {
692755
return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute)
693756
}
694757

758+
func waitForSimpleQueryRequest() async throws -> PostgresFrontendMessage.Query {
759+
let query = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
760+
guard case .query(let query) = query else {
761+
fatalError()
762+
}
763+
return query
764+
}
765+
695766
func waitForPrepareRequest() async throws -> PrepareRequest {
696767
let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
697768
let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)

0 commit comments

Comments
 (0)