Skip to content

Commit ddd0839

Browse files
authored
Merge pull request #158 from CodaFi/ogre-battle
Update to ORCJIT
2 parents faa2ada + a70882e commit ddd0839

File tree

6 files changed

+510
-180
lines changed

6 files changed

+510
-180
lines changed

Sources/LLVM/JIT.swift

Lines changed: 246 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,12 @@ import cllvm
55
/// JITError represents the different kinds of errors the JIT compiler can
66
/// throw.
77
public enum JITError: Error, CustomStringConvertible {
8-
/// The JIT was unable to be initialized. A message is provided explaining
9-
/// the failure.
10-
case couldNotInitialize(String)
8+
case generic(String)
119

12-
/// The JIT was unable to remove the provided module. A message is provided
13-
/// explaining the failure
14-
case couldNotRemoveModule(Module, String)
15-
16-
/// A human-readable description of the error.
1710
public var description: String {
1811
switch self {
19-
case .couldNotRemoveModule(let module, let message):
20-
return "could not remove module '\(module.name)': \(message)"
21-
case .couldNotInitialize(let message):
22-
return "could not initialize JIT: \(message)"
12+
case let .generic(desc):
13+
return desc
2314
}
2415
}
2516
}
@@ -28,123 +19,263 @@ public enum JITError: Error, CustomStringConvertible {
2819
/// that has been generated in a `Module`. It can execute arbitrary functions
2920
/// and return the value the function generated, allowing you to write
3021
/// interactive programs that will run as soon as they are compiled.
22+
///
23+
/// The JIT is fundamentally lazy, and allows control over when and how symbols
24+
/// are resolved.
3125
public final class JIT {
26+
public typealias TargetAddress = LLVMOrcTargetAddress
27+
public struct ModuleHandle {
28+
fileprivate var llvm: LLVMOrcModuleHandle
29+
}
30+
3231
/// The underlying LLVMExecutionEngineRef backing this JIT.
33-
internal let llvm: LLVMExecutionEngineRef
34-
35-
private static var linkOnce: () = {
36-
return LLVMLinkInMCJIT()
37-
}()
38-
39-
/// Creates a Just In Time compiler that will compile the code in the
40-
/// provided `Module` to the architecture of the provided `TargetMachine`,
41-
/// and execute it.
42-
///
43-
/// - parameters:
44-
/// - module: The module containing code you wish to execute
45-
/// - machine: The target machine which you're compiling for
46-
/// - throws: JITError
47-
public init(module: Module, machine: TargetMachine) throws {
48-
_ = JIT.linkOnce
49-
50-
var jit: LLVMExecutionEngineRef?
51-
var error: UnsafeMutablePointer<Int8>?
52-
if LLVMCreateExecutionEngineForModule(&jit, module.llvm, &error) != 0 {
53-
let str = String(cString: error!)
54-
throw JITError.couldNotInitialize(str)
32+
internal let llvm: LLVMOrcJITStackRef
33+
private let ownsContext: Bool
34+
35+
internal init(llvm: LLVMOrcJITStackRef, ownsContext: Bool) {
36+
self.llvm = llvm
37+
self.ownsContext = ownsContext
38+
}
39+
40+
/// Create and initialize a `JIT` with this target machine's representation.
41+
public convenience init(machine: TargetMachine) {
42+
// The JIT stack takes ownership of the target machine.
43+
machine.ownsContext = false
44+
self.init(llvm: LLVMOrcCreateInstance(machine.llvm), ownsContext: true)
45+
}
46+
47+
deinit {
48+
guard self.ownsContext else {
49+
return
5550
}
56-
guard let _jit = jit else {
57-
throw JITError.couldNotInitialize("JIT was NULL")
51+
_ = LLVMOrcDisposeInstance(self.llvm)
52+
}
53+
54+
// MARK: Symbols
55+
56+
/// Mangles the given symbol name according to the data layout of the JIT's
57+
/// target machine.
58+
///
59+
/// - parameter symbol: The symbol name to mangle.
60+
/// - returns: A mangled representation of the given symbol name.
61+
public func mangle(symbol: String) -> String {
62+
var mangledResult: UnsafeMutablePointer<Int8>? = nil
63+
LLVMOrcGetMangledSymbol(self.llvm, &mangledResult, symbol)
64+
guard let result = mangledResult else {
65+
fatalError("Mangled name should never be nil!")
5866
}
59-
self.llvm = _jit
60-
LLVMRunStaticConstructors(self.llvm)
61-
}
62-
63-
/// Retrieves a pointer to the function compiled by this JIT.
64-
/// - parameter name: The name of the function you wish to look up.
65-
/// - returns: A pointer to the result of compiling the specified function.
66-
/// - note: You will have to `unsafeBitCast` this pointer to
67-
/// the appropriate `@convention(c)` function type to be
68-
/// able to run it from Swift.
69-
///
70-
/// ```
71-
/// typealias FnPtr = @convention(c) () -> Double
72-
/// let fnAddr = jit.addressOfFunction(name: "test")
73-
/// let fn = unsafeBitCast(fnAddr, to: FnPtr.self)
74-
/// ```
75-
public func addressOfFunction(name: String) -> OpaquePointer? {
76-
let addr = LLVMGetFunctionAddress(llvm, name)
77-
guard addr != 0 else { return nil }
78-
return OpaquePointer(bitPattern: UInt(addr))
79-
}
80-
81-
/// Adds the provided module, and all top-level declarations into this JIT.
82-
/// - parameter module: The module you wish to add.
83-
public func addModule(_ module: Module) {
84-
LLVMAddModule(llvm, module.llvm)
85-
}
86-
87-
/// Removes the provided module, and all top-level declarations, from this
88-
/// JIT.
89-
public func removeModule(_ module: Module) throws {
90-
var outMod: LLVMModuleRef? = module.llvm
91-
var outError: UnsafeMutablePointer<Int8>?
92-
LLVMRemoveModule(llvm, module.llvm, &outMod, &outError)
93-
if let err = outError {
94-
defer { LLVMDisposeMessage(err) }
95-
throw JITError.couldNotRemoveModule(module, String(cString: err))
67+
defer { LLVMOrcDisposeMangledSymbol(mangledResult) }
68+
return String(cString: result)
69+
}
70+
71+
/// Computes the address of the given symbol, optionally restricting the
72+
/// search for its address to a particular module. If this symbol does not
73+
/// exist, an address of `0` is returned.
74+
///
75+
/// - parameter symbol: The symbol name to search for.
76+
/// - parameter module: An optional value describing the module in which to
77+
/// restrict the search, if any.
78+
/// - returns: The address of the symbol, or 0 if it does not exist.
79+
public func address(of symbol: String, in module: ModuleHandle? = nil) throws -> TargetAddress {
80+
var retAddr: TargetAddress = 0
81+
if let targetModule = module {
82+
try checkForJITError(LLVMOrcGetSymbolAddressIn(self.llvm, &retAddr, targetModule.llvm, symbol))
83+
} else {
84+
try checkForJITError(LLVMOrcGetSymbolAddress(self.llvm, &retAddr, symbol))
9685
}
86+
return retAddr
87+
}
88+
89+
// MARK: Lazy Compilation
90+
91+
/// Registers a lazy compile callback that can be used to get the target
92+
/// address of a trampoline function. When that trampoline address is
93+
/// called, the given compilation callback is fired.
94+
///
95+
/// Normally, the trampoline function is a known stub that has been previously
96+
/// registered with the JIT. The callback then computes the address of a
97+
/// known entry point and sets the address of the stub to it. See
98+
/// `JIT.createIndirectStub` to create a stub function and
99+
/// `JIT.setIndirectStubPointer` to set the address of a stub.
100+
///
101+
/// - parameter callback: A callback that returns the actual address of the
102+
/// trampoline function.
103+
/// - returns: The target address representing a stub. Calling this stub
104+
/// forces the given compilation callback to fire.
105+
public func registerLazyCompile(_ callback: @escaping (JIT) -> TargetAddress) throws -> TargetAddress {
106+
var addr: TargetAddress = 0
107+
let callbackContext = ORCLazyCompileCallbackContext(callback)
108+
let contextPtr = Unmanaged<ORCLazyCompileCallbackContext>.passRetained(callbackContext).toOpaque()
109+
try checkForJITError(LLVMOrcCreateLazyCompileCallback(self.llvm, &addr, lazyCompileBlockTrampoline, contextPtr))
110+
return addr
111+
}
112+
113+
// MARK: Stubs
114+
115+
/// Creates a new named indirect stub pointing to the given target address.
116+
///
117+
/// An indirect stub may be resolved to a different address at any time by
118+
/// invoking `JIT.setIndirectStubPointer`.
119+
///
120+
/// - parameter name: The name of the indirect stub.
121+
/// - parameter address: The address of the indirect stub.
122+
public func createIndirectStub(named name: String, address: TargetAddress) throws {
123+
try checkForJITError(LLVMOrcCreateIndirectStub(self.llvm, name, address))
124+
}
125+
126+
/// Resets the address of an indirect stub.
127+
///
128+
/// - warning: The indirect stub must be registered with a call to
129+
/// `JIT.createIndirectStub`. Failure to do so will result in undefined
130+
/// behavior.
131+
///
132+
/// - parameter name: The name of an indirect stub.
133+
/// - parameter address: The address to set the indirect stub to point to.
134+
public func setIndirectStubPointer(named name: String, address: TargetAddress) throws {
135+
try checkForJITError(LLVMOrcSetIndirectStubPointer(self.llvm, name, address))
136+
}
137+
138+
// MARK: Adding Code to the JIT
139+
140+
/// Adds the IR from a given module to the JIT, consuming it in the process.
141+
///
142+
/// Despite the name of this function, the callback to compile the symbols in
143+
/// the module is not necessarily called immediately. It is called at least
144+
/// when a given symbol's address is requested, either by the JIT or by
145+
/// the user e.g. `JIT.address(of:)`.
146+
///
147+
/// The callback function is required to compute the address of the given
148+
/// symbol. The symbols are passed in mangled form. Use
149+
/// `JIT.mangle(symbol:)` to request the mangled name of a symbol.
150+
///
151+
/// - warning: The JIT invalidates the underlying reference to the provided
152+
/// module. Further references to the module are thus dangling pointers and
153+
/// may be a source of subtle memory bugs. This will be addressed in a
154+
/// future revision of LLVM.
155+
///
156+
/// - parameter module: The module to compile.
157+
/// - parameter callback: A function that is called by the JIT to compute the
158+
/// address of symbols.
159+
public func addEagerlyCompiledIR(_ module: Module, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle {
160+
var handle: LLVMOrcModuleHandle = 0
161+
let callbackContext = ORCSymbolCallbackContext(callback)
162+
let contextPtr = Unmanaged<ORCSymbolCallbackContext>.passRetained(callbackContext).toOpaque()
163+
// The JIT stack takes ownership of the given module.
164+
module.ownsContext = false
165+
try checkForJITError(LLVMOrcAddEagerlyCompiledIR(self.llvm, &handle, module.llvm, symbolBlockTrampoline, contextPtr))
166+
return ModuleHandle(llvm: handle)
97167
}
98168

99-
/// Runs the specified function as if it were the `main` function in an
100-
/// executable. It takes an array of argument strings and passes them
101-
/// into the function as `argc` and `argv`.
102-
///
103-
/// - parameters:
104-
/// - function: The `main` function you wish to execute
105-
/// - args: The string arguments you wish to pass to the function
106-
/// - returns: The numerical exit code returned by the function
107-
public func runFunctionAsMain(_ function: Function, args: [String]) -> Int {
108-
// FIXME: Also add in envp.
109-
return withCArrayOfCStrings(args) { buf in
110-
return Int(LLVMRunFunctionAsMain(llvm, function.asLLVM(),
111-
UInt32(buf.count),
112-
buf.baseAddress, nil))
169+
/// Adds the IR from a given module to the JIT, consuming it in the process.
170+
///
171+
/// This function differs from `JIT.addEagerlyCompiledIR` in that the callback
172+
/// to request the address of symbols is only executed when that symbol is
173+
/// called, either in user code or by the JIT.
174+
///
175+
/// The callback function is required to compute the address of the given
176+
/// symbol. The symbols are passed in mangled form. Use
177+
/// `JIT.mangle(symbol:)` to request the mangled name of a symbol.
178+
///
179+
/// - warning: The JIT invalidates the underlying reference to the provided
180+
/// module. Further references to the module are thus dangling pointers and
181+
/// may be a source of subtle memory bugs. This will be addressed in a
182+
/// future revision of LLVM.
183+
///
184+
/// - parameter module: The module to compile.
185+
/// - parameter callback: A function that is called by the JIT to compute the
186+
/// address of symbols.
187+
public func addLazilyCompiledIR(_ module: Module, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle {
188+
var handle: LLVMOrcModuleHandle = 0
189+
let callbackContext = ORCSymbolCallbackContext(callback)
190+
let contextPtr = Unmanaged<ORCSymbolCallbackContext>.passRetained(callbackContext).toOpaque()
191+
// The JIT stack takes ownership of the given module.
192+
module.ownsContext = false
193+
try checkForJITError(LLVMOrcAddLazilyCompiledIR(self.llvm, &handle, module.llvm, symbolBlockTrampoline, contextPtr))
194+
return ModuleHandle(llvm: handle)
195+
}
196+
197+
/// Adds the executable code from an object file to ths JIT, consuming it in
198+
/// the process.
199+
///
200+
/// The callback function is required to compute the address of the given
201+
/// symbol. The symbols are passed in mangled form. Use
202+
/// `JIT.mangle(symbol:)` to request the mangled name of a symbol.
203+
///
204+
/// - warning: The JIT invalidates the underlying reference to the provided
205+
/// memory buffer. Further references to the buffer are thus dangling
206+
/// pointers and may be a source of subtle memory bugs. This will be
207+
/// addressed in a future revision of LLVM.
208+
///
209+
/// - parameter buffer: A buffer containing an object file.
210+
/// - parameter callback: A function that is called by the JIT to compute the
211+
/// address of symbols.
212+
public func addObjectFile(_ buffer: MemoryBuffer, _ callback: @escaping (String) -> TargetAddress) throws -> ModuleHandle {
213+
var handle: LLVMOrcModuleHandle = 0
214+
let callbackContext = ORCSymbolCallbackContext(callback)
215+
let contextPtr = Unmanaged<ORCSymbolCallbackContext>.passRetained(callbackContext).toOpaque()
216+
// The JIT stack takes ownership of the given buffer.
217+
buffer.ownsContext = false
218+
try checkForJITError(LLVMOrcAddObjectFile(self.llvm, &handle, buffer.llvm, symbolBlockTrampoline, contextPtr))
219+
return ModuleHandle(llvm: handle)
220+
}
221+
222+
/// Remove previously-added code from the JIT.
223+
///
224+
/// - warning: Removing a module handle consumes the handle. Further use of
225+
/// the handle will then result in undefined behavior.
226+
///
227+
/// - parameter handle: A handle to previously-added module.
228+
public func removeModule(_ handle: ModuleHandle) throws {
229+
try checkForJITError(LLVMOrcRemoveModule(self.llvm, handle.llvm))
230+
}
231+
232+
private func checkForJITError(_ orcError: LLVMOrcErrorCode) throws {
233+
switch orcError {
234+
case LLVMOrcErrSuccess:
235+
return
236+
case LLVMOrcErrGeneric:
237+
guard let msg = LLVMOrcGetErrorMsg(self.llvm) else {
238+
fatalError("Couldn't get the error message?")
239+
}
240+
throw JITError.generic(String(cString: msg))
241+
default:
242+
fatalError("Uncategorized ORC error code!")
113243
}
114244
}
245+
}
115246

116-
deinit {
117-
LLVMRunStaticDestructors(self.llvm)
247+
private let lazyCompileBlockTrampoline : LLVMOrcLazyCompileCallbackFn = { (callbackJIT, callbackCtx) in
248+
guard let jit = callbackJIT, let ctx = callbackCtx else {
249+
fatalError("Internal JIT callback and context must be non-nil")
118250
}
251+
252+
let tempJIT = JIT(llvm: jit, ownsContext: false)
253+
let callback = Unmanaged<ORCLazyCompileCallbackContext>.fromOpaque(ctx).takeUnretainedValue()
254+
return callback.block(tempJIT)
119255
}
120256

121-
/// Runs the provided block with the equivalent C strings copied from the
122-
/// passed-in array. The C strings will only be alive for the duration
123-
/// of the block, and they will be freed when the block exits.
124-
///
125-
/// - parameters:
126-
/// - strings: The strings you intend to convert to C strings
127-
/// - block: A block that uses the C strings
128-
/// - returns: The result of the passed-in block.
129-
/// - throws: Will only throw if the passed-in block throws.
130-
internal func withCArrayOfCStrings<T>(_ strings: [String], _ block:
131-
(UnsafeBufferPointer<UnsafePointer<Int8>?>) throws -> T) rethrows -> T {
132-
var cStrings = [UnsafeMutablePointer<Int8>?]()
133-
for string in strings {
134-
string.withCString {
135-
cStrings.append(strdup($0))
136-
}
257+
private let symbolBlockTrampoline : LLVMOrcSymbolResolverFn = { (callbackName, callbackCtx) in
258+
guard let cname = callbackName, let ctx = callbackCtx else {
259+
fatalError("Internal JIT name and context must be non-nil")
137260
}
138-
defer {
139-
for cStr in cStrings {
140-
free(cStr)
141-
}
261+
262+
let name = String(cString: cname)
263+
let callback = Unmanaged<ORCSymbolCallbackContext>.fromOpaque(ctx).takeUnretainedValue()
264+
return callback.block(name)
265+
}
266+
267+
private class ORCLazyCompileCallbackContext {
268+
fileprivate let block: (JIT) -> JIT.TargetAddress
269+
270+
fileprivate init(_ block: @escaping (JIT) -> JIT.TargetAddress) {
271+
self.block = block
142272
}
143-
return try cStrings.withUnsafeBufferPointer { buf in
144-
// We need to make this "immutable" but that doesn't change
145-
// their size or contents.
146-
let constPtr = unsafeBitCast(buf.baseAddress,
147-
to: UnsafePointer<UnsafePointer<Int8>?>.self)
148-
return try block(UnsafeBufferPointer(start: constPtr, count: buf.count))
273+
}
274+
275+
private class ORCSymbolCallbackContext {
276+
fileprivate let block: (String) -> JIT.TargetAddress
277+
278+
fileprivate init(_ block: @escaping (String) -> JIT.TargetAddress) {
279+
self.block = block
149280
}
150281
}

0 commit comments

Comments
 (0)