Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactoriz
kwargs...)
cache_A = cache.A
cache_A = convert(AbstractMatrix, cache_A)
b = cache.b
cache_b = cache.b
M, N = size(cache_A)
workspace = cache.cacheval[1]
thread = alg.thread

if cache.isfresh
@assert M==N "A must be square"
if (size(workspace.A, 1) != M)
workspace = RecursiveFactorization.🦋workspace(cache_A, b)
workspace = RecursiveFactorization.🦋workspace(cache_A, cache_b)
end
(;A, b, ws, U, V, out, tmp, n) = workspace
RecursiveFactorization.🦋mul!(A, ws)
Expand All @@ -130,6 +130,9 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactoriz

workspace, F = cache.cacheval
(;A, b, ws, U, V, out, tmp, n) = workspace
for i in 1:M
@inbounds b[i] = cache_b[i]
end
mul!(tmp, U', b)
TriangularSolve.ldiv!(F, tmp, thread)
mul!(b, V, tmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutating b will make it incorrect for the second solve. It seems like you need another temporary vector to do this right?

Expand Down
Loading