Skip to content

Commit 67696a1

Browse files
authored
[ADT] Reduce copies and recursion in StringSwitch (llvm#125362)
Optimize the `.Cases` and `.CasesLower` functions to avoid needlessly recursing on each case and copying the associated values. We can instead take `Value` by reference and short-circuit by using the `||` operator. Note that while the implementation uses variadic templates, we cannot simplify the public functions in the same way. This is because the current API forces the arguments to be converted to `StringLiterals` and places the `Value` parameter at the very end. Even if we did some tricks like split the parameter pack to separate out the `Value`, I do not see how we could force conversion to `StringLiteral`.
1 parent 22d9726 commit 67696a1

File tree

2 files changed

+73
-20
lines changed

2 files changed

+73
-20
lines changed

llvm/include/llvm/ADT/StringSwitch.h

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#define LLVM_ADT_STRINGSWITCH_H
1515

1616
#include "llvm/ADT/StringRef.h"
17-
#include "llvm/Support/Compiler.h"
1817
#include <cassert>
1918
#include <cstring>
2019
#include <optional>
@@ -67,9 +66,7 @@ class StringSwitch {
6766

6867
// Case-sensitive case matchers
6968
StringSwitch &Case(StringLiteral S, T Value) {
70-
if (!Result && Str == S) {
71-
Result = std::move(Value);
72-
}
69+
CaseImpl(Value, S);
7370
return *this;
7471
}
7572

@@ -88,61 +85,59 @@ class StringSwitch {
8885
}
8986

9087
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, T Value) {
91-
return Case(S0, Value).Case(S1, Value);
88+
return CasesImpl(Value, S0, S1);
9289
}
9390

9491
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
9592
T Value) {
96-
return Case(S0, Value).Cases(S1, S2, Value);
93+
return CasesImpl(Value, S0, S1, S2);
9794
}
9895

9996
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
10097
StringLiteral S3, T Value) {
101-
return Case(S0, Value).Cases(S1, S2, S3, Value);
98+
return CasesImpl(Value, S0, S1, S2, S3);
10299
}
103100

104101
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
105102
StringLiteral S3, StringLiteral S4, T Value) {
106-
return Case(S0, Value).Cases(S1, S2, S3, S4, Value);
103+
return CasesImpl(Value, S0, S1, S2, S3, S4);
107104
}
108105

109106
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
110107
StringLiteral S3, StringLiteral S4, StringLiteral S5,
111108
T Value) {
112-
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, Value);
109+
return CasesImpl(Value, S0, S1, S2, S3, S4, S5);
113110
}
114111

115112
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
116113
StringLiteral S3, StringLiteral S4, StringLiteral S5,
117114
StringLiteral S6, T Value) {
118-
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, Value);
115+
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6);
119116
}
120117

121118
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
122119
StringLiteral S3, StringLiteral S4, StringLiteral S5,
123120
StringLiteral S6, StringLiteral S7, T Value) {
124-
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, Value);
121+
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7);
125122
}
126123

127124
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
128125
StringLiteral S3, StringLiteral S4, StringLiteral S5,
129126
StringLiteral S6, StringLiteral S7, StringLiteral S8,
130127
T Value) {
131-
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, Value);
128+
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8);
132129
}
133130

134131
StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
135132
StringLiteral S3, StringLiteral S4, StringLiteral S5,
136133
StringLiteral S6, StringLiteral S7, StringLiteral S8,
137134
StringLiteral S9, T Value) {
138-
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, S9, Value);
135+
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8, S9);
139136
}
140137

141138
// Case-insensitive case matchers.
142139
StringSwitch &CaseLower(StringLiteral S, T Value) {
143-
if (!Result && Str.equals_insensitive(S))
144-
Result = std::move(Value);
145-
140+
CaseLowerImpl(Value, S);
146141
return *this;
147142
}
148143

@@ -161,22 +156,22 @@ class StringSwitch {
161156
}
162157

163158
StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, T Value) {
164-
return CaseLower(S0, Value).CaseLower(S1, Value);
159+
return CasesLowerImpl(Value, S0, S1);
165160
}
166161

167162
StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
168163
T Value) {
169-
return CaseLower(S0, Value).CasesLower(S1, S2, Value);
164+
return CasesLowerImpl(Value, S0, S1, S2);
170165
}
171166

172167
StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
173168
StringLiteral S3, T Value) {
174-
return CaseLower(S0, Value).CasesLower(S1, S2, S3, Value);
169+
return CasesLowerImpl(Value, S0, S1, S2, S3);
175170
}
176171

177172
StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
178173
StringLiteral S3, StringLiteral S4, T Value) {
179-
return CaseLower(S0, Value).CasesLower(S1, S2, S3, S4, Value);
174+
return CasesLowerImpl(Value, S0, S1, S2, S3, S4);
180175
}
181176

182177
[[nodiscard]] R Default(T Value) {
@@ -189,6 +184,39 @@ class StringSwitch {
189184
assert(Result && "Fell off the end of a string-switch");
190185
return std::move(*Result);
191186
}
187+
188+
private:
189+
// Returns true when `Str` matches the `S` argument, and stores the result.
190+
bool CaseImpl(T &Value, StringLiteral S) {
191+
if (!Result && Str == S) {
192+
Result = std::move(Value);
193+
return true;
194+
}
195+
return false;
196+
}
197+
198+
// Returns true when `Str` matches the `S` argument (case-insensitive), and
199+
// stores the result.
200+
bool CaseLowerImpl(T &Value, StringLiteral S) {
201+
if (!Result && Str.equals_insensitive(S)) {
202+
Result = std::move(Value);
203+
return true;
204+
}
205+
return false;
206+
}
207+
208+
template <typename... Args> StringSwitch &CasesImpl(T &Value, Args... Cases) {
209+
// Stop matching after the string is found.
210+
(... || CaseImpl(Value, Cases));
211+
return *this;
212+
}
213+
214+
template <typename... Args>
215+
StringSwitch &CasesLowerImpl(T &Value, Args... Cases) {
216+
// Stop matching after the string is found.
217+
(... || CaseLowerImpl(Value, Cases));
218+
return *this;
219+
}
192220
};
193221

194222
} // end namespace llvm

llvm/unittests/ADT/StringSwitchTest.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,28 @@ TEST(StringSwitchTest, CasesLower) {
205205
EXPECT_EQ(OSType::Unknown, Translate("wind"));
206206
EXPECT_EQ(OSType::Unknown, Translate(""));
207207
}
208+
209+
TEST(StringSwitchTest, CasesCopies) {
210+
struct Copyable {
211+
unsigned &NumCopies;
212+
Copyable(unsigned &Value) : NumCopies(Value) {}
213+
Copyable(const Copyable &Other) : NumCopies(Other.NumCopies) {
214+
++NumCopies;
215+
}
216+
Copyable &operator=(const Copyable &Other) {
217+
++NumCopies;
218+
return *this;
219+
}
220+
};
221+
222+
// Check that evaluating multiple cases does not cause unnecessary copies.
223+
unsigned NumCopies = 0;
224+
llvm::StringSwitch<Copyable, void>("baz").Cases("foo", "bar", "baz", "qux",
225+
Copyable{NumCopies});
226+
EXPECT_EQ(NumCopies, 1u);
227+
228+
NumCopies = 0;
229+
llvm::StringSwitch<Copyable, void>("baz").CasesLower(
230+
"Foo", "Bar", "Baz", "Qux", Copyable{NumCopies});
231+
EXPECT_EQ(NumCopies, 1u);
232+
}

0 commit comments

Comments
 (0)