Skip to content

Commit 6421248

Browse files
jgu222ssahasra
andauthored
[Uniformity] Fixed control-div early stop (#139667)
Control-divergence finds joins by propagating labels from the divergent control branch. The code that checks the early stop for propagation is not correct in some cases. This PR, also included changes from ssahasra, fixes this issue by stopping no early than the post-dominator of the divergent branch. #137277 --------- Co-authored-by: Sameer Sahasrabuddhe <sameer.sahasrabuddhe@amd.com>
1 parent c842705 commit 6421248

File tree

7 files changed

+428
-33
lines changed

7 files changed

+428
-33
lines changed

llvm/include/llvm/ADT/GenericUniformityImpl.h

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -610,13 +610,29 @@ template <typename ContextT> class DivergencePropagator {
610610
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
611611
<< Context.print(&DivTermBlock) << "\n");
612612

613-
// Early stopping criterion
614-
int FloorIdx = CyclePOT.size() - 1;
615-
const BlockT *FloorLabel = nullptr;
616613
int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
617614

618615
// Bootstrap with branch targets
619616
auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
617+
618+
// Locate the largest ancestor cycle that is not reducible and does not
619+
// contain a reducible ancestor. This is done with a lambda that is defined
620+
// and invoked in the same statement.
621+
const CycleT *IrreducibleAncestor = [](const CycleT *C) -> const CycleT * {
622+
if (!C)
623+
return nullptr;
624+
if (C->isReducible())
625+
return nullptr;
626+
while (const CycleT *P = C->getParentCycle()) {
627+
if (P->isReducible())
628+
return C;
629+
C = P;
630+
}
631+
assert(!C->getParentCycle());
632+
assert(!C->isReducible());
633+
return C;
634+
}(DivTermCycle);
635+
620636
for (const auto *SuccBlock : successors(&DivTermBlock)) {
621637
if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
622638
// If DivTerm exits the cycle immediately, computeJoin() might
@@ -626,14 +642,24 @@ template <typename ContextT> class DivergencePropagator {
626642
LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
627643
<< Context.print(SuccBlock) << "\n");
628644
}
629-
auto SuccIdx = CyclePOT.getIndex(SuccBlock);
630645
visitEdge(*SuccBlock, *SuccBlock);
631-
FloorIdx = std::min<int>(FloorIdx, SuccIdx);
632646
}
633647

648+
// Technically propagation can continue until it reaches the last node.
649+
//
650+
// For efficiency, propagation can stop if FreshLabels.count()==1. But
651+
// For irreducible cycles, let propagation continue until it reaches
652+
// out of irreducible cycles (see code for details.)
634653
while (true) {
635654
auto BlockIdx = FreshLabels.find_last();
636-
if (BlockIdx == -1 || BlockIdx < FloorIdx)
655+
if (BlockIdx == -1)
656+
break;
657+
658+
const auto *Block = CyclePOT[BlockIdx];
659+
// If no irreducible cycle, stop if freshLable.count() = 1 and Block
660+
// is the IPD. If it is in any irreducible cycle, continue propagation.
661+
if (FreshLabels.count() == 1 &&
662+
(!IrreducibleAncestor || !IrreducibleAncestor->contains(Block)))
637663
break;
638664

639665
LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
@@ -644,16 +670,12 @@ template <typename ContextT> class DivergencePropagator {
644670
continue;
645671
}
646672

647-
const auto *Block = CyclePOT[BlockIdx];
648673
LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
649674
<< BlockIdx << "\n");
650675

651676
const auto *Label = BlockLabels[Block];
652677
assert(Label);
653678

