Skip to content

Commit fa0cafa

Browse files
giuserosvlad-penkin
authored andcommitted
Pass the target machine to the LLVM pass builder (#4655)
In a recent change to the [LLVM AMD backend](llvm/llvm-project#83131), we moved the `AMDGPUAttributor` pass into the optimization pipeline (as opposed to the codegen pipeline). Since this is a pass specific for `AMD` targets, we want to pass the `TargetMachine` when building the pipeline, i.e., during the call to `optimize_module`. Failure to do so will result in an increase of number of registers used. Also, we spoke with our LLVM backend team, and they advised to always pass the `TargetMachine` when building the LLVM optimization pipeline. This PR is addressing this issue, in the following way: - I added optional parameters to the `optimize_module` funciton (similar to those passed to `translate_to_asm`) - if those params are passed in, then we will create the `TargetMachine` and pass it to the `PassBuilder` - Otherwise the `TargetMachine` will still be `nullptr` (as it was before) Please note that, as it stands now, this change will only effect the AMD backend.
1 parent 2528dd7 commit fa0cafa

File tree

2 files changed

+129
-103
lines changed

2 files changed

+129
-103
lines changed

python/src/llvm.cc

Lines changed: 128 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/Transforms/IPO/AlwaysInliner.h"
2525
#include "llvm/Transforms/InstCombine/InstCombine.h"
2626
#include <csignal>
27+
#include <memory>
2728
#include <pybind11/pybind11.h>
2829
#include <pybind11/stl.h>
2930
#include <stdexcept>
@@ -39,6 +40,30 @@ struct BreakStructPhiNodesPass : PassInfoMixin<BreakStructPhiNodesPass> {
3940

4041
using namespace llvm;
4142

43+
std::unique_ptr<TargetMachine>
44+
createTargetMachine(llvm::Module *module, std::string proc,
45+
bool enable_fp_fusion, const std::string &features) {
46+
std::string error;
47+
auto target =
48+
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
49+
llvm::TargetOptions opt;
50+
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
51+
if (enable_fp_fusion)
52+
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
53+
opt.UnsafeFPMath = false;
54+
opt.NoInfsFPMath = false;
55+
opt.NoNaNsFPMath = true;
56+
opt.TrapUnreachable = true;
57+
opt.MCOptions.AsmVerbose = true;
58+
opt.MCOptions.PreserveAsmComments = true;
59+
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
60+
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
61+
std::nullopt,
62+
disableLLVMOpt ? llvm::CodeGenOptLevel::None
63+
: llvm::CodeGenOptLevel::Aggressive)};
64+
return machine;
65+
}
66+
4267
std::string translateLLVMIRToASM(llvm::Module &module,
4368
const std::string &triple,
4469
const std::string &proc,
@@ -106,23 +131,7 @@ std::string translateLLVMIRToASM(llvm::Module &module,
106131

107132
// create machine
108133
module.setTargetTriple(triple);
109-
std::string error;
110-
auto target =
111-
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
112-
llvm::TargetOptions opt;
113-
if (enable_fp_fusion)
114-
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
115-
opt.UnsafeFPMath = false;
116-
opt.NoInfsFPMath = false;
117-
opt.NoNaNsFPMath = true;
118-
opt.TrapUnreachable = true;
119-
opt.MCOptions.AsmVerbose = true;
120-
opt.MCOptions.PreserveAsmComments = true;
121-
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
122-
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
123-
std::nullopt,
124-
disableLLVMOpt ? llvm::CodeGenOptLevel::None
125-
: llvm::CodeGenOptLevel::Aggressive)};
134+
auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features);
126135
// set data layout
127136
module.setDataLayout(machine->createDataLayout());
128137
// emit machine code
@@ -267,94 +276,111 @@ void init_triton_llvm(py::module &&m) {
267276
mod->setDataLayout(machine->createDataLayout());
268277
});
269278

