Skip to content

Commit 3eb9e77

Browse files
[CIR] Implement switch case simplify (#140649)
This PR introduces a new **CIR simplify for `switch` cases**, which folds multiple **cascading `Equal` cases** (that contain only a `YieldOp`) into a single `CaseOp` of kind `AnyOf`. This logic is based on the suggestion from this discussion: #138003 (comment)
1 parent 35434f2 commit 3eb9e77

File tree

4 files changed

+300
-9
lines changed

4 files changed

+300
-9
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ struct MissingFeatures {
120120
static bool opUnaryPromotionType() { return false; }
121121

122122
// SwitchOp handling
123-
static bool foldCascadingCases() { return false; }
124123
static bool foldRangeCase() { return false; }
125124

126125
// Clang early optimizations or things defered to LLVM lowering.

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,12 +533,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
533533
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
534534
cir::IntAttr::get(condType, endVal)});
535535
kind = cir::CaseOpKind::Range;
536-
537-
// We don't currently fold case range statements with other case statements.
538-
// TODO(cir): Add this capability. Folding these cases is going to be
539-
// implemented in CIRSimplify when it is upstreamed.
540-
assert(!cir::MissingFeatures::foldRangeCase());
541-
assert(!cir::MissingFeatures::foldCascadingCases());
542536
} else {
543537
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
544538
kind = cir::CaseOpKind::Equal;

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
159159
}
160160
};
161161