654-
bool CausedJoin = false;
655-
int LoweredFloorIdx = FloorIdx;
656-
657679
// If the current block is the header of a reducible cycle that
658680
// contains the divergent branch, then the label should be
659681
// propagated to the cycle exits. Such a header is the "last
@@ -681,28 +703,11 @@ template <typename ContextT> class DivergencePropagator {
681703
if (const auto *BlockCycle = getReducibleParent(Block)) {
682704
SmallVector<BlockT *, 4> BlockCycleExits;
683705
BlockCycle->getExitBlocks(BlockCycleExits);
684-
for (auto *BlockCycleExit : BlockCycleExits) {
685-
CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
686-
LoweredFloorIdx =
687-
std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
688-
}
706+
for (auto *BlockCycleExit : BlockCycleExits)
707+
visitCycleExitEdge(*BlockCycleExit, *Label);
689708
} else {
690-
for (const auto *SuccBlock : successors(Block)) {
691-
CausedJoin |= visitEdge(*SuccBlock, *Label);
692-
LoweredFloorIdx =
693-
std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
694-
}
695-
}
696-
697-
// Floor update
698-
if (CausedJoin) {
699-
// 1. Different labels pushed to successors
700-
FloorIdx = LoweredFloorIdx;
701-
} else if (FloorLabel != Label) {
702-
// 2. No join caused BUT we pushed a label that is different than the
703-
// last pushed label
704-
FloorIdx = LoweredFloorIdx;
705-
FloorLabel = Label;
709+
for (const auto *SuccBlock : successors(Block))
710+
visitEdge(*SuccBlock, *Label);
706711
}
707712
}
708713

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
;
2+
; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
3+
;
4+
;
5+
; Entry (div.cond)
6+
; / \
7+
; B0 B3
8+
; | |
9+
; B1 B4
10+
; | |
11+
; \ /
12+
; B5 (phi: divergent)
13+
; |
14+
; B6 (div.uni)
15+
; / \
16+
; B7 B9
17+
; | |
18+
; B8 B10
19+
; | |
20+
; \ /
21+
; B11 (phi: uniform)
22+
23+
24+
; CHECK-LABEL: 'test_ctrl_divergence':
25+
; CHECK-LABEL: BLOCK Entry
26+
; CHECK: DIVERGENT: %div.cond = icmp eq i32 %tid, 0
27+
; CHECK: DIVERGENT: br i1 %div.cond, label %B3, label %B0
28+
;
29+
; CHECK-LABEL: BLOCK B5
30+
; CHECK: DIVERGENT: %div_a = phi i32 [ %a0, %B1 ], [ %a1, %B4 ]
31+
; CHECK: DIVERGENT: %div_b = phi i32 [ %b0, %B1 ], [ %b1, %B4 ]
32+
;
33+
; CHECK-LABEL: BLOCK B6
34+
; CHECK-NOT: DIVERGENT: %uni.cond = icmp
35+
; CHECK-NOT: DIVERGENT: br i1 %div.cond
36+
;
37+
; CHECK-LABEL: BLOCK B11
38+
; CHECK-NOT: DIVERGENT: %div_d = phi i32
39+
40+
41+
define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
42+
Entry:
43+
%tid = call i32 @llvm.amdgcn.workitem.id.x()
44+
%div.cond = icmp eq i32 %tid, 0
45+
br i1 %div.cond, label %B3, label %B0 ; divergent branch
46+
47+
B0:
48+
%a0 = add i32 %a, 1
49+
br label %B1
50+
51+
B1:
52+
%b0 = add i32 %b, 2
53+
br label %B5
54+
55+
B3:
56+
%a1 = add i32 %a, 10
57+
br label %B4
58+
59+
B4:
60+
%b1 = add i32 %b, 20
61+
br label %B5
62+
63+
B5:
64+
%div_a = phi i32 [%a0, %B1], [%a1, %B4]
65+
%div_b = phi i32 [%b0, %B1], [%b1, %B4]
66+
br label %B6
67+
68+
B6:
69+
%uni.cond = icmp eq i32 %c, 0
70+
br i1 %uni.cond, label %B7, label %B9
71+
72+
B7:
73+
%d1 = add i32 %d, 1
74+
br label %B8
75+
76+
B8:
77+
br label %B11
78+
79+
B9:
80+
%d2 = add i32 %d, 3
81+
br label %B10
82+
83+
B10:
84+
br label %B11
85+
86+
B11:
87+
%div_d = phi i32 [%d1, %B8], [%d2, %B10]
88+
ret void
89+
}
90+
91+
92+
declare i32 @llvm.amdgcn.workitem.id.x() #0
93+
94+
attributes #0 = {nounwind readnone }

llvm/test/Analysis/UniformityAnalysis/AMDGPU/irreducible/diverged-entry-headers-nested.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ exit:
128128
;; CHECK-LABEL: UniformityInfo for function 'headers_b_t':
129129
;; CHECK: CYCLES ASSSUMED DIVERGENT:
130130
;; CHECK: depth=2: entries(T P) S Q R
131-
;; CHECK: CYCLES WITH DIVERGENT EXIT:
132-
;; CHECK: depth=1: entries(B A) D T S Q P R C
131+
;; CHECK-NOT: CYCLES WITH DIVERGENT EXIT:
133132

