Skip to content

Update to ORCJIT #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 246 additions & 115 deletions Sources/LLVM/JIT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,12 @@ import cllvm
/// JITError represents the different kinds of errors the JIT compiler can
/// throw.
public enum JITError: Error, CustomStringConvertible {
/// The JIT was unable to be initialized. A message is provided explaining
/// the failure.
case couldNotInitialize(String)
case generic(String)

/// The JIT was unable to remove the provided module. A message is provided
/// explaining the failure
case couldNotRemoveModule(Module, String)

/// A human-readable description of the error.
public var description: String {
switch self {
case .couldNotRemoveModule(let module, let message):
return "could not remove module '\(module.name)': \(message)"
case .couldNotInitialize(let message):
return "could not initialize JIT: \(message)"
case let .generic(desc):
return desc
}
}
}
Expand All @@ -28,123 +19,263 @@ public enum JITError: Error, CustomStringConvertible {
/// that has been generated in a `Module`. It can execute arbitrary functions
/// and return the value the function generated, allowing you to write
/// interactive programs that will run as soon as they are compiled.
///
/// The JIT is fundamentally lazy, and allows control over when and how symbols
/// are resolved.
public final class JIT {
public typealias TargetAddress = LLVMOrcTargetAddress
public struct ModuleHandle {
fileprivate var llvm: LLVMOrcModuleHandle
}

/// The underlying LLVMExecutionEngineRef backing this JIT.
internal let llvm: LLVMExecutionEngineRef

private static var linkOnce: () = {
return LLVMLinkInMCJIT()
}()

/// Creates a Just In Time compiler that will compile the code in the
/// provided `Module` to the architecture of the provided `TargetMachine`,
/// and execute it.
///
/// - parameters:
/// - module: The module containing code you wish to execute
/// - machine: The target machine which you're compiling for
/// - throws: JITError
public init(module: Module, machine: TargetMachine) throws {
_ = JIT.linkOnce

var jit: LLVMExecutionEngineRef?
var error: UnsafeMutablePointer<Int8>?
if LLVMCreateExecutionEngineForModule(&jit, module.llvm, &error) != 0 {
let str = String(cString: error!)
throw JITError.couldNotInitialize(str)
internal let llvm: LLVMOrcJITStackRef
private let ownsContext: Bool

internal init(llvm: LLVMOrcJITStackRef, ownsContext: Bool) {
self.llvm = llvm
self.ownsContext = ownsContext
}

/// Create and initialize a `JIT` with this target machine's representation.
public convenience init(machine: TargetMachine) {
// The JIT stack takes ownership of the target machine.
machine.ownsContext = false
self.init(llvm: LLVMOrcCreateInstance(machine.llvm), ownsContext: true)
}

deinit {
guard self.ownsContext else {
return
}
guard let _jit = jit else {
throw JITError.couldNotInitialize("JIT was NULL")
_ = LLVMOrcDisposeInstance(self.llvm)
}

// MARK: Symbols

/// Mangles the given symbol name according to the data layout of the JIT's
/// target machine.
///
/// - parameter symbol: The symbol name to mangle.
/// - returns: A mangled representation of the given symbol name.
public func mangle(symbol: String) -> String {
var mangledResult: UnsafeMutablePointer<Int8>? = nil
LLVMOrcGetMangledSymbol(self.llvm, &mangledResult, symbol)
guard let result = mangledResult else {
fatalError("Mangled name should never be nil!")
}
self.llvm = _jit
LLVMRunStaticConstructors(self.llvm)
}

/// Retrieves a pointer to the function compiled by this JIT.
/// - parameter name: The name of the function you wish to look up.
/// - returns: A pointer to the result of compiling the specified function.
/// - note: You will have to `unsafeBitCast` this pointer to
/// the appropriate `@convention(c)` function type to be
/// able to run it from Swift.
///
/// ```
/// typealias FnPtr = @convention(c) () -> Double
/// let fnAddr = jit.addressOfFunction(name: "test")
/// let fn = unsafeBitCast(fnAddr, to: FnPtr.self)
/// ```
public func addressOfFunction(name: String) -> OpaquePointer? {
let addr = LLVMGetFunctionAddress(llvm, name)
guard addr != 0 else { return nil }
return OpaquePointer(bitPattern: UInt(addr))
}

/// Adds the provided module, and all top-level declarations into this JIT.
/// - parameter module: The module you wish to add.
public func addModule(_ module: Module) {
LLVMAddModule(llvm, module.llvm)
}

/// Removes the provided module, and all top-level declarations, from this
/// JIT.
public func removeModule(_ module: Module) throws {
var outMod: LLVMModuleRef? = module.llvm
var outError: UnsafeMutablePointer<Int8>?
LLVMRemoveModule(llvm, module.llvm, &outMod, &outError)
if let err = outError {
defer { LLVMDisposeMessage(err) }
throw JITError.couldNotRemoveModule(module, String(cString: err))
defer { LLVMOrcDisposeMangledSymbol(mangledResult) }
return String(cString: result)
}

/// Computes the address of the given symbol, optionally restricting the
/// search for its address to a particular module. If this symbol does not
/// exist, an address of `0` is returned.
///
/// - parameter symbol: The symbol name to search for.
/// - parameter module: An optional value describing the module in which to
/// restrict the search, if any.
/// - returns: The address of the symbol, or 0 if it does not exist.
public func address(of symbol: String, in module: ModuleHandle? = nil) throws -> TargetAddress {
var retAddr: TargetAddress = 0
if let targetModule = module {
try checkForJITError(LLVMOrcGetSymbolAddressIn(self.llvm, &retAddr, targetModule.llvm, symbol))
} else {
try checkForJITError(LLVMOrcGetSymbolAddress(self.llvm, &retAddr, symbol))
}
return retAddr
}

// MARK: Lazy Compilation

/// Registers a lazy compile callback that can be used to get the target
/// address of a trampoline function. When that trampoline address is
/// called, the given compilation callback is fired.
///
/// Normally, the trampoline function is a known stub that has been previously
/// registered with the JIT. The callback then computes the address of a
/// known entry point and sets the address of the stub to it. See
/// `JIT.createIndirectStub` to create a stub function and
/// `JIT.setIndirectStubPointer` to set the address of a stub.
///
/// - parameter callback: A callback that returns the actual address of the
/// trampoline function.
/// - returns: The target address representing a stub. Calling this stub
/// forces the given compilation callback to fire.
public func registerLazyCompile(_ callback: @escaping (JIT) -> TargetAddress) throws -> TargetAddress {
var addr: TargetAddress = 0
let callbackContext = ORCLazyCompileCallbackContext(callback)
let contextPtr = Unmanaged<ORCLazyCompileCallbackContext>.passRetained(callbackContext).toOpaque()
try checkForJITError(LLVMOrcCreateLazyCompileCallback(self.llvm, &addr, lazyCompileBlockTrampoline, contextPtr))
return addr
}

// MARK: Stubs

/// Creates a new named indirect stub pointing to the given target address.
///
/// An indirect stub may be resolved to a different address at any time by
/// invoking `JIT.setIndirectStubPointer`.
///
/// - parameter name: The name of the indirect stub.
/// - parameter address: The address of the indirect stub.
public func createIndirectStub(named name: String, address: TargetAddress) throws {
try checkForJITError(LLVMOrcCreateIndirectStub(self.llvm, name, address))
}

/// Resets the address of an indirect stub.
///
/// - warning: The indirect stub must be registered with a call to
/// `JIT.createIndirectStub`. Failure to do so will result in undefined
/// behavior.
///
/// - parameter name: The name of an indirect stub.
/// - parameter address: The address to set the indirect stub to point to.
public func setIndirectStubPointer(named name: String, address: TargetAddress) throws {
try checkForJITError(LLVMOrcSetIndirectStubPointer(self.llvm, name, address))
}

// MARK: Adding Code to the JIT

/// Adds the IR from a given module to the JIT, consuming it in the process.
///
/// Despite the name of this function, the callback to compile the symbols in
/// the module is not necessarily called immediately. It is called at least
/// when a given symbol's address is requested, either by the JIT or by
/// the user e.g. `JIT.address(of:)`.
///
/// The callback function is required to compute the address of the given
/// symbol. The symbols are passed in mangled form. Use
/// `JIT.mangle(symbol:)` to request the mangled name of a symbol.
///
/// - warning: The JIT invalidates the underlying reference to the provided
/// module. Further references to the module are thus dangling pointers and
/// may be a source of subtle memory bugs. This will be addressed in a
/// future revision of LLVM.
///
/// - parameter module: The module to compile.
/// - parameter callback: A function that is called by the JIT to compute the
/// address of symbols.
public func addEagerlyCompiledIR(_ module: Module, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle {
var handle: LLVMOrcModuleHandle = 0
let callbackContext = ORCSymbolCallbackContext(callback)
let contextPtr = Unmanaged<ORCSymbolCallbackContext>.passRetained(callbackContext).toOpaque()
// The JIT stack takes ownership of the given module.
module.ownsContext = false
try checkForJITError(LLVMOrcAddEagerlyCompiledIR(self.llvm, &handle, module.llvm, symbolBlockTrampoline, contextPtr))
return ModuleHandle(llvm: handle)
}

/// Runs the specified function as if it were the `main` function in an
/// executable. It takes an array of argument strings and passes them
/// into the function as `argc` and `argv`.
///
/// - parameters:
/// - function: The `main` function you wish to execute
/// - args: The string arguments you wish to pass to the function
/// - returns: The numerical exit code returned by the function
public func runFunctionAsMain(_ function: Function, args: [String]) -> Int {
// FIXME: Also add in envp.
return withCArrayOfCStrings(args) { buf in
return Int(LLVMRunFunctionAsMain(llvm, function.asLLVM(),
UInt32(buf.count),
buf.baseAddress, nil))
/// Adds the IR from a given module to the JIT, consuming it in the process.
///
/// This function differs from `JIT.addEagerlyCompiledIR` in that the callback
/// to request the address of symbols is only executed when that symbol is
/// called, either in user code or by the JIT.
///
/// The callback function is required to compute the address of the given
/// symbol. The symbols are passed in mangled form. Use
/// `JIT.mangle(symbol:)` to request the mangled name of a symbol.
///
/// - warning: The JIT invalidates the underlying reference to the provided
/// module. Further references to the module are thus dangling pointers and
/// may be a source of subtle memory bugs. This will be addressed in a
/// future revision of LLVM.
///
/// - parameter module: The module to compile.
/// - parameter callback: A function that is called by the JIT to compute the
/// address of symbols.
public func addLazilyCompiledIR(_ module: Module, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle {
var handle: LLVMOrcModuleHandle = 0
let callbackContext = ORCSymbolCallbackContext(callback)
let contextPtr = Unmanaged<ORCSymbolCallbackContext>.passRetained(callbackContext).toOpaque()
// The JIT stack takes ownership of the given module.
module.ownsContext = false
try checkForJITError(LLVMOrcAddLazilyCompiledIR(self.llvm, &handle, module.llvm, symbolBlockTrampoline, contextPtr))
return ModuleHandle(llvm: handle)
}

/// Adds the executable code from an object file to ths JIT, consuming it in
/// the process.
///
/// The callback function is required to compute the address of the given
/// symbol. The symbols are passed in mangled form. Use
/// `JIT.mangle(symbol:)` to request the mangled name of a symbol.
///
/// - warning: The JIT invalidates the underlying reference to the provided
/// memory buffer. Further references to the buffer are thus dangling
/// pointers and may be a source of subtle memory bugs. This will be
/// addressed in a future revision of LLVM.
///
/// - parameter buffer: A buffer containing an object file.
/// - parameter callback: A function that is called by the JIT to compute the
/// address of symbols.
public func addObjectFile(_ buffer: MemoryBuffer, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle {
var handle: LLVMOrcModuleHandle = 0
let callbackContext = ORCSymbolCallbackContext(callback)
let contextPtr = Unmanaged<ORCSymbolCallbackContext>.passRetained(callbackContext).toOpaque()
// The JIT stack takes ownership of the given buffer.
buffer.ownsContext = false
try checkForJITError(LLVMOrcAddObjectFile(self.llvm, &handle, buffer.llvm, symbolBlockTrampoline, contextPtr))
return ModuleHandle(llvm: handle)
}

/// Remove previously-added code from the JIT.
///
/// - warning: Removing a module handle consumes the handle. Further use of
/// the handle will then result in undefined behavior.
///
/// - parameter handle: A handle to previously-added module.
public func removeModule(_ handle: ModuleHandle) throws {
try checkForJITError(LLVMOrcRemoveModule(self.llvm, handle.llvm))
}

private func checkForJITError(_ orcError: LLVMOrcErrorCode) throws {
switch orcError {
case LLVMOrcErrSuccess:
return
case LLVMOrcErrGeneric:
guard let msg = LLVMOrcGetErrorMsg(self.llvm) else {
fatalError("Couldn't get the error message?")
}
throw JITError.generic(String(cString: msg))
default:
fatalError("Uncategorized ORC error code!")
}
}
}

deinit {
LLVMRunStaticDestructors(self.llvm)
private let lazyCompileBlockTrampoline : LLVMOrcLazyCompileCallbackFn = { (callbackJIT, callbackCtx) in
guard let jit = callbackJIT, let ctx = callbackCtx else {
fatalError("Internal JIT callback and context must be non-nil")
}

let tempJIT = JIT(llvm: jit, ownsContext: false)
let callback = Unmanaged<ORCLazyCompileCallbackContext>.fromOpaque(ctx).takeUnretainedValue()
return callback.block(tempJIT)
}

/// Runs the provided block with the equivalent C strings copied from the
/// passed-in array. The C strings will only be alive for the duration
/// of the block, and they will be freed when the block exits.
///
/// - parameters:
/// - strings: The strings you intend to convert to C strings
/// - block: A block that uses the C strings
/// - returns: The result of the passed-in block.
/// - throws: Will only throw if the passed-in block throws.
internal func withCArrayOfCStrings<T>(_ strings: [String], _ block:
(UnsafeBufferPointer<UnsafePointer<Int8>?>) throws -> T) rethrows -> T {
var cStrings = [UnsafeMutablePointer<Int8>?]()
for string in strings {
string.withCString {
cStrings.append(strdup($0))
}
private let symbolBlockTrampoline : LLVMOrcSymbolResolverFn = { (callbackName, callbackCtx) in
guard let cname = callbackName, let ctx = callbackCtx else {
fatalError("Internal JIT name and context must be non-nil")
}
defer {
for cStr in cStrings {
free(cStr)
}

let name = String(cString: cname)
let callback = Unmanaged<ORCSymbolCallbackContext>.fromOpaque(ctx).takeUnretainedValue()
return callback.block(name)
}

private class ORCLazyCompileCallbackContext {
fileprivate let block: (JIT) -> JIT.TargetAddress

fileprivate init(_ block: @escaping (JIT) -> JIT.TargetAddress) {
self.block = block
}
return try cStrings.withUnsafeBufferPointer { buf in
// We need to make this "immutable" but that doesn't change
// their size or contents.
let constPtr = unsafeBitCast(buf.baseAddress,
to: UnsafePointer<UnsafePointer<Int8>?>.self)
return try block(UnsafeBufferPointer(start: constPtr, count: buf.count))
}

private class ORCSymbolCallbackContext {
fileprivate let block: (String) -> JIT.TargetAddress

fileprivate init(_ block: @escaping (String) -> JIT.TargetAddress) {
self.block = block
}
}
Loading