From a70882ee95518085063cd7d6d01206548639ba3d Mon Sep 17 00:00:00 2001 From: Robert Widmann Date: Fri, 19 Oct 2018 17:49:39 -0700 Subject: [PATCH] Update to ORCJIT --- Sources/LLVM/JIT.swift | 361 +++++++++++++++++++++---------- Sources/LLVM/MemoryBuffer.swift | 4 + Sources/LLVM/Module.swift | 4 + Sources/LLVM/TargetMachine.swift | 5 + Tests/LLVMTests/JITSpec.swift | 313 +++++++++++++++++++++------ Tests/LinuxMain.swift | 3 +- 6 files changed, 510 insertions(+), 180 deletions(-) diff --git a/Sources/LLVM/JIT.swift b/Sources/LLVM/JIT.swift index 55261d99..4f930aed 100644 --- a/Sources/LLVM/JIT.swift +++ b/Sources/LLVM/JIT.swift @@ -5,21 +5,12 @@ import cllvm /// JITError represents the different kinds of errors the JIT compiler can /// throw. public enum JITError: Error, CustomStringConvertible { - /// The JIT was unable to be initialized. A message is provided explaining - /// the failure. - case couldNotInitialize(String) + case generic(String) - /// The JIT was unable to remove the provided module. A message is provided - /// explaining the failure - case couldNotRemoveModule(Module, String) - - /// A human-readable description of the error. public var description: String { switch self { - case .couldNotRemoveModule(let module, let message): - return "could not remove module '\(module.name)': \(message)" - case .couldNotInitialize(let message): - return "could not initialize JIT: \(message)" + case let .generic(desc): + return desc } } } @@ -28,123 +19,263 @@ public enum JITError: Error, CustomStringConvertible { /// that has been generated in a `Module`. It can execute arbitrary functions /// and return the value the function generated, allowing you to write /// interactive programs that will run as soon as they are compiled. +/// +/// The JIT is fundamentally lazy, and allows control over when and how symbols +/// are resolved. public final class JIT { + public typealias TargetAddress = LLVMOrcTargetAddress + public struct ModuleHandle { + fileprivate var llvm: LLVMOrcModuleHandle + } + /// The underlying LLVMExecutionEngineRef backing this JIT. - internal let llvm: LLVMExecutionEngineRef - - private static var linkOnce: () = { - return LLVMLinkInMCJIT() - }() - - /// Creates a Just In Time compiler that will compile the code in the - /// provided `Module` to the architecture of the provided `TargetMachine`, - /// and execute it. - /// - /// - parameters: - /// - module: The module containing code you wish to execute - /// - machine: The target machine which you're compiling for - /// - throws: JITError - public init(module: Module, machine: TargetMachine) throws { - _ = JIT.linkOnce - - var jit: LLVMExecutionEngineRef? - var error: UnsafeMutablePointer? - if LLVMCreateExecutionEngineForModule(&jit, module.llvm, &error) != 0 { - let str = String(cString: error!) - throw JITError.couldNotInitialize(str) + internal let llvm: LLVMOrcJITStackRef + private let ownsContext: Bool + + internal init(llvm: LLVMOrcJITStackRef, ownsContext: Bool) { + self.llvm = llvm + self.ownsContext = ownsContext + } + + /// Create and initialize a `JIT` with this target machine's representation. + public convenience init(machine: TargetMachine) { + // The JIT stack takes ownership of the target machine. + machine.ownsContext = false + self.init(llvm: LLVMOrcCreateInstance(machine.llvm), ownsContext: true) + } + + deinit { + guard self.ownsContext else { + return } - guard let _jit = jit else { - throw JITError.couldNotInitialize("JIT was NULL") + _ = LLVMOrcDisposeInstance(self.llvm) + } + + // MARK: Symbols + + /// Mangles the given symbol name according to the data layout of the JIT's + /// target machine. + /// + /// - parameter symbol: The symbol name to mangle. + /// - returns: A mangled representation of the given symbol name. + public func mangle(symbol: String) -> String { + var mangledResult: UnsafeMutablePointer? = nil + LLVMOrcGetMangledSymbol(self.llvm, &mangledResult, symbol) + guard let result = mangledResult else { + fatalError("Mangled name should never be nil!") } - self.llvm = _jit - LLVMRunStaticConstructors(self.llvm) - } - - /// Retrieves a pointer to the function compiled by this JIT. - /// - parameter name: The name of the function you wish to look up. - /// - returns: A pointer to the result of compiling the specified function. - /// - note: You will have to `unsafeBitCast` this pointer to - /// the appropriate `@convention(c)` function type to be - /// able to run it from Swift. - /// - /// ``` - /// typealias FnPtr = @convention(c) () -> Double - /// let fnAddr = jit.addressOfFunction(name: "test") - /// let fn = unsafeBitCast(fnAddr, to: FnPtr.self) - /// ``` - public func addressOfFunction(name: String) -> OpaquePointer? { - let addr = LLVMGetFunctionAddress(llvm, name) - guard addr != 0 else { return nil } - return OpaquePointer(bitPattern: UInt(addr)) - } - - /// Adds the provided module, and all top-level declarations into this JIT. - /// - parameter module: The module you wish to add. - public func addModule(_ module: Module) { - LLVMAddModule(llvm, module.llvm) - } - - /// Removes the provided module, and all top-level declarations, from this - /// JIT. - public func removeModule(_ module: Module) throws { - var outMod: LLVMModuleRef? = module.llvm - var outError: UnsafeMutablePointer? - LLVMRemoveModule(llvm, module.llvm, &outMod, &outError) - if let err = outError { - defer { LLVMDisposeMessage(err) } - throw JITError.couldNotRemoveModule(module, String(cString: err)) + defer { LLVMOrcDisposeMangledSymbol(mangledResult) } + return String(cString: result) + } + + /// Computes the address of the given symbol, optionally restricting the + /// search for its address to a particular module. If this symbol does not + /// exist, an address of `0` is returned. + /// + /// - parameter symbol: The symbol name to search for. + /// - parameter module: An optional value describing the module in which to + /// restrict the search, if any. + /// - returns: The address of the symbol, or 0 if it does not exist. + public func address(of symbol: String, in module: ModuleHandle? = nil) throws -> TargetAddress { + var retAddr: TargetAddress = 0 + if let targetModule = module { + try checkForJITError(LLVMOrcGetSymbolAddressIn(self.llvm, &retAddr, targetModule.llvm, symbol)) + } else { + try checkForJITError(LLVMOrcGetSymbolAddress(self.llvm, &retAddr, symbol)) } + return retAddr + } + + // MARK: Lazy Compilation + + /// Registers a lazy compile callback that can be used to get the target + /// address of a trampoline function. When that trampoline address is + /// called, the given compilation callback is fired. + /// + /// Normally, the trampoline function is a known stub that has been previously + /// registered with the JIT. The callback then computes the address of a + /// known entry point and sets the address of the stub to it. See + /// `JIT.createIndirectStub` to create a stub function and + /// `JIT.setIndirectStubPointer` to set the address of a stub. + /// + /// - parameter callback: A callback that returns the actual address of the + /// trampoline function. + /// - returns: The target address representing a stub. Calling this stub + /// forces the given compilation callback to fire. + public func registerLazyCompile(_ callback: @escaping (JIT) -> TargetAddress) throws -> TargetAddress { + var addr: TargetAddress = 0 + let callbackContext = ORCLazyCompileCallbackContext(callback) + let contextPtr = Unmanaged.passRetained(callbackContext).toOpaque() + try checkForJITError(LLVMOrcCreateLazyCompileCallback(self.llvm, &addr, lazyCompileBlockTrampoline, contextPtr)) + return addr + } + + // MARK: Stubs + + /// Creates a new named indirect stub pointing to the given target address. + /// + /// An indirect stub may be resolved to a different address at any time by + /// invoking `JIT.setIndirectStubPointer`. + /// + /// - parameter name: The name of the indirect stub. + /// - parameter address: The address of the indirect stub. + public func createIndirectStub(named name: String, address: TargetAddress) throws { + try checkForJITError(LLVMOrcCreateIndirectStub(self.llvm, name, address)) + } + + /// Resets the address of an indirect stub. + /// + /// - warning: The indirect stub must be registered with a call to + /// `JIT.createIndirectStub`. Failure to do so will result in undefined + /// behavior. + /// + /// - parameter name: The name of an indirect stub. + /// - parameter address: The address to set the indirect stub to point to. + public func setIndirectStubPointer(named name: String, address: TargetAddress) throws { + try checkForJITError(LLVMOrcSetIndirectStubPointer(self.llvm, name, address)) + } + + // MARK: Adding Code to the JIT + + /// Adds the IR from a given module to the JIT, consuming it in the process. + /// + /// Despite the name of this function, the callback to compile the symbols in + /// the module is not necessarily called immediately. It is called at least + /// when a given symbol's address is requested, either by the JIT or by + /// the user e.g. `JIT.address(of:)`. + /// + /// The callback function is required to compute the address of the given + /// symbol. The symbols are passed in mangled form. Use + /// `JIT.mangle(symbol:)` to request the mangled name of a symbol. + /// + /// - warning: The JIT invalidates the underlying reference to the provided + /// module. Further references to the module are thus dangling pointers and + /// may be a source of subtle memory bugs. This will be addressed in a + /// future revision of LLVM. + /// + /// - parameter module: The module to compile. + /// - parameter callback: A function that is called by the JIT to compute the + /// address of symbols. + public func addEagerlyCompiledIR(_ module: Module, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle { + var handle: LLVMOrcModuleHandle = 0 + let callbackContext = ORCSymbolCallbackContext(callback) + let contextPtr = Unmanaged.passRetained(callbackContext).toOpaque() + // The JIT stack takes ownership of the given module. + module.ownsContext = false + try checkForJITError(LLVMOrcAddEagerlyCompiledIR(self.llvm, &handle, module.llvm, symbolBlockTrampoline, contextPtr)) + return ModuleHandle(llvm: handle) } - /// Runs the specified function as if it were the `main` function in an - /// executable. It takes an array of argument strings and passes them - /// into the function as `argc` and `argv`. - /// - /// - parameters: - /// - function: The `main` function you wish to execute - /// - args: The string arguments you wish to pass to the function - /// - returns: The numerical exit code returned by the function - public func runFunctionAsMain(_ function: Function, args: [String]) -> Int { - // FIXME: Also add in envp. - return withCArrayOfCStrings(args) { buf in - return Int(LLVMRunFunctionAsMain(llvm, function.asLLVM(), - UInt32(buf.count), - buf.baseAddress, nil)) + /// Adds the IR from a given module to the JIT, consuming it in the process. + /// + /// This function differs from `JIT.addEagerlyCompiledIR` in that the callback + /// to request the address of symbols is only executed when that symbol is + /// called, either in user code or by the JIT. + /// + /// The callback function is required to compute the address of the given + /// symbol. The symbols are passed in mangled form. Use + /// `JIT.mangle(symbol:)` to request the mangled name of a symbol. + /// + /// - warning: The JIT invalidates the underlying reference to the provided + /// module. Further references to the module are thus dangling pointers and + /// may be a source of subtle memory bugs. This will be addressed in a + /// future revision of LLVM. + /// + /// - parameter module: The module to compile. + /// - parameter callback: A function that is called by the JIT to compute the + /// address of symbols. + public func addLazilyCompiledIR(_ module: Module, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle { + var handle: LLVMOrcModuleHandle = 0 + let callbackContext = ORCSymbolCallbackContext(callback) + let contextPtr = Unmanaged.passRetained(callbackContext).toOpaque() + // The JIT stack takes ownership of the given module. + module.ownsContext = false + try checkForJITError(LLVMOrcAddLazilyCompiledIR(self.llvm, &handle, module.llvm, symbolBlockTrampoline, contextPtr)) + return ModuleHandle(llvm: handle) + } + + /// Adds the executable code from an object file to ths JIT, consuming it in + /// the process. + /// + /// The callback function is required to compute the address of the given + /// symbol. The symbols are passed in mangled form. Use + /// `JIT.mangle(symbol:)` to request the mangled name of a symbol. + /// + /// - warning: The JIT invalidates the underlying reference to the provided + /// memory buffer. Further references to the buffer are thus dangling + /// pointers and may be a source of subtle memory bugs. This will be + /// addressed in a future revision of LLVM. + /// + /// - parameter buffer: A buffer containing an object file. + /// - parameter callback: A function that is called by the JIT to compute the + /// address of symbols. + public func addObjectFile(_ buffer: MemoryBuffer, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle { + var handle: LLVMOrcModuleHandle = 0 + let callbackContext = ORCSymbolCallbackContext(callback) + let contextPtr = Unmanaged.passRetained(callbackContext).toOpaque() + // The JIT stack takes ownership of the given buffer. + buffer.ownsContext = false + try checkForJITError(LLVMOrcAddObjectFile(self.llvm, &handle, buffer.llvm, symbolBlockTrampoline, contextPtr)) + return ModuleHandle(llvm: handle) + } + + /// Remove previously-added code from the JIT. + /// + /// - warning: Removing a module handle consumes the handle. Further use of + /// the handle will then result in undefined behavior. + /// + /// - parameter handle: A handle to previously-added module. + public func removeModule(_ handle: ModuleHandle) throws { + try checkForJITError(LLVMOrcRemoveModule(self.llvm, handle.llvm)) + } + + private func checkForJITError(_ orcError: LLVMOrcErrorCode) throws { + switch orcError { + case LLVMOrcErrSuccess: + return + case LLVMOrcErrGeneric: + guard let msg = LLVMOrcGetErrorMsg(self.llvm) else { + fatalError("Couldn't get the error message?") + } + throw JITError.generic(String(cString: msg)) + default: + fatalError("Uncategorized ORC error code!") } } +} - deinit { - LLVMRunStaticDestructors(self.llvm) +private let lazyCompileBlockTrampoline : LLVMOrcLazyCompileCallbackFn = { (callbackJIT, callbackCtx) in + guard let jit = callbackJIT, let ctx = callbackCtx else { + fatalError("Internal JIT callback and context must be non-nil") } + + let tempJIT = JIT(llvm: jit, ownsContext: false) + let callback = Unmanaged.fromOpaque(ctx).takeUnretainedValue() + return callback.block(tempJIT) } -/// Runs the provided block with the equivalent C strings copied from the -/// passed-in array. The C strings will only be alive for the duration -/// of the block, and they will be freed when the block exits. -/// -/// - parameters: -/// - strings: The strings you intend to convert to C strings -/// - block: A block that uses the C strings -/// - returns: The result of the passed-in block. -/// - throws: Will only throw if the passed-in block throws. -internal func withCArrayOfCStrings(_ strings: [String], _ block: - (UnsafeBufferPointer?>) throws -> T) rethrows -> T { - var cStrings = [UnsafeMutablePointer?]() - for string in strings { - string.withCString { - cStrings.append(strdup($0)) - } +private let symbolBlockTrampoline : LLVMOrcSymbolResolverFn = { (callbackName, callbackCtx) in + guard let cname = callbackName, let ctx = callbackCtx else { + fatalError("Internal JIT name and context must be non-nil") } - defer { - for cStr in cStrings { - free(cStr) - } + + let name = String(cString: cname) + let callback = Unmanaged.fromOpaque(ctx).takeUnretainedValue() + return callback.block(name) +} + +private class ORCLazyCompileCallbackContext { + fileprivate let block: (JIT) -> JIT.TargetAddress + + fileprivate init(_ block: @escaping (JIT) -> JIT.TargetAddress) { + self.block = block } - return try cStrings.withUnsafeBufferPointer { buf in - // We need to make this "immutable" but that doesn't change - // their size or contents. - let constPtr = unsafeBitCast(buf.baseAddress, - to: UnsafePointer?>.self) - return try block(UnsafeBufferPointer(start: constPtr, count: buf.count)) +} + +private class ORCSymbolCallbackContext { + fileprivate let block: (String) -> JIT.TargetAddress + + fileprivate init(_ block: @escaping (String) -> JIT.TargetAddress) { + self.block = block } } diff --git a/Sources/LLVM/MemoryBuffer.swift b/Sources/LLVM/MemoryBuffer.swift index 417ceabc..7ee70ea3 100644 --- a/Sources/LLVM/MemoryBuffer.swift +++ b/Sources/LLVM/MemoryBuffer.swift @@ -21,6 +21,7 @@ public enum MemoryBufferError: Error { /// position to see if it has reached the end of the file. public class MemoryBuffer: Sequence { let llvm: LLVMMemoryBufferRef + internal var ownsContext: Bool = true /// Creates a `MemoryBuffer` with the contents of `stdin`, stopping once /// `EOF` is read. @@ -115,6 +116,9 @@ public class MemoryBuffer: Sequence { } deinit { + guard self.ownsContext else { + return + } LLVMDisposeMemoryBuffer(llvm) } } diff --git a/Sources/LLVM/Module.swift b/Sources/LLVM/Module.swift index d8a6bb27..faebe392 100644 --- a/Sources/LLVM/Module.swift +++ b/Sources/LLVM/Module.swift @@ -59,6 +59,7 @@ public enum ModuleError: Error, CustomStringConvertible { /// units merged together. public final class Module: CustomStringConvertible { internal let llvm: LLVMModuleRef + internal var ownsContext: Bool = true /// Creates a `Module` with the given name. /// @@ -261,6 +262,9 @@ public final class Module: CustomStringConvertible { } deinit { + guard self.ownsContext else { + return + } LLVMDisposeModule(llvm) } } diff --git a/Sources/LLVM/TargetMachine.swift b/Sources/LLVM/TargetMachine.swift index 9c528218..66cfb313 100644 --- a/Sources/LLVM/TargetMachine.swift +++ b/Sources/LLVM/TargetMachine.swift @@ -117,6 +117,8 @@ public class TargetMachine { return String(validatingUTF8: UnsafePointer(str)) ?? "" } + internal var ownsContext: Bool = true + /// Creates a Target Machine with information about its target environment. /// /// - parameter triple: An optional target triple to target. If this is not @@ -217,6 +219,9 @@ public class TargetMachine { } deinit { + guard self.ownsContext else { + return + } LLVMDisposeTargetMachine(llvm) } } diff --git a/Tests/LLVMTests/JITSpec.swift b/Tests/LLVMTests/JITSpec.swift index 06c2e9ae..519d208b 100644 --- a/Tests/LLVMTests/JITSpec.swift +++ b/Tests/LLVMTests/JITSpec.swift @@ -3,76 +3,261 @@ import XCTest import FileCheck import Foundation +// NB: Marking this function `public` is the safest way to make sure it gets +// emitted. +public func calculateFibs(_ forward: Bool) -> Double { + if forward { + return 1/109 + } else { + return 1/89 + } +} + +typealias FnPtr = @convention(c) (Bool) -> Double +private func getUnderlyingCDecl(_ function: FnPtr) -> JIT.TargetAddress { + return withoutActuallyEscaping(function) { fn in + return unsafeBitCast(fn, to: JIT.TargetAddress.self) + } +} + class JITSpec : XCTestCase { - func testFibonacci() { - XCTAssert(fileCheckOutput(withPrefixes: ["JIT"]) { - let module = Module(name: "Fibonacci") - let builder = IRBuilder(module: module) - - let function = builder.addFunction( - "calculateFibs", - type: FunctionType(argTypes: [IntType.int1], - returnType: FloatType.double) - ) - let entryBB = function.appendBasicBlock(named: "entry") - builder.positionAtEnd(of: entryBB) - - // allocate space for a local value - let local = builder.buildAlloca(type: FloatType.double, name: "local") - - // Compare to the condition - let test = builder.buildICmp(function.parameters[0], IntType.int1.zero(), .equal) - - // Create basic blocks for "then", "else", and "merge" - let thenBB = function.appendBasicBlock(named: "then") - let elseBB = function.appendBasicBlock(named: "else") - let mergeBB = function.appendBasicBlock(named: "merge") - - builder.buildCondBr(condition: test, then: thenBB, else: elseBB) - - // MARK: Then Block - - builder.positionAtEnd(of: thenBB) - // local = 1/89, the fibonacci series (sort of) - let thenVal = FloatType.double.constant(1/89) - // Branch to the merge block - builder.buildBr(mergeBB) - - // MARK: Else Block - builder.positionAtEnd(of: elseBB) - // local = 1/109, the fibonacci series (sort of) backwards - let elseVal = FloatType.double.constant(1/109) - // Branch to the merge block - builder.buildBr(mergeBB) - - // MARK: Merge Block - - builder.positionAtEnd(of: mergeBB) - let phi = builder.buildPhi(FloatType.double, name: "phi_example") - phi.addIncoming([ - (thenVal, thenBB), - (elseVal, elseBB), - ]) - builder.buildStore(phi, to: local) - let ret = builder.buildLoad(local, name: "ret") - builder.buildRet(ret) - - // Setup the JIT - let jit = try! JIT(module: module, machine: TargetMachine()) - typealias FnPtr = @convention(c) (Bool) -> Double - // Retrieve a handle to the function we're going to invoke - let fnAddr = jit.addressOfFunction(name: "calculateFibs") - let fn = unsafeBitCast(fnAddr, to: FnPtr.self) - // JIT: 0.009174311926605505 - print(fn(true)) - // JIT-NEXT: 0.011235955056179775 - print(fn(false)) + typealias MainFnPtr = @convention(c) () -> () + + func buildTestModule() -> Module { + let module = Module(name: "Fibonacci") + let builder = IRBuilder(module: module) + + var llvmSwiftFn = builder.addFunction( + "calculateSwiftFibs", + type: FunctionType(argTypes: [IntType.int1], + returnType: FloatType.double) + ) + llvmSwiftFn.linkage = .external + + let function = builder.addFunction( + "calculateFibs", + type: FunctionType(argTypes: [IntType.int1], + returnType: FloatType.double) + ) + let entryBB = function.appendBasicBlock(named: "entry") + builder.positionAtEnd(of: entryBB) + + // allocate space for a local value + let local = builder.buildAlloca(type: FloatType.double, name: "local") + + // Compare to the condition + let test = builder.buildICmp(function.parameters[0], IntType.int1.zero(), .equal) + + // Create basic blocks for "then", "else", and "merge" + let thenBB = function.appendBasicBlock(named: "then") + let elseBB = function.appendBasicBlock(named: "else") + let mergeBB = function.appendBasicBlock(named: "merge") + + builder.buildCondBr(condition: test, then: thenBB, else: elseBB) + + // MARK: Then Block + + builder.positionAtEnd(of: thenBB) + // local = 1/89, the fibonacci series (sort of) + let thenVal = FloatType.double.constant(1/89) + // Branch to the merge block + builder.buildBr(mergeBB) + + // MARK: Else Block + builder.positionAtEnd(of: elseBB) + // local = 1/109, the fibonacci series (sort of) backwards + let elseVal = FloatType.double.constant(1/109) + // Branch to the merge block + builder.buildBr(mergeBB) + + // MARK: Merge Block + + builder.positionAtEnd(of: mergeBB) + let phi = builder.buildPhi(FloatType.double, name: "phi_example") + phi.addIncoming([ + (thenVal, thenBB), + (elseVal, elseBB), + ]) + builder.buildStore(phi, to: local) + let ret = builder.buildLoad(local, name: "ret") + builder.buildRet(ret) + + let main = builder.addFunction("main", type: FunctionType(argTypes: [], returnType: VoidType())) + let mainEntry = main.appendBasicBlock(named: "entry") + builder.positionAtEnd(of: mainEntry) + _ = builder.buildCall(llvmSwiftFn, args: [ IntType.int1.constant(1) ]) + _ = builder.buildCall(function, args: [ IntType.int1.constant(1) ]) + builder.buildRetVoid() + + return module + } + + + func testEagerIRCompilation() { + XCTAssert(fileCheckOutput(withPrefixes: ["JIT-EAGER-COMPILE"]) { + do { + let jit = try JIT(machine: TargetMachine()) + let module = buildTestModule() + + let testFuncName = jit.mangle(symbol: "calculateSwiftFibs") + var gotForced = false + _ = try jit.addEagerlyCompiledIR(module) { (name) -> JIT.TargetAddress in + gotForced = true + guard name == testFuncName else { + return 0 + } + return getUnderlyingCDecl(calculateFibs) + } + + XCTAssertFalse(gotForced) + let fibsAddr = try jit.address(of: "calculateFibs") + XCTAssertTrue(gotForced) + let fibFn = unsafeBitCast(fibsAddr, to: FnPtr.self) + // JIT-EAGER-COMPILE: 0.009174311926605505 + print(fibFn(true)) + // JIT-EAGER-COMPILE: 0.011235955056179775 + print(fibFn(false)) + } catch _ { + XCTFail() + } + }) + } + + func testLazyIRCompilation() { + XCTAssert(fileCheckOutput(withPrefixes: ["JIT-LAZY-COMPILE"]) { + do { + let jit = try JIT(machine: TargetMachine()) + let module = buildTestModule() + + let testFuncName = jit.mangle(symbol: "calculateSwiftFibs") + var gotForced = false + _ = try jit.addLazilyCompiledIR(module) { (name) -> JIT.TargetAddress in + gotForced = true + guard name == testFuncName else { + return 0 + } + return getUnderlyingCDecl(calculateFibs) + } + + XCTAssertFalse(gotForced) + let fibsAddr = try jit.address(of: "calculateFibs") + let fibFn = unsafeBitCast(fibsAddr, to: FnPtr.self) + // JIT-LAZY-COMPILE: 0.009174311926605505 + print(fibFn(true)) + XCTAssertFalse(gotForced) + // JIT-LAZY-COMPILE-NEXT: 0.011235955056179775 + print(fibFn(false)) + XCTAssertFalse(gotForced) + + let mainAddr = try jit.address(of: "main") + let mainFn = unsafeBitCast(mainAddr, to: MainFnPtr.self) + mainFn() + XCTAssertTrue(gotForced) + } catch _ { + XCTFail() + } + }) + } + + func testAddObjectFile() { + do { + let module = buildTestModule() + let targetMachine = try TargetMachine() + let objBuffer = try targetMachine.emitToMemoryBuffer(module: module, type: .object) + + let jit = JIT(machine: targetMachine) + let testFuncName = jit.mangle(symbol: "calculateSwiftFibs") + _ = try jit.addObjectFile(objBuffer) { (name) -> JIT.TargetAddress in + guard name == testFuncName else { + return 0 + } + return getUnderlyingCDecl(calculateFibs) + } + let mainAddr = try jit.address(of: "main") + XCTAssert(mainAddr != 0) + } catch _ { + XCTFail() + } + } + + func testDirectCallbacks() { + XCTAssert(fileCheckOutput(withPrefixes: ["JIT-DIRECT-CALLBACK"]) { + do { + let jit = try JIT(machine: TargetMachine()) + + let testFuncName = jit.mangle(symbol: "calculateSwiftFibs") + let ccAddr = try jit.registerLazyCompile({ (jit) -> JIT.TargetAddress in + let sm = self.buildTestModule() + _ = try! jit.addEagerlyCompiledIR(sm) { (name) -> JIT.TargetAddress in + guard name == testFuncName else { + return 0 + } + return getUnderlyingCDecl(calculateFibs) + } + let fibsAddr = try! jit.address(of: "calculateFibs") + try! jit.setIndirectStubPointer(named: "force", address: fibsAddr) + return fibsAddr + }) + try jit.createIndirectStub(named: "force", address: ccAddr) + let fooAddr = try jit.address(of: "force") + let fooFn = unsafeBitCast(fooAddr, to: FnPtr.self) + // JIT-DIRECT-CALLBACK: 0.009174311926605505 + print(fooFn(true)) + // JIT-DIRECT-CALLBACK-NEXT: 0.011235955056179775 + print(fooFn(false)) + } catch _ { + XCTFail() + } + }) + } + + func testDirectCallBackToSwift() { + XCTAssert(fileCheckOutput(withPrefixes: ["JIT-SWIFT-CALLBACK"]) { + do { + let jit = try JIT(machine: TargetMachine()) + + let testFuncName = jit.mangle(symbol: "calculateSwiftFibs") + var gotForced = false + let ccAddr = try jit.registerLazyCompile { (jit) -> JIT.TargetAddress in + gotForced = true + let sm = self.buildTestModule() + _ = try! jit.addEagerlyCompiledIR(sm) { (name) -> JIT.TargetAddress in + guard name == testFuncName else { + return 0 + } + return getUnderlyingCDecl(calculateFibs) + } + let mainAddr = getUnderlyingCDecl(calculateFibs) + try! jit.setIndirectStubPointer(named: "force", address: mainAddr) + return mainAddr + } + + // Ensure the main entry point is compiled, causing calculateSwiftFibs + // to be lazily compiled. + try jit.createIndirectStub(named: "force", address: ccAddr) + let forceAddr = try jit.address(of: "force") + let forceFn = unsafeBitCast(forceAddr, to: FnPtr.self) + + XCTAssertFalse(gotForced) + // JIT-SWIFT-CALLBACK: 0.009174311926605505 + print(forceFn(true)) + // JIT-SWIFT-CALLBACK-NEXT: 0.011235955056179775 + print(forceFn(false)) + XCTAssertTrue(gotForced) + } catch _ { + XCTFail() + } }) } + // FIXME: These tests cannot run on Linux without SEGFAULT'ing. #if !os(macOS) static var allTests = testCase([ - ("testFibonacci", testFibonacci), + ("testEagerIRCompilation", testEagerIRCompilation), + ("testLazyIRCompilation", testLazyIRCompilation), + ("testAddObjectFile", testAddObjectFile), + ("testDirectCallbacks", testDirectCallbacks), + ("testDirectCallBackToSwift", testDirectCallBackToSwift), ]) #endif } diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 777946cb..a612e8a7 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -11,7 +11,8 @@ XCTMain([ IRExceptionSpec.allTests, IRGlobalSpec.allTests, IROperationSpec.allTests, - JITSpec.allTests, + // FIXME: These tests cannot run on Linux without SEGFAULT'ing. + // JITSpec.allTests, ModuleLinkSpec.allTests, ]) #endif