diff --git a/Foundation/NSURLSession/NSURLSessionTask.swift b/Foundation/NSURLSession/NSURLSessionTask.swift index e44fb9da9f..7b6959af35 100644 --- a/Foundation/NSURLSession/NSURLSessionTask.swift +++ b/Foundation/NSURLSession/NSURLSessionTask.swift @@ -554,9 +554,31 @@ fileprivate extension URLSessionTask { // HTTP Options: easyHandle.set(followLocation: false) + + // The httpAdditionalHeaders from session configuration has to be added to the request. + // The request.allHTTPHeaders can override the httpAdditionalHeaders elements. Add the + // httpAdditionalHeaders from session configuration first and then append/update the + // request.allHTTPHeaders so that request.allHTTPHeaders can override httpAdditionalHeaders. + + let httpSession = session as! URLSession + var httpHeaders: [AnyHashable : Any]? + + if let hh = httpSession.configuration.httpAdditionalHeaders { + httpHeaders = hh + } + + if let hh = currentRequest?.allHTTPHeaderFields { + if httpHeaders == nil { + httpHeaders = hh + } else { + hh.forEach { + httpHeaders![$0] = $1 + } + } + } let customHeaders: [String] - let headersForRequest = curlHeaders(for: request) + let headersForRequest = curlHeaders(for: httpHeaders) if ((request.httpMethod == "POST") && (request.value(forHTTPHeaderField: "Content-Type") == nil)) { customHeaders = headersForRequest + ["Content-Type:application/x-www-form-urlencoded"] } else { @@ -570,8 +592,7 @@ fileprivate extension URLSessionTask { //set the request timeout //TODO: the timeout value needs to be reset on every data transfer - let s = session as! URLSession - let timeoutInterval = Int(s.configuration.timeoutIntervalForRequest) * 1000 + let timeoutInterval = Int(httpSession.configuration.timeoutIntervalForRequest) * 1000 let timeoutHandler = DispatchWorkItem { [weak self] in guard let currentTask = self else { fatalError("Timeout on a task that doesn't exist") } //this guard must always pass currentTask.internalState = .transferFailed @@ -597,10 +618,11 @@ fileprivate extension URLSessionTask { /// expects. /// /// - SeeAlso: https://curl.haxx.se/libcurl/c/CURLOPT_HTTPHEADER.html - func curlHeaders(for request: URLRequest) -> [String] { + func curlHeaders(for httpHeaders: [AnyHashable : Any]?) -> [String] { var result: [String] = [] var names = Set() - if let hh = currentRequest?.allHTTPHeaderFields { + if httpHeaders != nil { + let hh = httpHeaders as! [String:String] hh.forEach { let name = $0.0.lowercased() guard !names.contains(name) else { return } diff --git a/TestFoundation/TestNSURLSession.swift b/TestFoundation/TestNSURLSession.swift index 022c601577..2b5e6da134 100644 --- a/TestFoundation/TestNSURLSession.swift +++ b/TestFoundation/TestNSURLSession.swift @@ -35,6 +35,7 @@ class TestURLSession : XCTestCase { ("test_cancelTask", test_cancelTask), ("test_taskTimeout", test_taskTimeout), ("test_verifyRequestHeaders", test_verifyRequestHeaders), + ("test_verifyHttpAdditionalHeaders", test_verifyHttpAdditionalHeaders), ] } @@ -349,6 +350,41 @@ class TestURLSession : XCTestCase { waitForExpectations(timeout: 30) } + // Verify httpAdditionalHeaders from session configuration are added to the request + // and whether it is overriden by Request.allHTTPHeaderFields. + + func test_verifyHttpAdditionalHeaders() { + let serverReady = ServerSemaphore() + globalDispatchQueue.async { + do { + try self.runServer(with: serverReady) + } catch { + XCTAssertTrue(true) + return + } + } + serverReady.wait() + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = 5 + config.httpAdditionalHeaders = ["header2": "svalue2", "header3": "svalue3"] + let session = URLSession(configuration: config, delegate: nil, delegateQueue: nil) + var expect = expectation(description: "download task with handler") + var req = URLRequest(url: URL(string: "http://127.0.0.1:\(serverPort)/requestHeaders")!) + let headers = ["header1": "rvalue1", "header2": "rvalue2"] + req.httpMethod = "POST" + req.allHTTPHeaderFields = headers + var task = session.dataTask(with: req) { (data, _, error) -> Void in + defer { expect.fulfill() } + let headers = String(data: data!, encoding: String.Encoding.utf8)! + XCTAssertNotNil(headers.range(of: "header1: rvalue1")) + XCTAssertNotNil(headers.range(of: "header2: rvalue2")) + XCTAssertNotNil(headers.range(of: "header3: svalue3")) + } + task.resume() + + waitForExpectations(timeout: 30) + } + func test_taskTimeout() { let serverReady = ServerSemaphore() globalDispatchQueue.async {