Skip to content

Commit 0c66d40

Browse files
committed
update mock server for swft 6 compliance
1 parent b4f0164 commit 0c66d40

File tree

3 files changed

+288
-180
lines changed

3 files changed

+288
-180
lines changed

Package.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ let package = Package(
1717
.library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]),
1818
],
1919
dependencies: [
20-
.package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"),
20+
.package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"),
2121
.package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"),
2222
],
2323
targets: [
@@ -89,11 +89,11 @@ let package = Package(
8989
.executableTarget(
9090
name: "MockServer",
9191
dependencies: [
92+
.product(name: "Logging", package: "swift-log"),
9293
.product(name: "NIOHTTP1", package: "swift-nio"),
9394
.product(name: "NIOCore", package: "swift-nio"),
9495
.product(name: "NIOPosix", package: "swift-nio"),
95-
],
96-
swiftSettings: [.swiftLanguageMode(.v5)]
96+
]
9797
),
9898
]
9999
)
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftAWSLambdaRuntime open source project
4+
//
5+
// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import Logging
16+
import NIOCore
17+
import NIOHTTP1
18+
import NIOPosix
19+
20+
// for UUID and Date
21+
#if canImport(FoundationEssentials)
22+
import FoundationEssentials
23+
#else
24+
import Foundation
25+
#endif
26+
27+
@main
28+
public class MockHttpServer {
29+
30+
public static func main() throws {
31+
let server = MockHttpServer()
32+
try server.start()
33+
}
34+
35+
private func start() throws {
36+
let host = env("HOST") ?? "127.0.0.1"
37+
let port = env("PORT").flatMap(Int.init) ?? 7000
38+
let mode = env("MODE").flatMap(Mode.init) ?? .string
39+
var log = Logger(label: "MockServer")
40+
log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info
41+
let logger = log
42+
43+
let socketBootstrap = ServerBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount))
44+
// Specify backlog and enable SO_REUSEADDR for the server itself
45+
// .serverChannelOption(.backlog, value: 256)
46+
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
47+
// .childChannelOption(.maxMessagesPerRead, value: 1)
48+
49+
// Set the handlers that are applied to the accepted Channels
50+
.childChannelInitializer { channel in
51+
channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap {
52+
channel.pipeline.addHandler(HTTPHandler(mode: mode, logger: logger))
53+
}
54+
}
55+
56+
let channel = try socketBootstrap.bind(host: host, port: port).wait()
57+
logger.debug("Server started and listening on \(host):\(port)")
58+
59+
// This will never return as we don't close the ServerChannel
60+
try channel.closeFuture.wait()
61+
}
62+
}
63+
64+
private final class HTTPHandler: ChannelInboundHandler {
65+
public typealias InboundIn = HTTPServerRequestPart
66+
public typealias OutboundOut = HTTPServerResponsePart
67+
68+
private enum State {
69+
case idle
70+
case waitingForRequestBody
71+
case sendingResponse
72+
73+
mutating func requestReceived() {
74+
precondition(self == .idle, "Invalid state for request received: \(self)")
75+
self = .waitingForRequestBody
76+
}
77+
78+
mutating func requestComplete() {
79+
precondition(
80+
self == .waitingForRequestBody,
81+
"Invalid state for request complete: \(self)"
82+
)
83+
self = .sendingResponse
84+
}
85+
86+
mutating func responseComplete() {
87+
precondition(self == .sendingResponse, "Invalid state for response complete: \(self)")
88+
self = .idle
89+
}
90+
}
91+
92+
private let logger: Logger
93+
private let mode: Mode
94+
95+
private var buffer: ByteBuffer! = nil
96+
private var state: HTTPHandler.State = .idle
97+
private var keepAlive = false
98+
99+
private var requestHead: HTTPRequestHead?
100+
private var requestBodyBytes: Int = 0
101+
102+
init(mode: Mode, logger: Logger) {
103+
self.mode = mode
104+
self.logger = logger
105+
}
106+
107+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
108+
let reqPart = Self.unwrapInboundIn(data)
109+
handle(context: context, request: reqPart)
110+
}
111+
112+
func channelReadComplete(context: ChannelHandlerContext) {
113+
context.flush()
114+
self.buffer.clear()
115+
}
116+
117+
func handlerAdded(context: ChannelHandlerContext) {
118+
self.buffer = context.channel.allocator.buffer(capacity: 0)
119+
}
120+
121+
private func handle(context: ChannelHandlerContext, request: HTTPServerRequestPart) {
122+
switch request {
123+
case .head(let request):
124+
logger.trace("Received request .head")
125+
self.requestHead = request
126+
self.requestBodyBytes = 0
127+
self.keepAlive = request.isKeepAlive
128+
self.state.requestReceived()
129+
case .body(buffer: var buf):
130+
logger.trace("Received request .body")
131+
self.requestBodyBytes += buf.readableBytes
132+
self.buffer.writeBuffer(&buf)
133+
case .end:
134+
logger.trace("Received request .end")
135+
self.state.requestComplete()
136+
137+
precondition(requestHead != nil, "Received .end without .head")
138+
let (responseStatus, responseHeaders, responseBody) = self.processRequest(
139+
requestHead: self.requestHead!,
140+
requestBody: self.buffer
141+
)
142+
143+
self.buffer.clear()
144+
self.buffer.writeString(responseBody)
145+
146+
var headers = HTTPHeaders(responseHeaders)
147+
headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)")
148+
149+
// write the response
150+
context.write(
151+
Self.wrapOutboundOut(
152+
.head(
153+
httpResponseHead(
154+
request: self.requestHead!,
155+
status: responseStatus,
156+
headers: headers
157+
)
158+
)
159+
),
160+
promise: nil
161+
)
162+
context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil)
163+
self.completeResponse(context, trailers: nil, promise: nil)
164+
}
165+
}
166+
167+
private func processRequest(
168+
requestHead: HTTPRequestHead,
169+
requestBody: ByteBuffer
170+
) -> (HTTPResponseStatus, [(String, String)], String) {
171+
var responseStatus: HTTPResponseStatus = .ok
172+
var responseBody: String = ""
173+
var responseHeaders: [(String, String)] = []
174+
175+
logger.trace("Processing request for : \(requestHead) - \(requestBody.getString(at: 0, length: self.requestBodyBytes) ?? "")")
176+
177+
if requestHead.uri.hasSuffix("/next") {
178+
logger.trace("URI /next")
179+
180+
responseStatus = .accepted
181+
182+
let requestId = UUID().uuidString
183+
switch self.mode {
184+
case .string:
185+
responseBody = "\"\(requestId)\"" // must be a valid JSON string
186+
case .json:
187+
responseBody = "{ \"body\": \"\(requestId)\" }"
188+
}
189+
let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000)
190+
responseHeaders = [
191+
// ("Connection", "close"),
192+
(AmazonHeaders.requestID, requestId),
193+
(AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"),
194+
(AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"),
195+
(AmazonHeaders.deadline, String(deadline)),
196+
]
197+
} else if requestHead.uri.hasSuffix("/response") {
198+
logger.trace("URI /response")
199+
responseStatus = .accepted
200+
} else if requestHead.uri.hasSuffix("/error") {
201+
logger.trace("URI /error")
202+
responseStatus = .ok
203+
} else {
204+
logger.trace("Unknown URI : \(requestHead)")
205+
responseStatus = .notFound
206+
}
207+
logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)")
208+
return (responseStatus, responseHeaders, responseBody)
209+
}
210+
211+
private func completeResponse(
212+
_ context: ChannelHandlerContext,
213+
trailers: HTTPHeaders?,
214+
promise: EventLoopPromise<Void>?
215+
) {
216+
self.state.responseComplete()
217+
218+
let eventLoop = context.eventLoop
219+
let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop)
220+
221+
let promise = self.keepAlive ? promise : (promise ?? context.eventLoop.makePromise())
222+
if !self.keepAlive {
223+
promise!.futureResult.whenComplete { (_: Result<Void, Error>) in
224+
let context = loopBoundContext.value
225+
context.close(promise: nil)
226+
}
227+
}
228+
229+
context.writeAndFlush(Self.wrapOutboundOut(.end(trailers)), promise: promise)
230+
}
231+
232+
private func httpResponseHead(
233+
request: HTTPRequestHead,
234+
status: HTTPResponseStatus,
235+
headers: HTTPHeaders = HTTPHeaders()
236+
) -> HTTPResponseHead {
237+
var head = HTTPResponseHead(version: request.version, status: status, headers: headers)
238+
let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map {
239+
$0.lowercased()
240+
}
241+
242+
if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") {
243+
// the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers
244+
245+
switch (request.isKeepAlive, request.version.major, request.version.minor) {
246+
case (true, 1, 0):
247+
// HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that
248+
head.headers.add(name: "Connection", value: "keep-alive")
249+
case (false, 1, let n) where n >= 1:
250+
// HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that
251+
head.headers.add(name: "Connection", value: "close")
252+
default:
253+
// we should match the default or are dealing with some HTTP that we don't support, let's leave as is
254+
()
255+
}
256+
}
257+
return head
258+
}
259+
260+
private enum ServerError: Error {
261+
case notReady
262+
case cantBind
263+
}
264+
265+
private enum AmazonHeaders {
266+
static let requestID = "Lambda-Runtime-Aws-Request-Id"
267+
static let traceID = "Lambda-Runtime-Trace-Id"
268+
static let clientContext = "X-Amz-Client-Context"
269+
static let cognitoIdentity = "X-Amz-Cognito-Identity"
270+
static let deadline = "Lambda-Runtime-Deadline-Ms"
271+
static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn"
272+
}
273+
}
274+
275+
private enum Mode: String {
276+
case string
277+
case json
278+
}
279+
280+
private func env(_ name: String) -> String? {
281+
guard let value = getenv(name) else {
282+
return nil
283+
}
284+
return String(cString: value)
285+
}

0 commit comments

Comments
 (0)