Skip to content
This repository was archived by the owner on Apr 23, 2020. It is now read-only.

Commit 07b40ea

Browse files
committed
[SDA] Don't stop divergence propagation at the IPD.
Summary: This fixes B42473 and B42706. This patch makes the SDA propagate branch divergence until the end of the RPO traversal. Before, the SyncDependenceAnalysis propagated divergence only until the IPD in rpo order. RPO is incompatible with post dominance in the presence of loops. This made the SDA crash because blocks were missed in the propagation. Reviewers: foad, nhaehnle Reviewed By: foad Subscribers: jvesely, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D65274 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@372223 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 901ea18 commit 07b40ea

File tree

2 files changed

+137
-35
lines changed

2 files changed

+137
-35
lines changed

lib/Analysis/SyncDependenceAnalysis.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -218,23 +218,31 @@ struct DivergencePropagator {
218218
template <typename SuccessorIterable>
219219
std::unique_ptr<ConstBlockSet>
220220
computeJoinPoints(const BasicBlock &RootBlock,
221-
SuccessorIterable NodeSuccessors, const Loop *ParentLoop, const BasicBlock * PdBoundBlock) {
221+
SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
222222
assert(JoinBlocks);
223223

224+
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " << (ParentLoop ? ParentLoop->getName() : "<null>") << "\n" );
225+
224226
// bootstrap with branch targets
225227
for (const auto *SuccBlock : NodeSuccessors) {
226228
DefMap.emplace(SuccBlock, SuccBlock);
227229

228230
if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
229231
// immediate loop exit from node.
230232
ReachedLoopExits.insert(SuccBlock);
231-
continue;
232233
} else {
233234
// regular successor
234235
PendingUpdates.insert(SuccBlock);
235236
}
236237
}
237238

239+
LLVM_DEBUG(
240+
dbgs() << "SDA: rpo order:\n";
241+
for (const auto * RpoBlock : FuncRPOT) {
242+
dbgs() << "- " << RpoBlock->getName() << "\n";
243+
}
244+
);
245+
238246
auto ItBeginRPO = FuncRPOT.begin();
239247

240248
// skip until term (TODO RPOT won't let us start at @term directly)
@@ -245,16 +253,18 @@ struct DivergencePropagator {
245253

246254
// propagate definitions at the immediate successors of the node in RPO
247255
auto ItBlockRPO = ItBeginRPO;
248-
while (++ItBlockRPO != ItEndRPO && *ItBlockRPO != PdBoundBlock) {
256+
while ((++ItBlockRPO != ItEndRPO) &&
257+
!PendingUpdates.empty()) {
249258
const auto *Block = *ItBlockRPO;
259+
LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
250260

251-
// skip @block if not pending update
261+
// skip Block if not pending update
252262
auto ItPending = PendingUpdates.find(Block);
253263
if (ItPending == PendingUpdates.end())
254264
continue;
255265
PendingUpdates.erase(ItPending);
256266

257-
// propagate definition at @block to its successors
267+
// propagate definition at Block to its successors
258268
auto ItDef = DefMap.find(Block);
259269
const auto *DefBlock = ItDef->second;
260270
assert(DefBlock);
@@ -278,6 +288,8 @@ struct DivergencePropagator {
278288
}
279289
}
280290

291+
LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
292+
281293
// We need to know the definition at the parent loop header to decide
282294
// whether the definition at the header is different from the definition at
283295
// the loop exits, which would indicate a divergent loop exits.
@@ -292,24 +304,17 @@ struct DivergencePropagator {
292304
// |
293305
// proper exit from both loops
294306
//
295-
// D post-dominates B as it is the only proper exit from the "A loop".
296-
// If C has a divergent branch, propagation will therefore stop at D.
297-
// That implies that B will never receive a definition.
298-
// But that definition can only be the same as at D (D itself in thise case)
299-
// because all paths to anywhere have to pass through D.
300-
//
301-
const BasicBlock *ParentLoopHeader =
302-
ParentLoop ? ParentLoop->getHeader() : nullptr;
303-
if (ParentLoop && ParentLoop->contains(PdBoundBlock)) {
304-
DefMap[ParentLoopHeader] = DefMap[PdBoundBlock];
305-
}
306-
307307
// analyze reached loop exits
308308
if (!ReachedLoopExits.empty()) {
309+
const BasicBlock *ParentLoopHeader =
310+
ParentLoop ? ParentLoop->getHeader() : nullptr;
311+
309312
assert(ParentLoop);
310-
const auto *HeaderDefBlock = DefMap[ParentLoopHeader];
313+
auto ItHeaderDef = DefMap.find(ParentLoopHeader);
314+
const auto *HeaderDefBlock = (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second;
315+
311316
LLVM_DEBUG(printDefs(dbgs()));
312-
assert(HeaderDefBlock && "no definition in header of carrying loop");
317+
assert(HeaderDefBlock && "no definition at header of carrying loop");
313318

314319
for (const auto *ExitBlock : ReachedLoopExits) {
315320
auto ItExitDef = DefMap.find(ExitBlock);
@@ -339,19 +344,10 @@ const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
339344
return *ItCached->second;
340345
}
341346

342-
// dont propagte beyond the immediate post dom of the loop
343-
const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(Loop.getHeader()));
344-
const auto *IpdNode = PdNode->getIDom();
345-
const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
346-
while (PdBoundBlock && Loop.contains(PdBoundBlock)) {
347-
IpdNode = IpdNode->getIDom();
348-
PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
349-
}
350-
351347
// compute all join points
352348
DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
353349
auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
354-
*Loop.getHeader(), LoopExits, Loop.getParentLoop(), PdBoundBlock);
350+
*Loop.getHeader(), LoopExits, Loop.getParentLoop());
355351

356352
auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
357353
assert(ItInserted.second);
@@ -370,16 +366,11 @@ SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
370366
if (ItCached != CachedBranchJoins.end())
371367
return *ItCached->second;
372368

373-
// dont propagate beyond the immediate post dominator of the branch
374-
const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(Term.getParent()));
375-
const auto *IpdNode = PdNode->getIDom();
376-
const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
377-
378369
// compute all join points
379370
DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
380371
const auto &TermBlock = *Term.getParent();
381372
auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>(
382-
TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock), PdBoundBlock);
373+
TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
383374

384375
auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
385376
assert(ItInserted.second);
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
; RUN: opt -mtriple amdgcn-unknown-amdhsa -analyze -divergence -use-gpu-divergence-analysis %s | FileCheck %s
2+
3+
declare i32 @gf2(i32)
4+
declare i32 @gf1(i32)
5+
6+
define void @tw1(i32 addrspace(4)* noalias nocapture readonly %A, i32 addrspace(4)* noalias nocapture %B) local_unnamed_addr #2 {
7+
; CHECK: Printing analysis 'Legacy Divergence Analysis' for function 'tw1':
8+
; CHECK: DIVERGENT: i32 addrspace(4)* %A
9+
; CHECK: DIVERGENT: i32 addrspace(4)* %B
10+
entry:
11+
; CHECK: DIVERGENT: %call = tail call i32 @gf2(i32 0) #0
12+
; CHECK: DIVERGENT: %cmp = icmp ult i32 %call, 16
13+
; CHECK: DIVERGENT: br i1 %cmp, label %if.then, label %new_exit
14+
%call = tail call i32 @gf2(i32 0) #3
15+
%cmp = icmp ult i32 %call, 16
16+
br i1 %cmp, label %if.then, label %new_exit
17+
18+
if.then:
19+
; CHECK: DIVERGENT: %call1 = tail call i32 @gf1(i32 0) #0
20+
; CHECK: DIVERGENT: %arrayidx = getelementptr inbounds i32, i32 addrspace(4)* %A, i32 %call1
21+
; CHECK: DIVERGENT: %0 = load i32, i32 addrspace(4)* %arrayidx, align 4
22+
; CHECK: DIVERGENT: %cmp225 = icmp sgt i32 %0, 0
23+
; CHECK: DIVERGENT: %arrayidx10 = getelementptr inbounds i32, i32 addrspace(4)* %B, i32 %call1
24+
; CHECK: DIVERGENT: br i1 %cmp225, label %while.body.preheader, label %if.then.while.end_crit_edge
25+
%call1 = tail call i32 @gf1(i32 0) #4
26+
%arrayidx = getelementptr inbounds i32, i32 addrspace(4)* %A, i32 %call1
27+
%0 = load i32, i32 addrspace(4)* %arrayidx, align 4
28+
%cmp225 = icmp sgt i32 %0, 0
29+
%arrayidx10 = getelementptr inbounds i32, i32 addrspace(4)* %B, i32 %call1
30+
br i1 %cmp225, label %while.body.preheader, label %if.then.while.end_crit_edge
31+
32+
while.body.preheader:
33+
br label %while.body
34+
35+
if.then.while.end_crit_edge:
36+
; CHECK: DIVERGENT: %.pre = load i32, i32 addrspace(4)* %arrayidx10, align 4
37+
%.pre = load i32, i32 addrspace(4)* %arrayidx10, align 4
38+
br label %while.end
39+
40+
while.body:
41+
; CHECK-NOT: DIVERGENT: %i.026 = phi i32 [ %inc, %if.end.while.body_crit_edge ], [ 0, %while.body.preheader ]
42+
; CHECK: DIVERGENT: %call3 = tail call i32 @gf1(i32 0) #0
43+
; CHECK: DIVERGENT: %cmp4 = icmp ult i32 %call3, 10
44+
; CHECK: DIVERGENT: %arrayidx6 = getelementptr inbounds i32, i32 addrspace(4)* %A, i32 %i.026
45+
; CHECK: DIVERGENT: %1 = load i32, i32 addrspace(4)* %arrayidx6, align 4
46+
; CHECK: DIVERGENT: br i1 %cmp4, label %if.then5, label %if.else
47+
%i.026 = phi i32 [ %inc, %if.end.while.body_crit_edge ], [ 0, %while.body.preheader ]
48+
%call3 = tail call i32 @gf1(i32 0) #4
49+
%cmp4 = icmp ult i32 %call3, 10
50+
%arrayidx6 = getelementptr inbounds i32, i32 addrspace(4)* %A, i32 %i.026
51+
%1 = load i32, i32 addrspace(4)* %arrayidx6, align 4
52+
br i1 %cmp4, label %if.then5, label %if.else
53+
54+
if.then5:
55+
; CHECK: DIVERGENT: %mul = shl i32 %1, 1
56+
; CHECK: DIVERGENT: %2 = load i32, i32 addrspace(4)* %arrayidx10, align 4
57+
; CHECK: DIVERGENT: %add = add nsw i32 %2, %mul
58+
%mul = shl i32 %1, 1
59+
%2 = load i32, i32 addrspace(4)* %arrayidx10, align 4
60+
%add = add nsw i32 %2, %mul
61+
br label %if.end
62+
63+
if.else:
64+
; CHECK: DIVERGENT: %mul9 = shl i32 %1, 2
65+
; CHECK: DIVERGENT: %3 = load i32, i32 addrspace(4)* %arrayidx10, align 4
66+
; CHECK: DIVERGENT: %add11 = add nsw i32 %3, %mul9
67+
%mul9 = shl i32 %1, 2
68+
%3 = load i32, i32 addrspace(4)* %arrayidx10, align 4
69+
%add11 = add nsw i32 %3, %mul9
70+
br label %if.end
71+
72+
if.end:
73+
; CHECK: DIVERGENT: %storemerge = phi i32 [ %add11, %if.else ], [ %add, %if.then5 ]
74+
; CHECK: DIVERGENT: store i32 %storemerge, i32 addrspace(4)* %arrayidx10, align 4
75+
; CHECK-NOT: DIVERGENT: %inc = add nuw nsw i32 %i.026, 1
76+
; CHECK: DIVERGENT: %exitcond = icmp ne i32 %inc, %0
77+
; CHECK: DIVERGENT: br i1 %exitcond, label %if.end.while.body_crit_edge, label %while.end.loopexit
78+
%storemerge = phi i32 [ %add11, %if.else ], [ %add, %if.then5 ]
79+
store i32 %storemerge, i32 addrspace(4)* %arrayidx10, align 4
80+
%inc = add nuw nsw i32 %i.026, 1
81+
%exitcond = icmp ne i32 %inc, %0
82+
br i1 %exitcond, label %if.end.while.body_crit_edge, label %while.end.loopexit
83+
84+
if.end.while.body_crit_edge:
85+
br label %while.body
86+
87+
while.end.loopexit:
88+
; CHECK: DIVERGENT: %storemerge.lcssa = phi i32 [ %storemerge, %if.end ]
89+
%storemerge.lcssa = phi i32 [ %storemerge, %if.end ]
90+
br label %while.end
91+
92+
while.end:
93+
; CHECK: DIVERGENT: %4 = phi i32 [ %.pre, %if.then.while.end_crit_edge ], [ %storemerge.lcssa, %while.end.loopexit ]
94+
; CHECK: DIVERGENT: %i.0.lcssa = phi i32 [ 0, %if.then.while.end_crit_edge ], [ %0, %while.end.loopexit ]
95+
; CHECK: DIVERGENT: %sub = sub nsw i32 %4, %i.0.lcssa
96+
; CHECK: DIVERGENT: store i32 %sub, i32 addrspace(4)* %arrayidx10, align 4
97+
%4 = phi i32 [ %.pre, %if.then.while.end_crit_edge ], [ %storemerge.lcssa, %while.end.loopexit ]
98+
%i.0.lcssa = phi i32 [ 0, %if.then.while.end_crit_edge ], [ %0, %while.end.loopexit ]
99+
%sub = sub nsw i32 %4, %i.0.lcssa
100+
store i32 %sub, i32 addrspace(4)* %arrayidx10, align 4
101+
br label %new_exit
102+
103+
new_exit:
104+
ret void
105+
}
106+
107+
attributes #0 = { nounwind readnone }
108+
attributes #1 = { nounwind readnone }
109+
attributes #2 = { nounwind readnone }
110+
attributes #3 = { nounwind readnone }
111+
attributes #4 = { nounwind readnone }

0 commit comments

Comments
 (0)