11module LinearSolveRecursiveFactorizationExt
22
33using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval ,
4- RFLUFactorization, RF32MixedLUFactorization, default_alias_A,
5- default_alias_b, LinearVerbosity
4+ RFLUFactorization, ButterflyFactorization, RF32MixedLUFactorization,
5+ default_alias_A, default_alias_b, LinearVerbosity
66using LinearSolve. LinearAlgebra, LinearSolve. ArrayInterface, RecursiveFactorization
77using SciMLBase: SciMLBase, ReturnCode
88using SciMLLogging: @SciMLMessage
9+ using TriangularSolve
910
1011LinearSolve. userecursivefactorization (A:: Union{Nothing, AbstractMatrix} ) = true
1112
@@ -20,7 +21,6 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization
2021 end
2122 fact = RecursiveFactorization. lu! (A, ipiv, Val (P), Val (T), check = false )
2223 cache. cacheval = (fact, ipiv)
23-
2424 if ! LinearAlgebra. issuccess (fact)
2525 @SciMLMessage (" Solver failed" , cache. verbose, :solver_failure )
2626 return SciMLBase. build_linear_solution (
@@ -107,4 +107,41 @@ function SciMLBase.solve!(
107107 alg, cache. u, nothing , cache; retcode = ReturnCode. Success)
108108end
109109
110+ function SciMLBase. solve! (cache:: LinearSolve.LinearCache , alg:: ButterflyFactorization ;
111+ kwargs... )
112+ cache_A = cache. A
113+ cache_A = convert (AbstractMatrix, cache_A)
114+ cache_b = cache. b
115+ M, N = size (cache_A)
116+ workspace = cache. cacheval[1 ]
117+ thread = alg. thread
118+
119+ if cache. isfresh
120+ @assert M== N " A must be square"
121+ if (size (workspace. A, 1 ) != M)
122+ workspace = RecursiveFactorization.🦋workspace (cache_A, cache_b)
123+ end
124+ (;A, b, ws, U, V, out, tmp, n) = workspace
125+ RecursiveFactorization.🦋mul! (A, ws)
126+ F = RecursiveFactorization. lu! (A, Val (false ), thread)
127+ cache. cacheval = (workspace, F)
128+ cache. isfresh = false
129+ end
130+
131+ workspace, F = cache. cacheval
132+ (;A, b, ws, U, V, out, tmp, n) = workspace
133+ b[1 : M] .= cache_b
134+ mul! (tmp, U' , b)
135+ TriangularSolve. ldiv! (F, tmp, thread)
136+ mul! (b, V, tmp)
137+ out .= @view b[1 : n]
138+ SciMLBase. build_linear_solution (alg, out, nothing , cache)
139+ end
140+
141+ function LinearSolve. init_cacheval (alg:: ButterflyFactorization , A, b, u, Pl, Pr, maxiters:: Int ,
142+ abstol, reltol, verbose:: Bool , assumptions:: LinearSolve.OperatorAssumptions )
143+ ws = RecursiveFactorization.🦋workspace (A, b), RecursiveFactorization. lu! (rand (1 , 1 ), Val (false ), alg. thread)
110144end
145+
146+ end
147+
0 commit comments