162+
/// Simplify `cir.switch` operations by folding cascading cases
163+
/// into a single `cir.case` with the `anyof` kind.
164+
///
165+
/// This pattern identifies cascading cases within a `cir.switch` operation.
166+
/// Cascading cases are defined as consecutive `cir.case` operations of kind
167+
/// `equal`, each containing a single `cir.yield` operation in their body.
168+
///
169+
/// The pattern merges these cascading cases into a single `cir.case` operation
170+
/// with kind `anyof`, aggregating all the case values.
171+
///
172+
/// The merging process continues until a `cir.case` with a different body
173+
/// (e.g., containing `cir.break` or compound stmt) is encountered, which
174+
/// breaks the chain.
175+
///
176+
/// Example:
177+
///
178+
/// Before:
179+
/// cir.case equal, [#cir.int<0> : !s32i] {
180+
/// cir.yield
181+
/// }
182+
/// cir.case equal, [#cir.int<1> : !s32i] {
183+
/// cir.yield
184+
/// }
185+
/// cir.case equal, [#cir.int<2> : !s32i] {
186+
/// cir.break
187+
/// }
188+
///
189+
/// After applying SimplifySwitch:
190+
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191+
/// !s32i] {
192+
/// cir.break
193+
/// }
194+
struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
195+
using OpRewritePattern<SwitchOp>::OpRewritePattern;
196+
LogicalResult matchAndRewrite(SwitchOp op,
197+
PatternRewriter &rewriter) const override {
198+
199+
LogicalResult changed = mlir::failure();
200+
SmallVector<CaseOp, 8> cases;
201+
SmallVector<CaseOp, 4> cascadingCases;
202+
SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203+
204+
op.collectCases(cases);
205+
if (cases.empty())
206+
return mlir::failure();
207+
208+
auto flushMergedOps = [&]() {
209+
for (CaseOp &c : cascadingCases)
210+
rewriter.eraseOp(c);
211+
cascadingCases.clear();
212+
cascadingCaseValues.clear();
213+
};
214+
215+
auto mergeCascadingInto = [&](CaseOp &target) {
216+
rewriter.modifyOpInPlace(target, [&]() {
217+
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218+
target.setKind(CaseOpKind::Anyof);
219+
});
220+
changed = mlir::success();
221+
};
222+
223+
for (CaseOp c : cases) {
224+
cir::CaseOpKind kind = c.getKind();
225+
if (kind == cir::CaseOpKind::Equal &&
226+
isa<YieldOp>(c.getCaseRegion().front().front())) {
227+
// If the case contains only a YieldOp, collect it for cascading merge
228+
cascadingCases.push_back(c);
229+
cascadingCaseValues.push_back(c.getValue()[0]);
230+
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231+
// merge previously collected cascading cases
232+
cascadingCaseValues.push_back(c.getValue()[0]);
233+
mergeCascadingInto(c);
234+
flushMergedOps();
235+
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236+
// If a Default, Anyof or Range case is found and there are previous
237+
// cascading cases, merge all of them into the last cascading case.
238+
// We don't currently fold case range statements with other case
239+
// statements.
240+
assert(!cir::MissingFeatures::foldRangeCase());
241+
CaseOp lastCascadingCase = cascadingCases.back();
242+
mergeCascadingInto(lastCascadingCase);
243+
cascadingCases.pop_back();
244+
flushMergedOps();
245+
} else {
246+
cascadingCases.clear();
247+
cascadingCaseValues.clear();
248+
}
249+
}
250+
251+
// Edge case: all cases are simple cascading cases
252+
if (cascadingCases.size() == cases.size()) {
253+
CaseOp lastCascadingCase = cascadingCases.back();
254+
mergeCascadingInto(lastCascadingCase);
255+
cascadingCases.pop_back();
256+
flushMergedOps();
257+
}
258+
259+
return changed;
260+
}
261+
};
262+
162263
//===----------------------------------------------------------------------===//
163264
// CIRSimplifyPass
164265
//===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
173274
// clang-format off
174275
patterns.add<
175276
SimplifyTernary,
176-
SimplifySelect
277+
SimplifySelect,
278+
SimplifySwitch
177279
>(patterns.getContext());
178280
// clang-format on
179281
}
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
186288
// Collect operations to apply patterns.
187289
llvm::SmallVector<Operation *, 16> ops;
188290
getOperation()->walk([&](Operation *op) {
189-
if (isa<TernaryOp, SelectOp>(op))
291+
if (isa<TernaryOp, SelectOp, SwitchOp>(op))
190292
ops.push_back(op);
191293
});
192294

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @foldCascade(%arg0: !s32i) {
8+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
9+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
10+
cir.scope {
11+
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
12+
cir.switch (%1 : !s32i) {
13+
cir.case(equal, [#cir.int<1> : !s32i]) {
14+
cir.yield
15+
}
16+
cir.case(equal, [#cir.int<2> : !s32i]) {
17+
cir.yield
18+
}
19+
cir.case(equal, [#cir.int<3> : !s32i]) {
20+
%2 = cir.const #cir.int<2> : !s32i
21+
cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
22+
cir.break
23+
}
24+
cir.yield
25+
}
26+
}
27+
cir.return
28+
}
29+
//CHECK: cir.func @foldCascade
30+
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
31+
//CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
32+
//CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
33+
//CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
34+
//CHECK-NEXT: cir.break
35+
//CHECK-NEXT: }
36+
//CHECK-NEXT: cir.yield
37+
//CHECK-NEXT: }
38+
39+
cir.func @foldCascade2(%arg0: !s32i) {
40+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
41+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
42+
cir.scope {
43+
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
44+
cir.switch (%1 : !s32i) {
45+
cir.case(equal, [#cir.int<0> : !s32i]) {
46+
cir.yield
47+
}
48+
cir.case(equal, [#cir.int<2> : !s32i]) {
49+
cir.yield
50+
}
51+
cir.case(equal, [#cir.int<4> : !s32i]) {
52+
cir.break
53+
}
54+
cir.case(equal, [#cir.int<1> : !s32i]) {
55+
cir.yield
56+
}
57+
cir.case(equal, [#cir.int<3> : !s32i]) {
58+
cir.yield
59+
}
60+
cir.case(equal, [#cir.int<5> : !s32i]) {
61+
cir.break
62+
}
63+
cir.yield
64+
}
65+
}
66+
cir.return
67+
}
68+
//CHECK: @foldCascade2
69+
//CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
70+
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<2> : !s32i, #cir.int<4> : !s32i]) {
71+
//CHECK: cir.break
72+
//cehck: }
73+
//CHECK: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i]) {
74+
//CHECK: cir.break
75+
//CHECK: }
76+
//CHECK: cir.yield
77+
//CHECK: }
78+
cir.func @foldCascade3(%arg0: !s32i ) {
79+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
80+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
81+
cir.scope {
82+
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
83+
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
84+
cir.switch (%2 : !s32i) {
85+
cir.case(equal, [#cir.int<0> : !s32i]) {
86+
cir.yield
87+
}
88+
cir.case(equal, [#cir.int<1> : !s32i]) {
89+
cir.yield
90+
}
91+
cir.case(equal, [#cir.int<2> : !s32i]) {
92+
cir.yield
93+
}
94+
cir.case(equal, [#cir.int<3> : !s32i]) {
95+
cir.yield
96+
}
97+
cir.case(equal, [#cir.int<4> : !s32i]) {
98+
cir.yield
99+
}
100+
cir.case(equal, [#cir.int<5> : !s32i]) {
101+
cir.break
102+
}
103+
cir.yield
104+
}
105+
}
106+
cir.return
107+
}
108+
//CHECK: cir.func @foldCascade3
109+
//CHECK: cir.switch (%[[COND3:.*]] : !s32i) {
110+
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
111+
//CHECK: cir.break
112+
//CHECK: }
113+
//CHECK: cir.yield
114+
//CHECK: }
115+
cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
116+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
117+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
118+
cir.scope {
119+
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
120+
cir.switch (%1 : !s32i) {
121+
cir.case(equal, [#cir.int<3> : !s32i]) {
122+
cir.break
123+
}
124+
cir.case(equal, [#cir.int<4> : !s32i]) {
125+
cir.yield
126+
}
127+
cir.case(equal, [#cir.int<5> : !s32i]) {
128+
cir.yield
129+
}
130+
cir.case(default, []) {
131+
cir.yield
132+
}
133+
cir.case(equal, [#cir.int<6> : !s32i]) {
134+
cir.yield
135+
}
136+
cir.case(equal, [#cir.int<7> : !s32i]) {
137+
cir.break
138+
}
139+
cir.yield
140+
}
141+
}
142+
cir.return
143+
}
144+
//CHECK: cir.func @foldCascadeWithDefault
145+
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
146+
//CHECK: cir.case(equal, [#cir.int<3> : !s32i]) {
147+
//CHECK: cir.break
148+
//CHECK: }
149+
//CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
150+
//CHECK: cir.yield
151+
//CHECK: }
152+
//CHECK: cir.case(default, []) {
153+
//CHECK: cir.yield
154+
//CHECK: }
155+
//CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
156+
//CHECK: cir.break
157+
//CHECK: }
158+
//CHECK: cir.yield
159+
//CHECK: }
160+
cir.func @foldAllCascade(%arg0: !s32i ) {
161+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
162+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
163+
cir.scope {
164+
%1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
165+
cir.switch (%1 : !s32i) {
166+
cir.case(equal, [#cir.int<0> : !s32i]) {
167+
cir.yield
168+
}
169+
cir.case(equal, [#cir.int<1> : !s32i]) {
170+
cir.yield
171+
}
172+
cir.case(equal, [#cir.int<2> : !s32i]) {
173+
cir.yield
174+
}
175+
cir.case(equal, [#cir.int<3> : !s32i]) {
176+
cir.yield
177+
}
178+
cir.case(equal, [#cir.int<4> : !s32i]) {
179+
cir.yield
180+
}
181+
cir.case(equal, [#cir.int<5> : !s32i]) {
182+
cir.yield
183+
}
184+
cir.yield
185+
}
186+
}
187+
cir.return
188+
}
189+
//CHECK: cir.func @foldAllCascade
190+
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
191+
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
192+
//CHECK: cir.yield
193+
//CHECK: }
194+
//CHECK: cir.yield
195+
//CHECK: }
196+
}

0 commit comments

Comments
 (0)