Skip to content

Commit a4ae171

Browse files
authored
Merge pull request #85973 from bnbarham/convert-async-shorthand
[SourceKit] Allow converting functions containing shorthand ifs to async
2 parents 3754042 + 7887ce8 commit a4ae171

File tree

6 files changed

+70
-31
lines changed

6 files changed

+70
-31
lines changed

include/swift/AST/Stmt.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,10 @@ class alignas(1 << PatternAlignInBits) StmtConditionElement {
725725
bool rebindsSelf(ASTContext &Ctx, bool requiresCaptureListRef = false,
726726
bool requireLoadExpr = false) const;
727727

728+
/// Returns the synthesized RHS for a shorthand if let (eg. `if let x`), or
729+
/// null if this element does not represent a shorthand if let.
730+
Expr *getSynthesizedShorthandInitOrNull() const;
731+
728732
SourceLoc getStartLoc() const;
729733
SourceLoc getEndLoc() const;
730734
SourceRange getSourceRange() const;

lib/AST/Stmt.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,28 @@ bool StmtConditionElement::rebindsSelf(ASTContext &Ctx,
597597
return false;
598598
}
599599

600+
Expr *StmtConditionElement::getSynthesizedShorthandInitOrNull() const {
601+
auto *init = getInitializerOrNull();
602+
if (!init)
603+
return nullptr;
604+
605+
auto *pattern = dyn_cast_or_null<OptionalSomePattern>(getPattern());
606+
if (!pattern)
607+
return nullptr;
608+
609+
auto *var = pattern->getSubPattern()->getSingleVar();
610+
if (!var)
611+
return nullptr;
612+
613+
// If the right-hand side has the same location as the variable, it was
614+
// synthesized.
615+
if (var->getLoc().isValid() && var->getLoc() == init->getStartLoc() &&
616+
init->getStartLoc() == init->getEndLoc()) {
617+
return init;
618+
}
619+
return nullptr;
620+
}
621+
600622
SourceRange ConditionalPatternBindingInfo::getSourceRange() const {
601623
SourceLoc Start;
602624
if (IntroducerLoc.isValid())

lib/Refactoring/Async/AsyncConverter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,11 @@ bool AsyncConverter::walkToDeclPost(Decl *D) {
424424
#define PLACEHOLDER_START "<#"
425425
#define PLACEHOLDER_END "#>"
426426
bool AsyncConverter::walkToExprPre(Expr *E) {
427+
// We've already added any shorthand if declaration, don't add its
428+
// synthesized initializer as well.
429+
if (shorthandIfInits.contains(E))
430+
return true;
431+
427432
// TODO: Handle Result.get as well
428433
if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
429434
if (auto *D = DRE->getDecl()) {
@@ -530,6 +535,15 @@ bool AsyncConverter::walkToExprPost(Expr *E) {
530535
#undef PLACEHOLDER_END
531536

532537
bool AsyncConverter::walkToStmtPre(Stmt *S) {
538+
// Keep track of any shorthand initializer expressions
539+
if (auto *labeledConditional = dyn_cast<LabeledConditionalStmt>(S)) {
540+
for (const auto &condition : labeledConditional->getCond()) {
541+
if (auto *init = condition.getSynthesizedShorthandInitOrNull()) {
542+
shorthandIfInits.insert(init);
543+
}
544+
}
545+
}
546+
533547
// CaseStmt has an implicit BraceStmt inside it, which *should* start a new
534548
// scope, so don't check isImplicit here.
535549
if (startsNewScope(S)) {

lib/Refactoring/Async/AsyncRefactoring.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,10 @@ class AsyncConverter : private SourceEntityWalker {
969969
SmallString<0> Buffer;
970970
llvm::raw_svector_ostream OS;
971971

972+
// Any initializer expressions in a shorthand if that we need to skip (as it
973+
// points to the same identifier as the declaration itself).
974+
llvm::DenseSet<const Expr *> shorthandIfInits;
975+
972976
// Decls where any force unwrap or optional chain of that decl should be
973977
// elided, e.g for a previously optional closure parameter that has become a
974978
// non-optional local.

lib/Sema/TypeCheckEffects.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4715,37 +4715,7 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
47154715
// Make a note of any initializers that are the synthesized right-hand side
47164716
// for an "if let x".
47174717
for (const auto &condition: stmt->getCond()) {
4718-
switch (condition.getKind()) {
4719-
case StmtConditionElement::CK_Availability:
4720-
case StmtConditionElement::CK_Boolean:
4721-
case StmtConditionElement::CK_HasSymbol:
4722-
continue;
4723-
4724-
case StmtConditionElement::CK_PatternBinding:
4725-
break;
4726-
}
4727-
4728-
auto init = condition.getInitializer();
4729-
if (!init)
4730-
continue;
4731-
4732-
auto pattern = condition.getPattern();
4733-
if (!pattern)
4734-
continue;
4735-
4736-
auto optPattern = dyn_cast<OptionalSomePattern>(pattern);
4737-
if (!optPattern)
4738-
continue;
4739-
4740-
auto var = optPattern->getSubPattern()->getSingleVar();
4741-
if (!var)
4742-
continue;
4743-
4744-
// If the right-hand side has the same location as the variable, it was
4745-
// synthesized.
4746-
if (var->getLoc().isValid() &&
4747-
var->getLoc() == init->getStartLoc() &&
4748-
init->getStartLoc() == init->getEndLoc())
4718+
if (auto *init = condition.getSynthesizedShorthandInitOrNull())
47494719
synthesizedIfLetInitializers.insert(init);
47504720
}
47514721
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// REQUIRES: concurrency
2+
3+
// RUN: %empty-directory(%t)
4+
5+
func foo(_ fn: @escaping (String, Error?) -> Void) {}
6+
func foo() async throws -> String { return "" }
7+
8+
// RUN: %refactor-check-compiles -convert-to-async -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck %s
9+
func shorthandIf(completion: @escaping (String?, Error?) -> Void) {
10+
foo { str, error in
11+
if let error {
12+
completion(nil, error)
13+
} else {
14+
completion(str, nil)
15+
}
16+
}
17+
}
18+
// CHECK: func shorthandIf() async throws -> String {
19+
// CHECK-NEXT: return try await withCheckedThrowingContinuation { continuation in
20+
// CHECK-NEXT: foo { str, error in
21+
// CHECK-NEXT: if let error {
22+
// CHECK-NEXT: continuation.resume(throwing: error)
23+
// CHECK-NEXT: } else {
24+
// CHECK-NEXT: continuation.resume(returning: str)
25+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)