Skip to content

Commit 965f32d

Browse files
authored
Merge pull request swiftlang#41773 from slavapestov/rqm-concrete-protocol-typealias-minimization
RequirementMachine: Overhaul handling of protocol typealiases with concrete underlying type
2 parents b6047e2 + 1f83cd0 commit 965f32d

14 files changed

+254
-64
lines changed

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ RequirementMachine::getLongestValidPrefix(const MutableTerm &term) const {
239239
case Symbol::Kind::Superclass:
240240
case Symbol::Kind::ConcreteType:
241241
case Symbol::Kind::ConcreteConformance:
242-
llvm_unreachable("Property symbol cannot appear in a type term");
242+
llvm::errs() <<"Invalid symbol in a type term: " << term << "\n";
243+
abort();
243244
}
244245

245246
// This symbol is valid, add it to the longest prefix.
@@ -265,6 +266,9 @@ bool RequirementMachine::isCanonicalTypeInContext(Type type) const {
265266
explicit Walker(const RequirementMachine &self) : Self(self) {}
266267

267268
Action walkToTypePre(Type component) override {
269+
if (!component->hasTypeParameter())
270+
return Action::SkipChildren;
271+
268272
if (!component->isTypeParameter())
269273
return Action::Continue;
270274

@@ -305,6 +309,9 @@ Type RequirementMachine::getCanonicalTypeInContext(
305309
TypeArrayView<GenericTypeParamType> genericParams) const {
306310

307311
return type.transformRec([&](Type t) -> Optional<Type> {
312+
if (!t->hasTypeParameter())
313+
return t;
314+
308315
if (!t->isTypeParameter())
309316
return None;
310317

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,30 @@
6565
using namespace swift;
6666
using namespace rewriting;
6767

68-
/// Recompute Useful, RulesInEmptyContext, ProjectionCount and DecomposeCount
69-
/// if needed.
68+
/// Recompute various cached values if needed.
7069
void RewriteLoop::recompute(const RewriteSystem &system) {
7170
if (!Dirty)
7271
return;
7372
Dirty = 0;
7473

74+
Useful = 0;
7575
ProjectionCount = 0;
7676
DecomposeCount = 0;
77-
Useful = false;
77+
HasConcreteTypeAliasRule = 0;
7878

7979
RewritePathEvaluator evaluator(Basepoint);
8080
for (auto step : Path) {
8181
switch (step.Kind) {
82-
case RewriteStep::Rule:
82+
case RewriteStep::Rule: {
8383
Useful |= (!step.isInContext() && !evaluator.isInContext());
84+
85+
const auto &rule = system.getRule(step.getRuleID());
86+
if (rule.isProtocolTypeAliasRule() &&
87+
rule.getLHS().size() == 3)
88+
HasConcreteTypeAliasRule = 1;
89+
8490
break;
91+
}
8592

8693
case RewriteStep::LeftConcreteProjection:
8794
++ProjectionCount;
@@ -130,6 +137,14 @@ unsigned RewriteLoop::getDecomposeCount(
130137
return DecomposeCount;
131138
}
132139

140+
/// Returns true if the loop contains at least one concrete protocol typealias rule,
141+
/// which have the form ([P].A.[concrete: C] => [P].A).
142+
bool RewriteLoop::hasConcreteTypeAliasRule(
143+
const RewriteSystem &system) const {
144+
const_cast<RewriteLoop *>(this)->recompute(system);
145+
return HasConcreteTypeAliasRule;
146+
}
147+
133148
/// The number of Decompose steps, used by the elimination order to prioritize
134149
/// loops that are not concrete simplifications.
135150
bool RewriteLoop::isUseful(
@@ -488,7 +503,7 @@ RewritePath::getRulesInEmptyContext(const MutableTerm &term,
488503
/// \p redundantConformances equal to the set of conformance rules that are
489504
/// not minimal conformances.
490505
Optional<std::pair<unsigned, unsigned>> RewriteSystem::
491-
findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
506+
findRuleToDelete(EliminationPredicate isRedundantRuleFn) {
492507
SmallVector<std::pair<unsigned, unsigned>, 2> redundancyCandidates;
493508
for (unsigned loopID : indices(Loops)) {
494509
auto &loop = Loops[loopID];
@@ -520,7 +535,10 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
520535
}
521536

522537
for (const auto &pair : redundancyCandidates) {
538+
unsigned loopID = pair.first;
523539
unsigned ruleID = pair.second;
540+
541+
const auto &loop = Loops[loopID];
524542
const auto &rule = getRule(ruleID);
525543

526544
// We should not find a rule that has already been marked redundant
@@ -538,18 +556,18 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
538556
// Homotopy reduction runs multiple passes with different filters to
539557
// prioritize the deletion of certain rules ahead of others. Apply
540558
// the filter now.
541-
if (!isRedundantRuleFn(ruleID)) {
559+
if (!isRedundantRuleFn(loopID, ruleID)) {
542560
if (Debug.contains(DebugFlags::HomotopyReductionDetail)) {
543561
llvm::dbgs() << "** Skipping rule " << rule << " from loop #"
544-
<< pair.first << "\n";
562+
<< loopID << "\n";
545563
}
546564

547565
continue;
548566
}
549567

550568
if (Debug.contains(DebugFlags::HomotopyReductionDetail)) {
551569
llvm::dbgs() << "** Candidate rule " << rule << " from loop #"
552-
<< pair.first << "\n";
570+
<< loopID << "\n";
553571
}
554572

555573
if (!found) {
@@ -561,7 +579,6 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
561579
// we've found so far.
562580
const auto &otherRule = getRule(found->second);
563581

564-
const auto &loop = Loops[pair.first];
565582
const auto &otherLoop = Loops[found->first];
566583

567584
{
@@ -712,7 +729,7 @@ void RewriteSystem::deleteRule(unsigned ruleID,
712729
}
713730

714731
void RewriteSystem::performHomotopyReduction(
715-
llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
732+
EliminationPredicate isRedundantRuleFn) {
716733
while (true) {
717734
auto optPair = findRuleToDelete(isRedundantRuleFn);
718735

@@ -803,14 +820,21 @@ void RewriteSystem::minimizeRewriteSystem() {
803820
// First pass:
804821
// - Eliminate all LHS-simplified non-conformance rules.
805822
// - Eliminate all RHS-simplified and substitution-simplified rules.
806-
// - Eliminate all rules with unresolved symbols.
823+
//
824+
// An example of a conformance rule that is LHS-simplified but not
825+
// RHS-simplified is (T.[P] => T) where T is irreducible, but there
826+
// is a rule (V.[P] => V) for some V with T == U.V.
827+
//
828+
// Such conformance rules can still be minimal, as part of a hack to
829+
// maintain compatibility with the GenericSignatureBuilder's minimization
830+
// algorithm.
807831
if (Debug.contains(DebugFlags::HomotopyReduction)) {
808-
llvm::dbgs() << "---------------------------------------------\n";
809-
llvm::dbgs() << "First pass: simplified and unresolved rules -\n";
810-
llvm::dbgs() << "---------------------------------------------\n";
832+
llvm::dbgs() << "------------------------------\n";
833+
llvm::dbgs() << "First pass: simplified rules -\n";
834+
llvm::dbgs() << "------------------------------\n";
811835
}
812836

813-
performHomotopyReduction([&](unsigned ruleID) -> bool {
837+
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
814838
const auto &rule = getRule(ruleID);
815839

816840
if (rule.isLHSSimplified() &&
@@ -821,8 +845,31 @@ void RewriteSystem::minimizeRewriteSystem() {
821845
rule.isSubstitutionSimplified())
822846
return true;
823847

824-
if (rule.containsUnresolvedSymbols() &&
825-
!rule.isProtocolTypeAliasRule())
848+
return false;
849+
});
850+
851+
// Second pass:
852+
// - Eliminate all rules with unresolved symbols which were *not*
853+
// simplified.
854+
//
855+
// Two examples of such rules:
856+
//
857+
// - (T.X => T.[P:X]) obtained from resolving the overlap between
858+
// (T.[P] => T) and ([P].X => [P:X]).
859+
//
860+
// - (T.X.[concrete: C] => T.X) obtained from resolving the overlap
861+
// between (T.[P] => T) and a protocol typealias rule
862+
// ([P].X.[concrete: C] => [P].X).
863+
if (Debug.contains(DebugFlags::HomotopyReduction)) {
864+
llvm::dbgs() << "-------------------------------\n";
865+
llvm::dbgs() << "Second pass: unresolved rules -\n";
866+
llvm::dbgs() << "-------------------------------\n";
867+
}
868+
869+
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
870+
const auto &rule = getRule(ruleID);
871+
872+
if (rule.containsUnresolvedSymbols())
826873
return true;
827874

828875
return false;
@@ -838,14 +885,14 @@ void RewriteSystem::minimizeRewriteSystem() {
838885
llvm::DenseSet<unsigned> redundantConformances;
839886
computeMinimalConformances(redundantConformances);
840887

841-
// Second pass: Eliminate all non-minimal conformance rules.
888+
// Third pass: Eliminate all non-minimal conformance rules.
842889
if (Debug.contains(DebugFlags::HomotopyReduction)) {
843-
llvm::dbgs() << "--------------------------------------------\n";
844-
llvm::dbgs() << "Second pass: non-minimal conformance rules -\n";
845-
llvm::dbgs() << "--------------------------------------------\n";
890+
llvm::dbgs() << "-------------------------------------------\n";
891+
llvm::dbgs() << "Third pass: non-minimal conformance rules -\n";
892+
llvm::dbgs() << "-------------------------------------------\n";
846893
}
847894

848-
performHomotopyReduction([&](unsigned ruleID) -> bool {
895+
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
849896
const auto &rule = getRule(ruleID);
850897

851898
if (rule.isAnyConformanceRule() &&
@@ -855,17 +902,22 @@ void RewriteSystem::minimizeRewriteSystem() {
855902
return false;
856903
});
857904

858-
// Third pass: Eliminate all other redundant non-conformance rules.
905+
// Fourth pass: Eliminate all remaining redundant non-conformance rules.
859906
if (Debug.contains(DebugFlags::HomotopyReduction)) {
860-
llvm::dbgs() << "---------------------------------------\n";
861-
llvm::dbgs() << "Third pass: all other redundant rules -\n";
862-
llvm::dbgs() << "---------------------------------------\n";
907+
llvm::dbgs() << "----------------------------------------\n";
908+
llvm::dbgs() << "Fourth pass: all other redundant rules -\n";
909+
llvm::dbgs() << "----------------------------------------\n";
863910
}
864911

865-
performHomotopyReduction([&](unsigned ruleID) -> bool {
912+
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
913+
const auto &loop = Loops[loopID];
866914
const auto &rule = getRule(ruleID);
867915

868-
if (!rule.isAnyConformanceRule())
916+
if (rule.isProtocolTypeAliasRule())
917+
return true;
918+
919+
if (!loop.hasConcreteTypeAliasRule(*this) &&
920+
!rule.isAnyConformanceRule())
869921
return true;
870922

871923
return false;

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
311311
case Symbol::Kind::Superclass:
312312
case Symbol::Kind::ConcreteType:
313313
case Symbol::Kind::ConcreteConformance:
314-
llvm_unreachable("Term has invalid root symbol");
314+
llvm::errs() << "Invalid root symbol: " << MutableTerm(begin, end) << "\n";
315+
abort();
315316
}
316317
}
317318

lib/AST/RequirementMachine/MinimalConformances.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ static const ProtocolDecl *getParentConformanceForTerm(Term lhs) {
365365
break;
366366
}
367367

368-
llvm_unreachable("Bad symbol kind");
368+
llvm::errs() << "Bad symbol in " << lhs << "\n";
369+
abort();
369370
}
370371

371372
/// Collect conformance rules and parent paths, and record an initial

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ class RewriteContext final {
5555
/// Cache for associated type declarations.
5656
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;
5757

58-
/// Cache for merged associated type symbols.
59-
llvm::DenseMap<std::pair<Symbol, Symbol>, Symbol> MergedAssocTypes;
60-
6158
/// Requirement machines built from generic signatures.
6259
llvm::DenseMap<GenericSignature, RequirementMachine *> Machines;
6360

lib/AST/RequirementMachine/RewriteLoop.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ class RewriteLoop {
457457
/// Cached value for getDecomposeCount().
458458
unsigned DecomposeCount : 15;
459459

460+
/// Cached value for hasConcreteTypeAliasRule().
461+
unsigned HasConcreteTypeAliasRule : 1;
462+
460463
/// A useful loop contains at least one rule in empty context, even if that
461464
/// rule appears multiple times or also in non-empty context. The only loops
462465
/// that are elimination candidates contain a rule in empty context *exactly
@@ -478,6 +481,7 @@ class RewriteLoop {
478481
: Basepoint(basepoint), Path(path) {
479482
ProjectionCount = 0;
480483
DecomposeCount = 0;
484+
HasConcreteTypeAliasRule = 0;
481485
Useful = 0;
482486
Deleted = 0;
483487

@@ -509,6 +513,8 @@ class RewriteLoop {
509513

510514
unsigned getDecomposeCount(const RewriteSystem &system) const;
511515

516+
bool hasConcreteTypeAliasRule(const RewriteSystem &system) const;
517+
512518
void findProtocolConformanceRules(
513519
llvm::SmallDenseMap<const ProtocolDecl *,
514520
ProtocolConformanceRules, 2> &result,

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,13 @@ Optional<Identifier> Rule::isProtocolTypeAliasRule() const {
174174
//
175175
// We shouldn't have unresolved symbols on the right hand side;
176176
// they should have been simplified away.
177-
if (RHS.containsUnresolvedSymbols())
178-
return None;
177+
if (RHS.containsUnresolvedSymbols()) {
178+
if (RHS.size() != 2 ||
179+
RHS[0] != LHS[0] ||
180+
RHS[1].getKind() != Symbol::Kind::Name) {
181+
return None;
182+
}
183+
}
179184
} else {
180185
// This is the case where the underlying type is concrete.
181186
assert(LHS.size() == 3);
@@ -660,6 +665,28 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
660665
for (unsigned index : indices(lhs)) {
661666
auto symbol = lhs[index];
662667

668+
// The left hand side can contain a single name symbol if it has the form
669+
// T.N or T.N.[p], where T is some prefix that does not contain name
670+
// symbols, N is a name symbol, and [p] is an optional property symbol.
671+
//
672+
// In the latter case, we have a protocol typealias, or a rule derived
673+
// via resolving a critical pair involving a protocol typealias.
674+
//
675+
// Any other valid occurrence of a name symbol should have been reduced by
676+
// an associated type introduction rule [P].N, marking the rule as
677+
// LHS-simplified.
678+
if (!rule.isLHSSimplified() &&
679+
(rule.isPropertyRule()
680+
? index != lhs.size() - 2
681+
: index != lhs.size() - 1)) {
682+
// This is only true if the input requirements were valid.
683+
if (policy == DisallowInvalidRequirements) {
684+
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name);
685+
} else {
686+
// FIXME: Assert that we diagnosed an error
687+
}
688+
}
689+
663690
if (index != lhs.size() - 1) {
664691
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout);
665692
ASSERT_RULE(!symbol.hasSubstitutions());
@@ -677,14 +704,18 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
677704
for (unsigned index : indices(rhs)) {
678705
auto symbol = rhs[index];
679706

680-
// RHS-simplified rules might have unresolved name symbols on the
681-
// right hand side. Also, completion can introduce rules of the
682-
// form T.X.[concrete: C] => T.X, where T is some resolved term,
683-
// and X is a name symbol for a protocol typealias.
684-
if (!rule.isLHSSimplified() &&
685-
!rule.isRHSSimplified() &&
686-
!(rule.isPropertyRule() &&
687-
index == rhs.size() - 1)) {
707+
// The right hand side can contain a single name symbol if it has the form
708+
// T.N, where T is some prefix that does not contain name symbols, and
709+
// N is a name symbol.
710+
//
711+
// In this case, we have a protocol typealias, or a rule derived via
712+
// resolving a critical pair involving a protocol typealias.
713+
//
714+
// Any other valid occurrence of a name symbol should have been reduced by
715+
// an associated type introduction rule [P].N, marking the rule as
716+
// RHS-simplified.
717+
if (!rule.isRHSSimplified() &&
718+
index != rhs.size() - 1) {
688719
// This is only true if the input requirements were valid.
689720
if (policy == DisallowInvalidRequirements) {
690721
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name);

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,13 +506,15 @@ class RewriteSystem final {
506506

507507
void processConflicts();
508508

509+
using EliminationPredicate = llvm::function_ref<bool(unsigned loopID,
510+
unsigned ruleID)>;
511+
509512
Optional<std::pair<unsigned, unsigned>>
510-
findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn);
513+
findRuleToDelete(EliminationPredicate isRedundantRuleFn);
511514

512515
void deleteRule(unsigned ruleID, const RewritePath &replacementPath);
513516

514-
void performHomotopyReduction(
515-
llvm::function_ref<bool(unsigned)> isRedundantRuleFn);
517+
void performHomotopyReduction(EliminationPredicate isRedundantRuleFn);
516518

517519
void computeMinimalConformances(
518520
llvm::DenseSet<unsigned> &redundantConformances);

0 commit comments

Comments
 (0)