Skip to content

Commit 91eefb9

Browse files
committed
[CSBindings] Store the originator type variable in a transitive binding
Instead of using a flag in `addBinding`, let's record what type variable does the binding belong to in `PotentialBinding` itself. This is going to help remove bindings introduced transitively when the inference is done lazily.
1 parent 96cd2e1 commit 91eefb9

File tree

4 files changed

+58
-34
lines changed

4 files changed

+58
-34
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,20 @@ struct PotentialBinding {
8484
/// because they are synthetic, they have a locator instead.
8585
PointerUnion<Constraint *, ConstraintLocator *> BindingSource;
8686

87+
/// When the binding is transferred through a subtype chain, this
88+
/// marks a type variable for which it was originally inferred.
89+
TypeVariableType *Originator;
90+
8791
PotentialBinding(Type type, AllowedBindingKind kind,
88-
PointerUnion<Constraint *, ConstraintLocator *> source)
89-
: BindingType(type), Kind(kind), BindingSource(source) {}
92+
PointerUnion<Constraint *, ConstraintLocator *> source,
93+
TypeVariableType *originator)
94+
: BindingType(type), Kind(kind), BindingSource(source),
95+
Originator(originator) {}
9096

9197
PotentialBinding(Type type, AllowedBindingKind kind, Constraint *source)
9298
: PotentialBinding(
93-
type, kind,
94-
PointerUnion<Constraint *, ConstraintLocator *>(source)) {}
99+
type, kind, PointerUnion<Constraint *, ConstraintLocator *>(source),
100+
/*originator=*/nullptr) {}
95101

96102
bool isDefaultableBinding() const {
97103
if (auto *constraint = BindingSource.dyn_cast<Constraint *>())
@@ -124,13 +130,20 @@ struct PotentialBinding {
124130
Constraint *getSource() const { return cast<Constraint *>(BindingSource); }
125131

126132
PotentialBinding withType(Type type) const {
127-
return {type, Kind, BindingSource};
133+
return {type, Kind, BindingSource, Originator};
128134
}
129135

130136
PotentialBinding withSameSource(Type type, AllowedBindingKind kind) const {
131-
return {type, kind, BindingSource};
137+
return {type, kind, BindingSource, Originator};
138+
}
139+
140+
PotentialBinding asTransitiveFrom(TypeVariableType *originator) const {
141+
ASSERT(originator);
142+
return {BindingType, Kind, BindingSource, originator};
132143
}
133144

145+
bool isTransitive() const { return bool(Originator); }
146+
134147
/// Determine whether this binding could be a viable candidate
135148
/// to be "joined" with some other binding. It has to be at least
136149
/// a non-default r-value supertype binding with no type variables.
@@ -140,12 +153,13 @@ struct PotentialBinding {
140153
ConstraintLocator *locator) {
141154
return {PlaceholderType::get(typeVar->getASTContext(), typeVar),
142155
AllowedBindingKind::Exact,
143-
/*source=*/locator};
156+
/*source=*/locator, /*originator=*/nullptr};
144157
}
145158

146159
static PotentialBinding forPlaceholder(Type placeholderTy) {
147160
return {placeholderTy, AllowedBindingKind::Exact,
148-
PointerUnion<Constraint *, ConstraintLocator *>()};
161+
PointerUnion<Constraint *, ConstraintLocator *>(),
162+
/*originator=*/nullptr};
149163
}
150164

151165
void print(llvm::raw_ostream &out, const PrintOptions &PO) const;
@@ -473,7 +487,7 @@ class BindingSet {
473487
/// \param isTransitive Indicates whether this binding has been
474488
/// acquired through transitive inference and requires extra
475489
/// checking.
476-
bool isViable(PotentialBinding &binding, bool isTransitive);
490+
bool isViable(PotentialBinding &binding);
477491

478492
/// Determine whether this set has any "viable" (or non-hole) bindings.
479493
///
@@ -602,10 +616,7 @@ class BindingSet {
602616
/// Add a new binding to the set.
603617
///
604618
/// \param binding The binding to add.
605-
/// \param isTransitive Indicates whether this binding has been
606-
/// acquired through transitive inference and requires validity
607-
/// checking.
608-
void addBinding(PotentialBinding binding, bool isTransitive);
619+
void addBinding(PotentialBinding binding);
609620

610621
void addLiteralRequirement(Constraint *literal);
611622

include/swift/Sema/CSTrail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class SolverTrail {
145145
/// of a PotentialBinding.
146146
Type BindingType;
147147
PointerUnion<Constraint *, ConstraintLocator *> BindingSource;
148+
TypeVariableType *Originator;
148149
} Binding;
149150

150151
ConstraintFix *TheFix;

lib/Sema/CSBindings.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar,
4343
: CS(CS), TypeVar(TypeVar), Info(info) {
4444

4545
for (const auto &binding : info.Bindings)
46-
addBinding(binding, /*isTransitive=*/false);
46+
addBinding(binding);
4747

4848
for (auto *constraint : info.Constraints) {
4949
switch (constraint->getKind()) {
@@ -596,7 +596,7 @@ void BindingSet::inferTransitiveKeyPathBindings() {
596596

597597
// Copy the bindings over to the root.
598598
for (const auto &binding : bindings.Bindings)
599-
addBinding(binding, /*isTransitive=*/true);
599+
addBinding(binding.asTransitiveFrom(contextualRootVar));
600600

601601
// Make a note that the key path root is transitively adjacent
602602
// to contextual root type variable and all of its variables.
@@ -606,9 +606,9 @@ void BindingSet::inferTransitiveKeyPathBindings() {
606606
bindings.AdjacentVars.end());
607607
}
608608
} else {
609-
addBinding(
610-
binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact),
611-
/*isTransitive=*/true);
609+
auto newBinding = binding.withSameSource(
610+
inferredRootTy, AllowedBindingKind::Exact);
611+
addBinding(newBinding.asTransitiveFrom(keyPathTy));
612612
}
613613
}
614614
}
@@ -679,8 +679,9 @@ void BindingSet::inferTransitiveSupertypeBindings() {
679679
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
680680
continue;
681681

682-
addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes),
683-
/*isTransitive=*/true);
682+
auto newBinding =
683+
binding.withSameSource(type, AllowedBindingKind::Supertypes);
684+
addBinding(newBinding.asTransitiveFrom(entry.first));
684685
}
685686
}
686687
}
@@ -713,8 +714,7 @@ void BindingSet::inferTransitiveUnresolvedMemberRefBindings() {
713714
continue;
714715
}
715716

