diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h index 9214aad4e6aec..cdbe3c4d5bc4a 100644 --- a/llvm/include/llvm/Transforms/Utils/Local.h +++ b/llvm/include/llvm/Transforms/Utils/Local.h @@ -49,6 +49,11 @@ class StoreInst; class TargetLibraryInfo; class TargetTransformInfo; +template class GenericSSAContext; +using SSAContext = GenericSSAContext; +template class GenericUniformityInfo; +using UniformityInfo = GenericUniformityInfo; + //===----------------------------------------------------------------------===// // Local constant propagation. // @@ -183,7 +188,7 @@ bool EliminateDuplicatePHINodes(BasicBlock *BB, /// providing the set of loop headers that SimplifyCFG should not eliminate. extern cl::opt RequireAndPreserveDomTree; bool simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - DomTreeUpdater *DTU = nullptr, + DomTreeUpdater *DTU = nullptr, UniformityInfo *UI = nullptr, const SimplifyCFGOptions &Options = {}, ArrayRef LoopHeaders = {}); diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 4e437e9abeb43..200ba3a49e647 100644 --- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -29,6 +29,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" @@ -229,7 +230,7 @@ static bool tailMergeBlocksWithSimilarFunctionTerminators(Function &F, /// Call SimplifyCFG on all the blocks in the function, /// iterating until no more changes are made. static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, - DomTreeUpdater *DTU, + DomTreeUpdater *DTU, UniformityInfo *UI, const SimplifyCFGOptions &Options) { bool Changed = false; bool LocalChange = true; @@ -261,7 +262,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, while (BBIt != F.end() && DTU->isBBPendingDeletion(&*BBIt)) ++BBIt; } - if (simplifyCFG(&BB, TTI, DTU, Options, LoopHeaders)) { + if (simplifyCFG(&BB, TTI, DTU, UI, Options, LoopHeaders)) { LocalChange = true; ++NumSimpl; } @@ -272,14 +273,15 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, } static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI, - DominatorTree *DT, + DominatorTree *DT, UniformityInfo *UI, const SimplifyCFGOptions &Options) { DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); bool EverChanged = removeUnreachableBlocks(F, DT ? &DTU : nullptr); EverChanged |= tailMergeBlocksWithSimilarFunctionTerminators(F, DT ? &DTU : nullptr); - EverChanged |= iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, Options); + EverChanged |= + iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, UI, Options); // If neither pass changed anything, we're done. if (!EverChanged) return false; @@ -293,7 +295,8 @@ static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI, return true; do { - EverChanged = iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, Options); + EverChanged = + iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, UI, Options); EverChanged |= removeUnreachableBlocks(F, DT ? &DTU : nullptr); } while (EverChanged); @@ -301,13 +304,13 @@ static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI, } static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, - DominatorTree *DT, + DominatorTree *DT, UniformityInfo *UI, const SimplifyCFGOptions &Options) { assert((!RequireAndPreserveDomTree || (DT && DT->verify(DominatorTree::VerificationLevel::Full))) && "Original domtree is invalid?"); - bool Changed = simplifyFunctionCFGImpl(F, TTI, DT, Options); + bool Changed = simplifyFunctionCFGImpl(F, TTI, DT, UI, Options); assert((!RequireAndPreserveDomTree || (DT && DT->verify(DominatorTree::VerificationLevel::Full))) && @@ -378,7 +381,8 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F, DominatorTree *DT = nullptr; if (RequireAndPreserveDomTree) DT = &AM.getResult(F); - if (!simplifyFunctionCFG(F, TTI, DT, Options)) + auto *UA = &AM.getResult(F); + if (!simplifyFunctionCFG(F, TTI, DT, UA, Options)) return PreservedAnalyses::all(); PreservedAnalyses PA; if (RequireAndPreserveDomTree) @@ -412,7 +416,8 @@ struct CFGSimplifyPass : public FunctionPass { DT = &getAnalysis().getDomTree(); auto &TTI = getAnalysis().getTTI(F); - return simplifyFunctionCFG(F, TTI, DT, Options); + auto &UI = getAnalysis().getUniformityInfo(); + return simplifyFunctionCFG(F, TTI, DT, &UI, Options); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); @@ -422,6 +427,7 @@ struct CFGSimplifyPass : public FunctionPass { if (RequireAndPreserveDomTree) AU.addPreserved(); AU.addPreserved(); + AU.addRequired(); } }; } @@ -432,6 +438,7 @@ INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 7840601d341b8..b6f0f74125774 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -32,6 +32,7 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -258,6 +259,7 @@ struct ValueEqualityComparisonCase { class SimplifyCFGOpt { const TargetTransformInfo &TTI; DomTreeUpdater *DTU; + UniformityInfo *UI; const DataLayout &DL; ArrayRef LoopHeaders; const SimplifyCFGOptions &Options; @@ -306,9 +308,10 @@ class SimplifyCFGOpt { public: SimplifyCFGOpt(const TargetTransformInfo &TTI, DomTreeUpdater *DTU, - const DataLayout &DL, ArrayRef LoopHeaders, - const SimplifyCFGOptions &Opts) - : TTI(TTI), DTU(DTU), DL(DL), LoopHeaders(LoopHeaders), Options(Opts) { + const DataLayout &DL, UniformityInfo *UI, + ArrayRef LoopHeaders, const SimplifyCFGOptions &Opts) + : TTI(TTI), DTU(DTU), UI(UI), DL(DL), LoopHeaders(LoopHeaders), + Options(Opts) { assert((!DTU || !DTU->hasPostDomTree()) && "SimplifyCFG is not yet capable of maintaining validity of a " "PostDomTree, so don't ask for it."); @@ -3490,6 +3493,17 @@ static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { return true; } +static bool blockIsFreeToThreadThrough(BasicBlock *BB, PHINode *PN) { + unsigned Size = 0; + for (Instruction &I : BB->instructionsWithoutDebug(false)) { + if (&I == PN) + continue; + if (++Size > 1) + return false; + } + return true; +} + static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From, BasicBlock *To) { // Don't look past the block defining the value, we might get the value from @@ -3511,10 +3525,9 @@ static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From, /// If we have a conditional branch on something for which we know the constant /// value in predecessors (e.g. a phi node in the current block), thread edges /// from the predecessor to their ultimate destination. -static std::optional -foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, - const DataLayout &DL, - AssumptionCache *AC) { +static std::optional foldCondBranchOnValueKnownInPredecessorImpl( + BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL, + const TargetTransformInfo &TTI, UniformityInfo *UI, AssumptionCache *AC) { SmallMapVector, 2> KnownValues; BasicBlock *BB = BI->getParent(); Value *Cond = BI->getCondition(); @@ -3555,6 +3568,16 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, if (RealDest == BB) continue; // Skip self loops. + // Check to see that we're not duplicating instructions into divergent + // branches. Doing so would essentially double the execution time, since + // the instructions will be executed by divergent threads serially. + if (TTI.hasBranchDivergence() && UI && + !blockIsFreeToThreadThrough(BB, PN) && + any_of(PredBBs, [&](BasicBlock *PredBB) { + return UI->hasDivergentTerminator(*PredBB); + })) + continue; + // Skip if the predecessor's terminator is an indirect branch. if (any_of(PredBBs, [](BasicBlock *PredBB) { return isa(PredBB->getTerminator()); @@ -3669,15 +3692,15 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, return false; } -static bool foldCondBranchOnValueKnownInPredecessor(BranchInst *BI, - DomTreeUpdater *DTU, - const DataLayout &DL, - AssumptionCache *AC) { +static bool foldCondBranchOnValueKnownInPredecessor( + BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL, + const TargetTransformInfo &TTI, UniformityInfo *UI, AssumptionCache *AC) { std::optional Result; bool EverChanged = false; do { // Note that None means "we changed things, but recurse further." - Result = foldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, AC); + Result = + foldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, TTI, UI, AC); EverChanged |= Result == std::nullopt || *Result; } while (Result == std::nullopt); return EverChanged; @@ -8082,7 +8105,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this is a branch on something for which we know the constant value in // predecessors (e.g. a phi node in the current block), thread control // through this block. - if (foldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, Options.AC)) + if (foldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, TTI, UI, Options.AC)) return requestResimplify(); // Scan predecessor blocks for conditional branches. @@ -8402,9 +8425,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { } bool llvm::simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - DomTreeUpdater *DTU, const SimplifyCFGOptions &Options, + DomTreeUpdater *DTU, UniformityInfo *UI, + const SimplifyCFGOptions &Options, ArrayRef LoopHeaders) { - return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), LoopHeaders, - Options) + return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), UI, LoopHeaders, Options) .run(BB); } diff --git a/llvm/test/Transforms/SimplifyCFG/NVPTX/lit.local.cfg b/llvm/test/Transforms/SimplifyCFG/NVPTX/lit.local.cfg new file mode 100644 index 0000000000000..0d37b86e1c8e6 --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/NVPTX/lit.local.cfg @@ -0,0 +1,2 @@ +if not "NVPTX" in config.root.targets: + config.unsupported = True diff --git a/llvm/test/Transforms/SimplifyCFG/NVPTX/uniformity.ll b/llvm/test/Transforms/SimplifyCFG/NVPTX/uniformity.ll new file mode 100644 index 0000000000000..c23bbfd41d8fe --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/NVPTX/uniformity.ll @@ -0,0 +1,148 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=simplifycfg -S | FileCheck %s + +target triple = "nvptx64-nvidia-cuda" + +;; Branch threading in cases where the condition is divergent is bad because the +;; divergent code that is duplicated will be executed serially by the diverged +;; threads essentially doubling execution time. +define ptx_kernel void @test_01(ptr %ptr) { +; CHECK-LABEL: @test_01( +; CHECK-NEXT: [[ID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[ID]], 0 +; CHECK-NEXT: br i1 [[C]], label [[TRUE2_CRITEDGE:%.*]], label [[FALSE1:%.*]] +; CHECK: true1: +; CHECK-NEXT: store volatile i64 0, ptr [[PTR:%.*]], align 8 +; CHECK-NEXT: store volatile i64 -1, ptr [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, ptr [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, ptr [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, ptr [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, ptr [[PTR]], align 8 +; CHECK-NEXT: br i1 [[C]], label [[TRUE2:%.*]], label [[FALSE2:%.*]] +; CHECK: false1: +; CHECK-NEXT: store volatile i64 1, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[TRUE2_CRITEDGE]] +; CHECK: true2: +; CHECK-NEXT: store volatile i64 2, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[JOIN2:%.*]] +; CHECK: false2: +; CHECK-NEXT: store volatile i64 3, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[JOIN2]] +; CHECK: join2: +; CHECK-NEXT: ret void +; + %id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %c = icmp eq i32 %id, 0 + br i1 %c, label %true1, label %false1 + +true1: + store volatile i64 0, ptr %ptr + store volatile i64 -1, ptr %ptr + store volatile i64 -1, ptr %ptr + store volatile i64 -1, ptr %ptr + store volatile i64 -1, ptr %ptr + store volatile i64 -1, ptr %ptr + br i1 %c, label %true2, label %false2 + +false1: + store volatile i64 1, ptr %ptr + br label %true1 + +true2: + store volatile i64 2, ptr %ptr + br label %join2 + +false2: + store volatile i64 3, ptr %ptr + br label %join2 + +join2: + ret void +} + +;; This case isn't as bad but still costly enough that we should avoid threading +;; through the divergent edge. +define ptx_kernel void @test_02(ptr %ptr) { +; CHECK-LABEL: @test_02( +; CHECK-NEXT: [[ID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[ID]], 0 +; CHECK-NEXT: br i1 [[C]], label [[TRUE2_CRITEDGE:%.*]], label [[FALSE1:%.*]] +; CHECK: true1: +; CHECK-NEXT: store volatile i64 0, ptr [[PTR:%.*]], align 8 +; CHECK-NEXT: br i1 [[C]], label [[TRUE2:%.*]], label [[FALSE2:%.*]] +; CHECK: false1: +; CHECK-NEXT: store volatile i64 1, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[TRUE2_CRITEDGE]] +; CHECK: true2: +; CHECK-NEXT: store volatile i64 2, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[JOIN2:%.*]] +; CHECK: false2: +; CHECK-NEXT: store volatile i64 3, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[JOIN2]] +; CHECK: join2: +; CHECK-NEXT: ret void +; + %id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %c = icmp eq i32 %id, 0 + br i1 %c, label %true1, label %false1 + +true1: + store volatile i64 0, ptr %ptr + br i1 %c, label %true2, label %false2 + +false1: + store volatile i64 1, ptr %ptr + br label %true1 + +true2: + store volatile i64 2, ptr %ptr + br label %join2 + +false2: + store volatile i64 3, ptr %ptr + br label %join2 + +join2: + ret void +} + +;; This case is simple enough that branch threading is still a good idea. +define ptx_kernel void @test_03(ptr %ptr) { +; CHECK-LABEL: @test_03( +; CHECK-NEXT: [[ID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[ID]], 0 +; CHECK-NEXT: br i1 [[C]], label [[TRUE1:%.*]], label [[FALSE2:%.*]] +; CHECK: true1: +; CHECK-NEXT: store volatile i64 1, ptr [[PTR:%.*]], align 8 +; CHECK-NEXT: store volatile i64 2, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[JOIN2:%.*]] +; CHECK: false2: +; CHECK-NEXT: store volatile i64 3, ptr [[PTR]], align 8 +; CHECK-NEXT: br label [[JOIN2]] +; CHECK: join2: +; CHECK-NEXT: ret void +; + %id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %c = icmp eq i32 %id, 0 + br i1 %c, label %true1, label %join1 + +true1: + store volatile i64 1, ptr %ptr + br label %join1 + +join1: + br i1 %c, label %true2, label %false2 + +true2: + store volatile i64 2, ptr %ptr + br label %join2 + +false2: + store volatile i64 3, ptr %ptr + br label %join2 + +join2: + ret void +} + +declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()