270-
m.def("optimize_module", [](llvm::Module *mod,
271-
const llvm::OptimizationLevel &opt) {
272-
if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"))
273-
return;
274-
// Check to see if we are passing a list of flags to disable optimizations.
275-
auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT");
276-
if (!flagList.empty()) {
277-
auto options = llvm::cl::getRegisteredOptions();
278-
llvm::SmallVector<StringRef, 3> split;
279-
StringRef(flagList.c_str()).split(split, ',');
280-
for (auto flag : split) {
281-
auto optIt = options.find(flag);
282-
if (optIt != options.end()) {
283-
auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
284-
*optPtr = true;
279+
m.def(
280+
"optimize_module",
281+
[](llvm::Module *mod, const llvm::OptimizationLevel &opt,
282+
std::string arch, std::string features, std::vector<std::string> flags,
283+
bool enable_fp_fusion) {
284+
if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"))
285+
return;
286+
// Check to see if we are passing a list of flags to disable
287+
// optimizations.
288+
auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT");
289+
if (!flagList.empty()) {
290+
auto options = llvm::cl::getRegisteredOptions();
291+
llvm::SmallVector<StringRef, 3> split;
292+
StringRef(flagList.c_str()).split(split, ',');
293+
for (auto flag : split) {
294+
auto optIt = options.find(flag);
295+
if (optIt != options.end()) {
296+
auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
297+
*optPtr = true;
298+
}
299+
}
300+
}
301+
using namespace llvm;
302+
LoopAnalysisManager lam;
303+
FunctionAnalysisManager fam;
304+
CGSCCAnalysisManager cgam;
305+
ModuleAnalysisManager mam;
306+
307+
PassInstrumentationCallbacks *instrCbPtr = nullptr;
308+
PassInstrumentationCallbacks passInstrCb;
309+
StandardInstrumentations standardInstr(mod->getContext(),
310+
/*DebugLogging*/ true);
311+
if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
312+
auto optMap = llvm::cl::getRegisteredOptions();
313+
auto optIt = optMap.find("print-after-all");
314+
if (optIt != optMap.end()) {
315+
auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
316+
*optPtr = true;
317+
}
318+
standardInstr.registerCallbacks(passInstrCb, &mam);
319+
instrCbPtr = &passInstrCb;
285320
}
286-
}
287-
}
288-
using namespace llvm;
289-
LoopAnalysisManager lam;
290-
FunctionAnalysisManager fam;
291-
CGSCCAnalysisManager cgam;
292-
ModuleAnalysisManager mam;
293-
294-
PassInstrumentationCallbacks *instrCbPtr = nullptr;
295-
PassInstrumentationCallbacks passInstrCb;
296-
StandardInstrumentations standardInstr(mod->getContext(),
297-
/*DebugLogging*/ true);
298-
if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
299-
auto optMap = llvm::cl::getRegisteredOptions();
300-
auto optIt = optMap.find("print-after-all");
301-
if (optIt != optMap.end()) {
302-
auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
303-
*optPtr = true;
304-
}
305-
standardInstr.registerCallbacks(passInstrCb, &mam);
306-
instrCbPtr = &passInstrCb;
307-
}
308321

309-
PipelineTuningOptions tuningOptions;
310-
tuningOptions.LoopUnrolling = true;
311-
tuningOptions.LoopInterleaving = true;
312-
tuningOptions.LoopVectorization = true;
313-
// TODO: currently we run SLP vectorizer with an empty target machine.
314-
// This cause the vectorizer to create larger vector which could be bad.
315-
// Disabling it would currently cause regressions as this pass also applies
316-
// some scheduling that helps performance in some cases. We should work on
317-
// using NVPTX target instead and address the performance regressions with
318-
// some scheduling solution.
319-
tuningOptions.SLPVectorization = true;
320-
321-
PassBuilder pb(nullptr /*targetMachine*/, tuningOptions, std::nullopt,
322-
instrCbPtr);
323-
324-
std::string pluginFile =
325-
mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH");
326-
327-
if (!pluginFile.empty()) {
328-
// TODO: Add some logging here that we inserted a pass into the LLVM
329-
// pass pipeline
330-
auto passPlugin = llvm::PassPlugin::Load(pluginFile);
331-
if (!passPlugin) {
332-
llvm::Error Err = passPlugin.takeError();
333-
std::string ErrMsg =
334-
"Pass Plugin Error: " + llvm::toString(std::move(Err));
335-
throw std::runtime_error(ErrMsg);
336-
}
337-
passPlugin->registerPassBuilderCallbacks(pb);
338-
}
322+
PipelineTuningOptions tuningOptions;
323+
tuningOptions.LoopUnrolling = true;
324+
tuningOptions.LoopInterleaving = true;
325+
tuningOptions.LoopVectorization = true;
326+
// TODO: currently we run SLP vectorizer with an empty target machine.
327+
// This cause the vectorizer to create larger vector which could be bad.
328+
// Disabling it would currently cause regressions as this pass also
329+
// applies some scheduling that helps performance in some cases. We
330+
// should work on using NVPTX target instead and address the performance
331+
// regressions with some scheduling solution.
332+
tuningOptions.SLPVectorization = true;
333+
334+
// We don't pass the targetMachine to the LLVM-IR pass builder, unless
335+
// `arch` is specified
336+
std::unique_ptr<TargetMachine> targetMachine = nullptr;
337+
if (!arch.empty())
338+
targetMachine = std::move(
339+
createTargetMachine(mod, arch, enable_fp_fusion, features));
340+
PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions,
341+
std::nullopt, instrCbPtr);
342+
343+
std::string pluginFile =
344+
mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH");
345+
346+
if (!pluginFile.empty()) {
347+
// TODO: Add some logging here that we inserted a pass into the LLVM
348+
// pass pipeline
349+
auto passPlugin = llvm::PassPlugin::Load(pluginFile);
350+
if (!passPlugin) {
351+
llvm::Error Err = passPlugin.takeError();
352+
std::string ErrMsg =
353+
"Pass Plugin Error: " + llvm::toString(std::move(Err));
354+
throw std::runtime_error(ErrMsg);
355+
}
356+
passPlugin->registerPassBuilderCallbacks(pb);
357+
}
339358

340-
pb.registerModuleAnalyses(mam);
341-
pb.registerCGSCCAnalyses(cgam);
342-
pb.registerFunctionAnalyses(fam);
343-
pb.registerLoopAnalyses(lam);
344-
pb.crossRegisterProxies(lam, fam, cgam, mam);
345-
346-
ModulePassManager mpm;
347-
pb.registerVectorizerStartEPCallback(
348-
[&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
349-
// Triton generates large structure of scalars which may pessimise
350-
// optimizations, we run a pass to break up phi of struct to make
351-
// sure all the struct are removed for the following passes.
352-
fpm.addPass(BreakStructPhiNodesPass());
353-
fpm.addPass(InstCombinePass());
354-
});
355-
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
356-
mpm.run(*mod, mam);
357-
});
359+
pb.registerModuleAnalyses(mam);
360+
pb.registerCGSCCAnalyses(cgam);
361+
pb.registerFunctionAnalyses(fam);
362+
pb.registerLoopAnalyses(lam);
363+
pb.crossRegisterProxies(lam, fam, cgam, mam);
364+
365+
ModulePassManager mpm;
366+
pb.registerVectorizerStartEPCallback(
367+
[&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
368+
// Triton generates large structure of scalars which may pessimise
369+
// optimizations, we run a pass to break up phi of struct to make
370+
// sure all the struct are removed for the following passes.
371+
fpm.addPass(BreakStructPhiNodesPass());
372+
fpm.addPass(InstCombinePass());
373+
});
374+
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
375+
mpm.run(*mod, mam);
376+
},
377+
// Mandatory parameters
378+
py::arg("mod"), py::arg("opt"),
379+
// If we want to specify the target machine, we require additional
380+
// (optional) parameters
381+
py::arg("arch") = "", py::arg("features") = "",
382+
py::arg("flags") = std::vector<std::string>{},
383+
py::arg("enable_fp_fusion") = false);
358384

359385
m.def(
360386
"translate_to_asm",

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def make_llir(src, metadata, options):
255255
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
256256
llvm.link_extern_libs(llvm_mod, paths)
257257

258-
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
258+
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
259259

260260
# Get some metadata
261261
metadata["shared"] = src.get_int_attr("triton_gpu.shared")

0 commit comments

Comments
 (0)