716-
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
717-
/*isTransitive=*/false);
717+
addBinding({protocolTy, AllowedBindingKind::Exact, constraint});
718718
}
719719
}
720720
}
@@ -829,8 +829,8 @@ bool BindingSet::finalizeKeyPathBindings() {
829829
// better diagnostics.
830830
auto keyPathTy = getKeyPathType(ctx, *capability, rootTy,
831831
CS.getKeyPathValueType(keyPath));
832-
updatedBindings.insert(
833-
{keyPathTy, AllowedBindingKind::Exact, locator});
832+
updatedBindings.insert({keyPathTy, AllowedBindingKind::Exact, locator,
833+
/*originator=*/nullptr});
834834
} else if (CS.shouldAttemptFixes()) {
835835
auto fixedRootTy = CS.getFixedType(rootTy);
836836
// If key path is structurally correct and has a resolved root
@@ -889,11 +889,11 @@ void BindingSet::finalizeUnresolvedMemberChainResult() {
889889
}
890890
}
891891

892-
void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
892+
void BindingSet::addBinding(PotentialBinding binding) {
893893
if (Bindings.count(binding))
894894
return;
895895

896-
if (!isViable(binding, isTransitive))
896+
if (!isViable(binding))
897897
return;
898898

899899
SmallPtrSet<TypeVariableType *, 4> referencedTypeVars;
@@ -957,14 +957,26 @@ void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
957957
for (auto existingBinding = Bindings.begin();
958958
existingBinding != Bindings.end();) {
959959
if (existingBinding->isViableForJoin()) {
960-
auto join =
960+
auto joinType =
961961
Type::join(existingBinding->BindingType, binding.BindingType);
962962

963-
if (join && isAcceptableJoin(*join)) {
963+
if (joinType && isAcceptableJoin(*joinType)) {
964964
// Result of the join has to use new binding because it refers
965965
// to the constraint that triggered the join that replaced the
966966
// existing binding.
967-
joined.push_back(binding.withType(*join));
967+
//
968+
// For "join" to be transitive, both bindings have to be as
969+
// well, otherwise we consider it a refinement of a direct
970+
// binding.
971+
auto *origintor =
972+
binding.isTransitive() && existingBinding->isTransitive()
973+
? binding.Originator
974+
: nullptr;
975+
976+
PotentialBinding join(*joinType, binding.Kind, binding.BindingSource,
977+
origintor);
978+
979+
joined.push_back(join);
968980
// Remove existing binding from the set.
969981
// It has to be re-introduced later, since its type has been changed.
970982
existingBinding = Bindings.erase(existingBinding);
@@ -1448,15 +1460,15 @@ static bool hasConversions(Type type) {
14481460
type->is<BuiltinType>() || type->is<ArchetypeType>());
14491461
}
14501462

1451-
bool BindingSet::isViable(PotentialBinding &binding, bool isTransitive) {
1463+
bool BindingSet::isViable(PotentialBinding &binding) {
14521464
// Prevent against checking against the same opened nominal type
14531465
// over and over again. Doing so means redundant work in the best
14541466
// case. In the worst case, we'll produce lots of duplicate solutions
14551467
// for this constraint system, which is problematic for overload
14561468
// resolution.
14571469
auto type = binding.BindingType;
14581470

1459-
if (isTransitive && !checkTypeOfBinding(TypeVar, type))
1471+
if (binding.isTransitive() && !checkTypeOfBinding(TypeVar, type))
14601472
return false;
14611473

14621474
auto *NTD = type->getAnyNominal();

lib/Sema/CSTrail.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ SolverTrail::Change::RetractedBinding(TypeVariableType *typeVar,
334334
result.Binding.TypeVar = typeVar;
335335
result.Binding.BindingType = binding.BindingType;
336336
result.Binding.BindingSource = binding.BindingSource;
337+
result.Binding.Originator = binding.Originator;
337338
result.Options = unsigned(binding.Kind);
338339

339340
return result;
@@ -563,9 +564,8 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const {
563564
break;
564565

565566
case ChangeKind::RetractedBinding: {
566-
PotentialBinding binding(Binding.BindingType,
567-
AllowedBindingKind(Options),
568-
Binding.BindingSource);
567+
PotentialBinding binding(Binding.BindingType, AllowedBindingKind(Options),
568+
Binding.BindingSource, Binding.Originator);
569569

570570
auto &bindings = cg[BindingRelation.TypeVar].getPotentialBindings();
571571
bindings.Bindings.push_back(binding);

0 commit comments

Comments
 (0)