24
24
#include " llvm/Transforms/IPO/AlwaysInliner.h"
25
25
#include " llvm/Transforms/InstCombine/InstCombine.h"
26
26
#include < csignal>
27
+ #include < memory>
27
28
#include < pybind11/pybind11.h>
28
29
#include < pybind11/stl.h>
29
30
#include < stdexcept>
@@ -39,6 +40,30 @@ struct BreakStructPhiNodesPass : PassInfoMixin<BreakStructPhiNodesPass> {
39
40
40
41
using namespace llvm ;
41
42
43
+ std::unique_ptr<TargetMachine>
44
+ createTargetMachine (llvm::Module *module, std::string proc,
45
+ bool enable_fp_fusion, const std::string &features) {
46
+ std::string error;
47
+ auto target =
48
+ llvm::TargetRegistry::lookupTarget (module->getTargetTriple (), error);
49
+ llvm::TargetOptions opt;
50
+ bool disableLLVMOpt = mlir::triton::tools::getBoolEnv (" DISABLE_LLVM_OPT" );
51
+ if (enable_fp_fusion)
52
+ opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
53
+ opt.UnsafeFPMath = false ;
54
+ opt.NoInfsFPMath = false ;
55
+ opt.NoNaNsFPMath = true ;
56
+ opt.TrapUnreachable = true ;
57
+ opt.MCOptions .AsmVerbose = true ;
58
+ opt.MCOptions .PreserveAsmComments = true ;
59
+ std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine (
60
+ module->getTargetTriple (), proc, features, opt, llvm::Reloc::PIC_,
61
+ std::nullopt,
62
+ disableLLVMOpt ? llvm::CodeGenOptLevel::None
63
+ : llvm::CodeGenOptLevel::Aggressive)};
64
+ return machine;
65
+ }
66
+
42
67
std::string translateLLVMIRToASM (llvm::Module &module,
43
68
const std::string &triple,
44
69
const std::string &proc,
@@ -106,23 +131,7 @@ std::string translateLLVMIRToASM(llvm::Module &module,
106
131
107
132
// create machine
108
133
module.setTargetTriple (triple);
109
- std::string error;
110
- auto target =
111
- llvm::TargetRegistry::lookupTarget (module.getTargetTriple (), error);
112
- llvm::TargetOptions opt;
113
- if (enable_fp_fusion)
114
- opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
115
- opt.UnsafeFPMath = false ;
116
- opt.NoInfsFPMath = false ;
117
- opt.NoNaNsFPMath = true ;
118
- opt.TrapUnreachable = true ;
119
- opt.MCOptions .AsmVerbose = true ;
120
- opt.MCOptions .PreserveAsmComments = true ;
121
- std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine (
122
- module.getTargetTriple (), proc, features, opt, llvm::Reloc::PIC_,
123
- std::nullopt,
124
- disableLLVMOpt ? llvm::CodeGenOptLevel::None
125
- : llvm::CodeGenOptLevel::Aggressive)};
134
+ auto machine = createTargetMachine (&module, proc, enable_fp_fusion, features);
126
135
// set data layout
127
136
module.setDataLayout (machine->createDataLayout ());
128
137
// emit machine code
@@ -267,94 +276,111 @@ void init_triton_llvm(py::module &&m) {
267
276
mod->setDataLayout (machine->createDataLayout ());
268
277
});
269
278
270
- m.def (" optimize_module" , [](llvm::Module *mod,
271
- const llvm::OptimizationLevel &opt) {
272
- if (mlir::triton::tools::getBoolEnv (" DISABLE_LLVM_OPT" ))
273
- return ;
274
- // Check to see if we are passing a list of flags to disable optimizations.
275
- auto flagList = mlir::triton::tools::getStrEnv (" DISABLE_LLVM_OPT" );
276
- if (!flagList.empty ()) {
277
- auto options = llvm::cl::getRegisteredOptions ();
278
- llvm::SmallVector<StringRef, 3 > split;
279
- StringRef (flagList.c_str ()).split (split, ' ,' );
280
- for (auto flag : split) {
281
- auto optIt = options.find (flag);
282
- if (optIt != options.end ()) {
283
- auto optPtr = static_cast <llvm::cl::opt<bool > *>(optIt->second );
284
- *optPtr = true ;
279
+ m.def (
280
+ " optimize_module" ,
281
+ [](llvm::Module *mod, const llvm::OptimizationLevel &opt,
282
+ std::string arch, std::string features, std::vector<std::string> flags,
283
+ bool enable_fp_fusion) {
284
+ if (mlir::triton::tools::getBoolEnv (" DISABLE_LLVM_OPT" ))
285
+ return ;
286
+ // Check to see if we are passing a list of flags to disable
287
+ // optimizations.
288
+ auto flagList = mlir::triton::tools::getStrEnv (" DISABLE_LLVM_OPT" );
289
+ if (!flagList.empty ()) {
290
+ auto options = llvm::cl::getRegisteredOptions ();
291
+ llvm::SmallVector<StringRef, 3 > split;
292
+ StringRef (flagList.c_str ()).split (split, ' ,' );
293
+ for (auto flag : split) {
294
+ auto optIt = options.find (flag);
295
+ if (optIt != options.end ()) {
296
+ auto optPtr = static_cast <llvm::cl::opt<bool > *>(optIt->second );
297
+ *optPtr = true ;
298
+ }
299
+ }
300
+ }
301
+ using namespace llvm ;
302
+ LoopAnalysisManager lam;
303
+ FunctionAnalysisManager fam;
304
+ CGSCCAnalysisManager cgam;
305
+ ModuleAnalysisManager mam;
306
+
307
+ PassInstrumentationCallbacks *instrCbPtr = nullptr ;
308
+ PassInstrumentationCallbacks passInstrCb;
309
+ StandardInstrumentations standardInstr (mod->getContext (),
310
+ /* DebugLogging*/ true );
311
+ if (mlir::triton::tools::getBoolEnv (" LLVM_IR_ENABLE_DUMP" )) {
312
+ auto optMap = llvm::cl::getRegisteredOptions ();
313
+ auto optIt = optMap.find (" print-after-all" );
314
+ if (optIt != optMap.end ()) {
315
+ auto optPtr = static_cast <llvm::cl::opt<bool > *>(optIt->second );
316
+ *optPtr = true ;
317
+ }
318
+ standardInstr.registerCallbacks (passInstrCb, &mam);
319
+ instrCbPtr = &passInstrCb;
285
320
}
286
- }
287
- }
288
- using namespace llvm ;
289
- LoopAnalysisManager lam;
290
- FunctionAnalysisManager fam;
291
- CGSCCAnalysisManager cgam;
292
- ModuleAnalysisManager mam;
293
-
294
- PassInstrumentationCallbacks *instrCbPtr = nullptr ;
295
- PassInstrumentationCallbacks passInstrCb;
296
- StandardInstrumentations standardInstr (mod->getContext (),
297
- /* DebugLogging*/ true );
298
- if (mlir::triton::tools::getBoolEnv (" LLVM_IR_ENABLE_DUMP" )) {
299
- auto optMap = llvm::cl::getRegisteredOptions ();
300
- auto optIt = optMap.find (" print-after-all" );
301
- if (optIt != optMap.end ()) {
302
- auto optPtr = static_cast <llvm::cl::opt<bool > *>(optIt->second );
303
- *optPtr = true ;
304
- }
305
- standardInstr.registerCallbacks (passInstrCb, &mam);
306
- instrCbPtr = &passInstrCb;
307
- }
308
321
309
- PipelineTuningOptions tuningOptions;
310
- tuningOptions.LoopUnrolling = true ;
311
- tuningOptions.LoopInterleaving = true ;
312
- tuningOptions.LoopVectorization = true ;
313
- // TODO: currently we run SLP vectorizer with an empty target machine.
314
- // This cause the vectorizer to create larger vector which could be bad.
315
- // Disabling it would currently cause regressions as this pass also applies
316
- // some scheduling that helps performance in some cases. We should work on
317
- // using NVPTX target instead and address the performance regressions with
318
- // some scheduling solution.
319
- tuningOptions.SLPVectorization = true ;
320
-
321
- PassBuilder pb (nullptr /* targetMachine*/ , tuningOptions, std::nullopt,
322
- instrCbPtr);
323
-
324
- std::string pluginFile =
325
- mlir::triton::tools::getStrEnv (" LLVM_PASS_PLUGIN_PATH" );
326
-
327
- if (!pluginFile.empty ()) {
328
- // TODO: Add some logging here that we inserted a pass into the LLVM
329
- // pass pipeline
330
- auto passPlugin = llvm::PassPlugin::Load (pluginFile);
331
- if (!passPlugin) {
332
- llvm::Error Err = passPlugin.takeError ();
333
- std::string ErrMsg =
334
- " Pass Plugin Error: " + llvm::toString (std::move (Err));
335
- throw std::runtime_error (ErrMsg);
336
- }
337
- passPlugin->registerPassBuilderCallbacks (pb);
338
- }
322
+ PipelineTuningOptions tuningOptions;
323
+ tuningOptions.LoopUnrolling = true ;
324
+ tuningOptions.LoopInterleaving = true ;
325
+ tuningOptions.LoopVectorization = true ;
326
+ // TODO: currently we run SLP vectorizer with an empty target machine.
327
+ // This cause the vectorizer to create larger vector which could be bad.
328
+ // Disabling it would currently cause regressions as this pass also
329
+ // applies some scheduling that helps performance in some cases. We
330
+ // should work on using NVPTX target instead and address the performance
331
+ // regressions with some scheduling solution.
332
+ tuningOptions.SLPVectorization = true ;
333
+
334
+ // We don't pass the targetMachine to the LLVM-IR pass builder, unless
335
+ // `arch` is specified
336
+ std::unique_ptr<TargetMachine> targetMachine = nullptr ;
337
+ if (!arch.empty ())
338
+ targetMachine = std::move (
339
+ createTargetMachine (mod, arch, enable_fp_fusion, features));
340
+ PassBuilder pb (/* targetMachine=*/ targetMachine.get (), tuningOptions,
341
+ std::nullopt, instrCbPtr);
342
+
343
+ std::string pluginFile =
344
+ mlir::triton::tools::getStrEnv (" LLVM_PASS_PLUGIN_PATH" );
345
+
346
+ if (!pluginFile.empty ()) {
347
+ // TODO: Add some logging here that we inserted a pass into the LLVM
348
+ // pass pipeline
349
+ auto passPlugin = llvm::PassPlugin::Load (pluginFile);
350
+ if (!passPlugin) {
351
+ llvm::Error Err = passPlugin.takeError ();
352
+ std::string ErrMsg =
353
+ " Pass Plugin Error: " + llvm::toString (std::move (Err));
354
+ throw std::runtime_error (ErrMsg);
355
+ }
356
+ passPlugin->registerPassBuilderCallbacks (pb);
357
+ }
339
358
340
- pb.registerModuleAnalyses (mam);
341
- pb.registerCGSCCAnalyses (cgam);
342
- pb.registerFunctionAnalyses (fam);
343
- pb.registerLoopAnalyses (lam);
344
- pb.crossRegisterProxies (lam, fam, cgam, mam);
345
-
346
- ModulePassManager mpm;
347
- pb.registerVectorizerStartEPCallback (
348
- [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
349
- // Triton generates large structure of scalars which may pessimise
350
- // optimizations, we run a pass to break up phi of struct to make
351
- // sure all the struct are removed for the following passes.
352
- fpm.addPass (BreakStructPhiNodesPass ());
353
- fpm.addPass (InstCombinePass ());
354
- });
355
- mpm.addPass (pb.buildPerModuleDefaultPipeline (opt));
356
- mpm.run (*mod, mam);
357
- });
359
+ pb.registerModuleAnalyses (mam);
360
+ pb.registerCGSCCAnalyses (cgam);
361
+ pb.registerFunctionAnalyses (fam);
362
+ pb.registerLoopAnalyses (lam);
363
+ pb.crossRegisterProxies (lam, fam, cgam, mam);
364
+
365
+ ModulePassManager mpm;
366
+ pb.registerVectorizerStartEPCallback (
367
+ [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
368
+ // Triton generates large structure of scalars which may pessimise
369
+ // optimizations, we run a pass to break up phi of struct to make
370
+ // sure all the struct are removed for the following passes.
371
+ fpm.addPass (BreakStructPhiNodesPass ());
372
+ fpm.addPass (InstCombinePass ());
373
+ });
374
+ mpm.addPass (pb.buildPerModuleDefaultPipeline (opt));
375
+ mpm.run (*mod, mam);
376
+ },
377
+ // Mandatory parameters
378
+ py::arg (" mod" ), py::arg (" opt" ),
379
+ // If we want to specify the target machine, we require additional
380
+ // (optional) parameters
381
+ py::arg (" arch" ) = " " , py::arg (" features" ) = " " ,
382
+ py::arg (" flags" ) = std::vector<std::string>{},
383
+ py::arg (" enable_fp_fusion" ) = false );
358
384
359
385
m.def (
360
386
" translate_to_asm" ,
0 commit comments