Skip to content

[SimplifyCFG] Avoid branch threading of divergent conditionals #141867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion llvm/include/llvm/Transforms/Utils/Local.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class StoreInst;
class TargetLibraryInfo;
class TargetTransformInfo;

template <typename T> class GenericSSAContext;
using SSAContext = GenericSSAContext<Function>;
template <typename T> class GenericUniformityInfo;
using UniformityInfo = GenericUniformityInfo<SSAContext>;

//===----------------------------------------------------------------------===//
// Local constant propagation.
//
Expand Down Expand Up @@ -183,7 +188,7 @@ bool EliminateDuplicatePHINodes(BasicBlock *BB,
/// providing the set of loop headers that SimplifyCFG should not eliminate.
extern cl::opt<bool> RequireAndPreserveDomTree;
bool simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI,
DomTreeUpdater *DTU = nullptr,
DomTreeUpdater *DTU = nullptr, UniformityInfo *UI = nullptr,
const SimplifyCFGOptions &Options = {},
ArrayRef<WeakVH> LoopHeaders = {});

Expand Down
25 changes: 16 additions & 9 deletions llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/DebugInfoMetadata.h"
Expand Down Expand Up @@ -229,7 +230,7 @@ static bool tailMergeBlocksWithSimilarFunctionTerminators(Function &F,
/// Call SimplifyCFG on all the blocks in the function,
/// iterating until no more changes are made.
static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
DomTreeUpdater *DTU,
DomTreeUpdater *DTU, UniformityInfo *UI,
const SimplifyCFGOptions &Options) {
bool Changed = false;
bool LocalChange = true;
Expand Down Expand Up @@ -261,7 +262,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
while (BBIt != F.end() && DTU->isBBPendingDeletion(&*BBIt))
++BBIt;
}
if (simplifyCFG(&BB, TTI, DTU, Options, LoopHeaders)) {
if (simplifyCFG(&BB, TTI, DTU, UI, Options, LoopHeaders)) {
LocalChange = true;
++NumSimpl;
}
Expand All @@ -272,14 +273,15 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
}

static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI,
DominatorTree *DT,
DominatorTree *DT, UniformityInfo *UI,
const SimplifyCFGOptions &Options) {
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);

bool EverChanged = removeUnreachableBlocks(F, DT ? &DTU : nullptr);
EverChanged |=
tailMergeBlocksWithSimilarFunctionTerminators(F, DT ? &DTU : nullptr);
EverChanged |= iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, Options);
EverChanged |=
iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, UI, Options);

// If neither pass changed anything, we're done.
if (!EverChanged) return false;
Expand All @@ -293,21 +295,22 @@ static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI,
return true;

do {
EverChanged = iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, Options);
EverChanged =
iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, UI, Options);
EverChanged |= removeUnreachableBlocks(F, DT ? &DTU : nullptr);
} while (EverChanged);

return true;
}

static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,
DominatorTree *DT,
DominatorTree *DT, UniformityInfo *UI,
const SimplifyCFGOptions &Options) {
assert((!RequireAndPreserveDomTree ||
(DT && DT->verify(DominatorTree::VerificationLevel::Full))) &&
"Original domtree is invalid?");

bool Changed = simplifyFunctionCFGImpl(F, TTI, DT, Options);
bool Changed = simplifyFunctionCFGImpl(F, TTI, DT, UI, Options);

assert((!RequireAndPreserveDomTree ||
(DT && DT->verify(DominatorTree::VerificationLevel::Full))) &&
Expand Down Expand Up @@ -378,7 +381,8 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F,
DominatorTree *DT = nullptr;
if (RequireAndPreserveDomTree)
DT = &AM.getResult<DominatorTreeAnalysis>(F);
if (!simplifyFunctionCFG(F, TTI, DT, Options))
auto *UA = &AM.getResult<UniformityInfoAnalysis>(F);
if (!simplifyFunctionCFG(F, TTI, DT, UA, Options))
Comment on lines +384 to +385
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Naming inconsistency. It's UI everywhere else.

return PreservedAnalyses::all();
PreservedAnalyses PA;
if (RequireAndPreserveDomTree)
Expand Down Expand Up @@ -412,7 +416,8 @@ struct CFGSimplifyPass : public FunctionPass {
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();

auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return simplifyFunctionCFG(F, TTI, DT, Options);
auto &UI = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
return simplifyFunctionCFG(F, TTI, DT, &UI, Options);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
Expand All @@ -422,6 +427,7 @@ struct CFGSimplifyPass : public FunctionPass {
if (RequireAndPreserveDomTree)
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addRequired<UniformityInfoWrapperPass>();
}
};
}
Expand All @@ -432,6 +438,7 @@ INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
false)