134133
define amdgpu_kernel void @headers_b_t(i32 %a, i32 %b, i32 %c) {
135134
entry:
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
; RUN: opt %s -mtriple amdgcn-- -passes='print<uniformity>' -disable-output 2>&1 | FileCheck %s
2+
3+
define amdgpu_kernel void @cycle_inner_ipd(i32 %n, i32 %a, i32 %b) #0 {
4+
;
5+
; entry
6+
; / \
7+
; E2<------E1
8+
; | \ ^^
9+
; | \ / |
10+
; | v/ |
11+
; | A |
12+
; | / |
13+
; | / |
14+
; vv |
15+
; B------->C
16+
; |
17+
; X
18+
;
19+
;
20+
; CHECK-LABEL: BLOCK entry
21+
; CHECK: DIVERGENT: %tid = call i32 @llvm.amdgcn.workitem.id.x()
22+
; CHECK: DIVERGENT: %div.cond = icmp slt i32 %tid, 0
23+
; CHECK: END BLOCK
24+
;
25+
; CHECK-LABEL: BLOCK B
26+
; CHECK: DIVERGENT: %div.merge = phi i32 [ 0, %A ], [ %b, %E2 ]
27+
; CHECK: END BLOCK
28+
29+
entry:
30+
%tid = call i32 @llvm.amdgcn.workitem.id.x()
31+
%div.cond = icmp slt i32 %tid, 0
32+
%uni.cond = icmp slt i32 %a, 0
33+
%uni.cond1 = icmp slt i32 %a, 2
34+
%uni.cond2 = icmp slt i32 %a, 10
35+
br i1 %uni.cond, label %E2, label %E1
36+
37+
E1:
38+
br label %E2
39+
40+
E2:
41+
br i1 %uni.cond1, label %A, label %B
42+
43+
44+
A:
45+
br i1 %div.cond, label %E1, label %B
46+
47+
B:
48+
%div.merge = phi i32 [ 0, %A ], [ %b, %E2 ]
49+
br label %C
50+
51+
C:
52+
br i1 %uni.cond2, label %E1, label %X
53+
54+
X:
55+
ret void
56+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
2+
;
3+
; This is to test an if-then-else case with some unmerged basic blocks
4+
; (https://github.com/llvm/llvm-project/issues/137277)
5+
;
6+
; Entry (div.cond)
7+
; / \
8+
; B0 B3
9+
; | |
10+
; B1 B4
11+
; | |
12+
; B2 B5
13+
; \ /
14+
; B6 (phi: divergent)
15+
;
16+
17+
18+
; CHECK-LABEL: 'test_ctrl_divergence':
19+
; CHECK-LABEL: BLOCK Entry
20+
; CHECK: DIVERGENT: %div.cond = icmp eq i32 %tid, 0
21+
; CHECK: DIVERGENT: br i1 %div.cond, label %B3, label %B0
22+
;
23+
; CHECK-LABEL: BLOCK B6
24+
; CHECK: DIVERGENT: %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ]
25+
; CHECK: DIVERGENT: %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ]
26+
; CHECK: DIVERGENT: %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ]
27+
28+
29+
define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
30+
Entry:
31+
%tid = call i32 @llvm.amdgcn.workitem.id.x()
32+
%div.cond = icmp eq i32 %tid, 0
33+
br i1 %div.cond, label %B3, label %B0 ; divergent branch
34+
35+
B0:
36+
%a0 = add i32 %a, 1
37+
br label %B1
38+
39+
B1:
40+
%b0 = add i32 %b, 2
41+
br label %B2
42+
43+
B2:
44+
%c0 = add i32 %c, 3
45+
br label %B6
46+
47+
B3:
48+
%a1 = add i32 %a, 10
49+
br label %B4
50+
51+
B4:
52+
%b1 = add i32 %b, 20
53+
br label %B5
54+
55+
B5:
56+
%c1 = add i32 %c, 30
57+
br label %B6
58+
59+
B6:
60+
%div_a = phi i32 [%a0, %B2], [%a1, %B5]
61+
%div_b = phi i32 [%b0, %B2], [%b1, %B5]
62+
%div_c = phi i32 [%c0, %B2], [%c1, %B5]
63+
br i1 %div.cond, label %B8, label %B7 ; divergent branch
64+
65+
B7:
66+
%d1 = add i32 %d, 1
67+
br label %B8
68+
69+
B8:
70+
%div_d = phi i32 [%d1, %B7], [%d, %B6]
71+
ret void
72+
}
73+
74+
75+
declare i32 @llvm.amdgcn.workitem.id.x()

0 commit comments

Comments
 (0)