Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import SwiftUI
@MainActor
public struct SignInWithAppleButton {
@Environment(AuthService.self) private var authService
@Environment(\.signInWithMergeConflictHandler) private var signInHandler
let provider: AuthProviderSwift
public init(provider: AuthProviderSwift) {
self.provider = provider
Expand All @@ -34,7 +35,13 @@ extension SignInWithAppleButton: View {
accessibilityId: "sign-in-with-apple-button"
) {
Task {
try? await authService.signIn(provider)
if let handler = signInHandler {
try? await handler(authService) {
try await authService.signIn(provider)
}
} else {
try? await authService.signIn(provider)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ public protocol AuthProviderUI {
var provider: AuthProviderSwift { get }
}

public protocol PhoneAuthProviderSwift: AuthProviderSwift {
@MainActor func verifyPhoneNumber(phoneNumber: String) async throws -> String
func setVerificationCode(verificationID: String, code: String)
public protocol PhoneAuthProviderSwift: AuthProviderSwift, AnyObject {
// Phone auth provider that presents its own UI flow in createAuthCredential()
// Internal use only: AuthService will be injected automatically by AuthService.signIn()
var authService: AuthService? { get set }
}

public enum AuthenticationState {
Expand All @@ -50,8 +51,6 @@ public enum AuthView: Hashable {
case mfaEnrollment
case mfaManagement
case mfaResolution
case enterPhoneNumber
case enterVerificationCode(verificationID: String, fullPhoneNumber: String)
}

public enum SignInOutcome: @unchecked Sendable {
Expand Down Expand Up @@ -144,10 +143,6 @@ public final class AuthService {

private var providers: [AuthProviderUI] = []

public var currentPhoneProvider: PhoneAuthProviderSwift? {
providers.compactMap { $0.provider as? PhoneAuthProviderSwift }.first
}

public func registerProvider(providerWithButton: AuthProviderUI) {
providers.append(providerWithButton)
}
Expand All @@ -171,11 +166,17 @@ public final class AuthService {

public func signIn(_ provider: AuthProviderSwift) async throws -> SignInOutcome {
do {
// Automatically inject AuthService for phone provider
if let phoneProvider = provider as? PhoneAuthProviderSwift {
phoneProvider.authService = self
}

let credential = try await provider.createAuthCredential()
let result = try await signIn(credentials: credential)
return result
} catch {
updateError(message: string.localizedErrorMessage(for: error))
// Always pass the underlying error - view decides what to show
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -206,8 +207,8 @@ public final class AuthService {
currentError = nil
}

func updateError(title: String = "Error", message: String) {
currentError = AlertError(title: title, message: message)
func updateError(title: String = "Error", message: String, underlyingError: Error? = nil) {
currentError = AlertError(title: title, message: message, underlyingError: underlyingError)
}

public var shouldHandleAnonymousUpgrade: Bool {
Expand All @@ -217,9 +218,11 @@ public final class AuthService {
public func signOut() async throws {
do {
try await auth.signOut()
// Cannot wait for auth listener to change, feedback needs to be immediate
currentUser = nil
updateAuthenticationState()
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand All @@ -237,12 +240,12 @@ public final class AuthService {
updateAuthenticationState()
} catch {
authenticationState = .unauthenticated
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}

public func handleAutoUpgradeAnonymousUser(credentials: AuthCredential) async throws
private func handleAutoUpgradeAnonymousUser(credentials: AuthCredential) async throws
-> SignInOutcome {
if currentUser == nil {
throw AuthServiceError.noCurrentUser
Expand All @@ -252,11 +255,27 @@ public final class AuthService {
updateAuthenticationState()
return .signedIn(result)
} catch let error as NSError {
// Handle credentialAlreadyInUse error
if error.code == AuthErrorCode.credentialAlreadyInUse.rawValue {
// Extract the updated credential from the error
let updatedCredential = error.userInfo["FIRAuthUpdatedCredentialKey"] as? AuthCredential
?? credentials

let context = AccountMergeConflictContext(
credential: updatedCredential,
underlyingError: error,
message: "Unable to merge accounts. The credential is already associated with a different account.",
uid: currentUser?.uid
)
throw AuthServiceError.accountMergeConflict(context: context)
}

// Handle emailAlreadyInUse error
if error.code == AuthErrorCode.emailAlreadyInUse.rawValue {
let context = AccountMergeConflictContext(
credential: credentials,
underlyingError: error,
message: "Unable to merge accounts. Use the credential in the context to resolve the conflict.",
message: "Unable to merge accounts. This email is already associated with a different account.",
uid: currentUser?.uid
)
throw AuthServiceError.accountMergeConflict(context: context)
Expand Down Expand Up @@ -285,7 +304,7 @@ public final class AuthService {
}
} else {
// Don't want error modal on MFA error so we only update here
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
}

throw error
Expand All @@ -307,7 +326,7 @@ public final class AuthService {
}
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand All @@ -326,7 +345,7 @@ public extension AuthService {
try await user.delete()
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand All @@ -341,7 +360,7 @@ public extension AuthService {
try await user.updatePassword(to: password)
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -374,7 +393,7 @@ public extension AuthService {
}
} catch {
authenticationState = .unauthenticated
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand All @@ -383,7 +402,7 @@ public extension AuthService {
do {
try await auth.sendPasswordReset(withEmail: email)
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand All @@ -400,7 +419,7 @@ public extension AuthService {
actionCodeSettings: actionCodeSettings
)
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -433,7 +452,7 @@ public extension AuthService {
emailLink = nil
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -502,7 +521,7 @@ public extension AuthService {
changeRequest.photoURL = url
try await changeRequest.commitChanges()
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand All @@ -517,7 +536,7 @@ public extension AuthService {
changeRequest.displayName = name
try await changeRequest.commitChanges()
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -609,7 +628,7 @@ public extension AuthService {
)
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -661,7 +680,7 @@ public extension AuthService {

return verificationID
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -740,7 +759,7 @@ public extension AuthService {
}
currentUser = auth.currentUser
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -829,7 +848,7 @@ public extension AuthService {

return freshFactors
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -899,7 +918,7 @@ public extension AuthService {
}
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down Expand Up @@ -952,7 +971,7 @@ public extension AuthService {
.multiFactorAuth("Failed to resolve MFA challenge: \(error.localizedDescription)")
}
} catch {
updateError(message: string.localizedErrorMessage(for: error))
updateError(message: string.localizedErrorMessage(for: error), underlyingError: error)
throw error
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,59 @@ import FirebaseAuthUIComponents
import FirebaseCore
import SwiftUI

// MARK: - Merge Conflict Handling

/// Helper function to handle sign-in with automatic merge conflict resolution.
///
/// This function attempts to sign in with the provided action. If a merge conflict occurs
/// (when an anonymous user is being upgraded and the credential is already associated with
/// an existing account), it automatically signs out the anonymous user and signs in with
/// the existing account's credential.
///
/// - Parameters:
/// - authService: The AuthService instance to use for sign-in operations
/// - signInAction: An async closure that performs the sign-in operation
/// - Returns: The SignInOutcome from the successful sign-in
/// - Throws: Re-throws any errors except accountMergeConflict (which is handled internally)
@MainActor
public func signInWithMergeConflictHandling(authService: AuthService,
signInAction: () async throws
-> SignInOutcome) async throws -> SignInOutcome {
do {
return try await signInAction()
} catch let error as AuthServiceError {
if case let .accountMergeConflict(context) = error {
// The anonymous account conflicts with an existing account
// Sign out the anonymous user
try await authService.signOut()

// Sign in with the existing account's credential
// This works because shouldHandleAnonymousUpgrade is now false after sign out
return try await authService.signIn(credentials: context.credential)
}
throw error
}
}

// MARK: - Environment Key for Sign-In Handler

/// Environment key for a sign-in handler that includes merge conflict resolution
private struct SignInHandlerKey: EnvironmentKey {
static let defaultValue: (@MainActor (AuthService, () async throws -> SignInOutcome) async throws
-> SignInOutcome)? = nil
}

public extension EnvironmentValues {
/// A sign-in handler that automatically handles merge conflicts for anonymous user upgrades.
/// When set in the environment, views should use this handler to wrap their sign-in calls.
var signInWithMergeConflictHandler: (@MainActor (AuthService,
() async throws -> SignInOutcome) async throws
-> SignInOutcome)? {
get { self[SignInHandlerKey.self] }
set { self[SignInHandlerKey.self] = newValue }
}
}

@MainActor
public struct AuthPickerView<Content: View> {
public init(@ViewBuilder content: @escaping () -> Content = { EmptyView() }) {
Expand Down Expand Up @@ -54,22 +107,6 @@ extension AuthPickerView: View {
MFAManagementView()
case AuthView.mfaResolution:
MFAResolutionView()
case AuthView.enterPhoneNumber:
if let phoneProvider = authService.currentPhoneProvider {
EnterPhoneNumberView(phoneProvider: phoneProvider)
} else {
EmptyView()
}
case let .enterVerificationCode(verificationID, fullPhoneNumber):
if let phoneProvider = authService.currentPhoneProvider {
EnterVerificationCodeView(
verificationID: verificationID,
fullPhoneNumber: fullPhoneNumber,
phoneProvider: phoneProvider
)
} else {
EmptyView()
}
}
}
}
Expand Down Expand Up @@ -117,7 +154,10 @@ extension AuthPickerView: View {
.aspectRatio(contentMode: .fit)
.frame(width: 100, height: 100)
if authService.emailSignInEnabled {
EmailAuthView()
EmailAuthView().environment(
\.signInWithMergeConflictHandler,
signInWithMergeConflictHandling
)
}
Divider()
otherSignInOptions(proxy)
Expand All @@ -133,6 +173,7 @@ extension AuthPickerView: View {
authService.renderButtons()
}
.padding(.horizontal, proxy.size.width * 0.18)
.environment(\.signInWithMergeConflictHandler, signInWithMergeConflictHandling)
}
}

Expand Down
Loading
Loading