|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// This source file is part of the SwiftAwsLambda open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2017-2018 Apple Inc. and the SwiftAwsLambda project authors |
| 6 | +// Licensed under Apache License v2.0 |
| 7 | +// |
| 8 | +// See LICENSE.txt for license information |
| 9 | +// See CONTRIBUTORS.txt for the list of SwiftAwsLambda project authors |
| 10 | +// |
| 11 | +// SPDX-License-Identifier: Apache-2.0 |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +import Foundation |
| 16 | +import Logging |
| 17 | +import NIO |
| 18 | +import NIOHTTP1 |
| 19 | + |
| 20 | +internal struct MockServer { |
| 21 | + private let logger: Logger |
| 22 | + private let group: EventLoopGroup |
| 23 | + private let host: String |
| 24 | + private let port: Int |
| 25 | + private let mode: Mode |
| 26 | + private let keepAlive: Bool |
| 27 | + |
| 28 | + public init() { |
| 29 | + var logger = Logger(label: "MockServer") |
| 30 | + logger.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info |
| 31 | + self.logger = logger |
| 32 | + self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) |
| 33 | + self.host = env("HOST") ?? "127.0.0.1" |
| 34 | + self.port = env("PORT").flatMap(Int.init) ?? 7000 |
| 35 | + self.mode = env("MODE").flatMap(Mode.init) ?? .string |
| 36 | + self.keepAlive = env("KEEP_ALIVE").flatMap(Bool.init) ?? true |
| 37 | + } |
| 38 | + |
| 39 | + func start() throws { |
| 40 | + let bootstrap = ServerBootstrap(group: group) |
| 41 | + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) |
| 42 | + .childChannelInitializer { channel in |
| 43 | + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in |
| 44 | + channel.pipeline.addHandler(HTTPHandler(logger: self.logger, |
| 45 | + keepAlive: self.keepAlive, |
| 46 | + mode: self.mode)) |
| 47 | + } |
| 48 | + } |
| 49 | + try bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture<Void> in |
| 50 | + guard let localAddress = channel.localAddress else { |
| 51 | + return channel.eventLoop.makeFailedFuture(ServerError.cantBind) |
| 52 | + } |
| 53 | + self.logger.info("\(self) started and listening on \(localAddress)") |
| 54 | + return channel.eventLoop.makeSucceededFuture(()) |
| 55 | + }.wait() |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +internal final class HTTPHandler: ChannelInboundHandler { |
| 60 | + public typealias InboundIn = HTTPServerRequestPart |
| 61 | + public typealias OutboundOut = HTTPServerResponsePart |
| 62 | + |
| 63 | + private let logger: Logger |
| 64 | + private let mode: Mode |
| 65 | + private let keepAlive: Bool |
| 66 | + |
| 67 | + private var requestHead: HTTPRequestHead! |
| 68 | + private var requestBody: ByteBuffer? |
| 69 | + |
| 70 | + public init(logger: Logger, keepAlive: Bool, mode: Mode) { |
| 71 | + self.logger = logger |
| 72 | + self.mode = mode |
| 73 | + self.keepAlive = keepAlive |
| 74 | + } |
| 75 | + |
| 76 | + func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
| 77 | + let requestPart = unwrapInboundIn(data) |
| 78 | + |
| 79 | + switch requestPart { |
| 80 | + case .head(let head): |
| 81 | + self.requestHead = head |
| 82 | + self.requestBody?.clear() |
| 83 | + case .body(var buffer): |
| 84 | + if self.requestBody == nil { |
| 85 | + self.requestBody = buffer |
| 86 | + } else { |
| 87 | + self.requestBody!.writeBuffer(&buffer) |
| 88 | + } |
| 89 | + case .end: |
| 90 | + self.processRequest(context: context) |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + func processRequest(context: ChannelHandlerContext) { |
| 95 | + self.logger.debug("\(self) processing \(self.requestHead.uri)") |
| 96 | + |
| 97 | + var responseStatus: HTTPResponseStatus |
| 98 | + var responseBody: String? |
| 99 | + var responseHeaders: [(String, String)]? |
| 100 | + |
| 101 | + if self.requestHead.uri.hasSuffix("/next") { |
| 102 | + let requestId = UUID().uuidString |
| 103 | + responseStatus = .ok |
| 104 | + switch self.mode { |
| 105 | + case .string: |
| 106 | + responseBody = requestId |
| 107 | + case .json: |
| 108 | + responseBody = "{ \"body\": \"\(requestId)\" }" |
| 109 | + } |
| 110 | + responseHeaders = [(AmazonHeaders.requestID, requestId)] |
| 111 | + } else if self.requestHead.uri.hasSuffix("/response") { |
| 112 | + responseStatus = .accepted |
| 113 | + } else { |
| 114 | + responseStatus = .notFound |
| 115 | + } |
| 116 | + self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody) |
| 117 | + } |
| 118 | + |
| 119 | + func writeResponse(context: ChannelHandlerContext, status: HTTPResponseStatus, headers: [(String, String)]? = nil, body: String? = nil) { |
| 120 | + var headers = HTTPHeaders(headers ?? []) |
| 121 | + headers.add(name: "Content-Length", value: "\(body?.utf8.count ?? 0)") |
| 122 | + headers.add(name: "Connection", value: self.keepAlive ? "keep-alive" : "close") |
| 123 | + let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: status, headers: headers) |
| 124 | + |
| 125 | + context.write(wrapOutboundOut(.head(head))).whenFailure { error in |
| 126 | + self.logger.error("\(self) write error \(error)") |
| 127 | + } |
| 128 | + |
| 129 | + if let b = body { |
| 130 | + var buffer = context.channel.allocator.buffer(capacity: b.utf8.count) |
| 131 | + buffer.writeString(b) |
| 132 | + context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in |
| 133 | + self.logger.error("\(self) write error \(error)") |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in |
| 138 | + if case .failure(let error) = result { |
| 139 | + self.logger.error("\(self) write error \(error)") |
| 140 | + } |
| 141 | + if !self.self.keepAlive { |
| 142 | + context.close().whenFailure { error in |
| 143 | + self.logger.error("\(self) close error \(error)") |
| 144 | + } |
| 145 | + } |
| 146 | + } |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +internal enum ServerError: Error { |
| 151 | + case notReady |
| 152 | + case cantBind |
| 153 | +} |
| 154 | + |
| 155 | +internal enum AmazonHeaders { |
| 156 | + static let requestID = "Lambda-Runtime-Aws-Request-Id" |
| 157 | + static let traceID = "Lambda-Runtime-Trace-Id" |
| 158 | + static let clientContext = "X-Amz-Client-Context" |
| 159 | + static let cognitoIdentity = "X-Amz-Cognito-Identity" |
| 160 | + static let deadline = "Lambda-Runtime-Deadline-Ms" |
| 161 | + static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" |
| 162 | +} |
| 163 | + |
| 164 | +internal enum Mode: String { |
| 165 | + case string |
| 166 | + case json |
| 167 | +} |
| 168 | + |
| 169 | +func env(_ name: String) -> String? { |
| 170 | + guard let value = getenv(name) else { |
| 171 | + return nil |
| 172 | + } |
| 173 | + return String(utf8String: value) |
| 174 | +} |
| 175 | + |
| 176 | +// main |
| 177 | +let server = MockServer() |
| 178 | +try! server.start() |
| 179 | +dispatchMain() |
0 commit comments