Expand Down
55 changes: 39 additions & 16 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
Expand Down Expand Up @@ -258,6 +259,7 @@ struct ValueEqualityComparisonCase {
class SimplifyCFGOpt {
const TargetTransformInfo &TTI;
DomTreeUpdater *DTU;
UniformityInfo *UI;
const DataLayout &DL;
ArrayRef<WeakVH> LoopHeaders;
const SimplifyCFGOptions &Options;
Expand Down Expand Up @@ -306,9 +308,10 @@ class SimplifyCFGOpt {

public:
SimplifyCFGOpt(const TargetTransformInfo &TTI, DomTreeUpdater *DTU,
const DataLayout &DL, ArrayRef<WeakVH> LoopHeaders,
const SimplifyCFGOptions &Opts)
: TTI(TTI), DTU(DTU), DL(DL), LoopHeaders(LoopHeaders), Options(Opts) {
const DataLayout &DL, UniformityInfo *UI,
ArrayRef<WeakVH> LoopHeaders, const SimplifyCFGOptions &Opts)
: TTI(TTI), DTU(DTU), UI(UI), DL(DL), LoopHeaders(LoopHeaders),
Options(Opts) {
assert((!DTU || !DTU->hasPostDomTree()) &&
"SimplifyCFG is not yet capable of maintaining validity of a "
"PostDomTree, so don't ask for it.");
Expand Down Expand Up @@ -3490,6 +3493,17 @@ static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) {
return true;
}

static bool blockIsFreeToThreadThrough(BasicBlock *BB, PHINode *PN) {
unsigned Size = 0;
for (Instruction &I : BB->instructionsWithoutDebug(false)) {
if (&I == PN)
continue;
if (++Size > 1)
return false;
Comment on lines +3499 to +3502
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd prefer "count what we want" instead of "filter unwanted instructions first, count the rest later".

Size += (&I != PN);
if (Size > 1 )
  return false;

}
return true;
}

static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From,
BasicBlock *To) {
// Don't look past the block defining the value, we might get the value from
Expand All @@ -3511,10 +3525,9 @@ static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From,
/// If we have a conditional branch on something for which we know the constant
/// value in predecessors (e.g. a phi node in the current block), thread edges
/// from the predecessor to their ultimate destination.
static std::optional<bool>
foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
const DataLayout &DL,
AssumptionCache *AC) {
static std::optional<bool> foldCondBranchOnValueKnownInPredecessorImpl(
BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL,
const TargetTransformInfo &TTI, UniformityInfo *UI, AssumptionCache *AC) {
SmallMapVector<ConstantInt *, SmallSetVector<BasicBlock *, 2>, 2> KnownValues;
BasicBlock *BB = BI->getParent();
Value *Cond = BI->getCondition();
Expand Down Expand Up @@ -3555,6 +3568,16 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
if (RealDest == BB)
continue; // Skip self loops.

// Check to see that we're not duplicating instructions into divergent
// branches. Doing so would essentially double the execution time, since
// the instructions will be executed by divergent threads serially.
if (TTI.hasBranchDivergence() && UI &&
!blockIsFreeToThreadThrough(BB, PN) &&
any_of(PredBBs, [&](BasicBlock *PredBB) {
return UI->hasDivergentTerminator(*PredBB);
}))
continue;

// Skip if the predecessor's terminator is an indirect branch.
if (any_of(PredBBs, [](BasicBlock *PredBB) {
return isa<IndirectBrInst>(PredBB->getTerminator());
Expand Down Expand Up @@ -3669,15 +3692,15 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
return false;
}

static bool foldCondBranchOnValueKnownInPredecessor(BranchInst *BI,
DomTreeUpdater *DTU,
const DataLayout &DL,
AssumptionCache *AC) {
static bool foldCondBranchOnValueKnownInPredecessor(
BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL,
const TargetTransformInfo &TTI, UniformityInfo *UI, AssumptionCache *AC) {
std::optional<bool> Result;
bool EverChanged = false;
do {
// Note that None means "we changed things, but recurse further."
Result = foldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, AC);
Result =
foldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, TTI, UI, AC);
EverChanged |= Result == std::nullopt || *Result;
} while (Result == std::nullopt);
return EverChanged;
Expand Down Expand Up @@ -8082,7 +8105,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
// If this is a branch on something for which we know the constant value in
// predecessors (e.g. a phi node in the current block), thread control
// through this block.
if (foldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, Options.AC))
if (foldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, TTI, UI, Options.AC))
return requestResimplify();

// Scan predecessor blocks for conditional branches.
Expand Down Expand Up @@ -8402,9 +8425,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) {
}

bool llvm::simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI,
DomTreeUpdater *DTU, const SimplifyCFGOptions &Options,
DomTreeUpdater *DTU, UniformityInfo *UI,
const SimplifyCFGOptions &Options,
ArrayRef<WeakVH> LoopHeaders) {
return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), LoopHeaders,
Options)
return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), UI, LoopHeaders, Options)
.run(BB);
}
2 changes: 2 additions & 0 deletions llvm/test/Transforms/SimplifyCFG/NVPTX/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
if not "NVPTX" in config.root.targets:
config.unsupported = True
Loading
Loading