Skip to content

Commit c3ca509

Browse files
Merge pull request #785 from Shreyas-Ekanathan/main
Implement Butterfly Factorization method using RecursiveFactorization.jl
2 parents 96b7a54 + d188cf4 commit c3ca509

File tree

6 files changed

+108
-9
lines changed

6 files changed

+108
-9
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1"
2727
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2828
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2929
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
30+
TriangularSolve = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf"
3031
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3132

3233
[weakdeps]
@@ -125,7 +126,7 @@ PrecompileTools = "1.2"
125126
Preferences = "1.4"
126127
Random = "1.10"
127128
RecursiveArrayTools = "3.37"
128-
RecursiveFactorization = "0.2.23"
129+
RecursiveFactorization = "0.2.26"
129130
Reexport = "1.2.2"
130131
SafeTestsets = "0.1"
131132
SciMLBase = "2.70"
@@ -138,6 +139,7 @@ StableRNGs = "1.0"
138139
StaticArrays = "1.9"
139140
StaticArraysCore = "1.4.3"
140141
Test = "1.10"
142+
TriangularSolve = "0.2.1"
141143
UnPack = "1.0.2"
142144
Zygote = "0.7"
143145
blis_jll = "0.9.0"

benchmarks/lu.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using BenchmarkTools, Random, VectorizationBase
22
using LinearAlgebra, LinearSolve, MKL_jll
3+
using RecursiveFactorization
4+
35
nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads())
4-
BLAS.set_num_threads(nc)
56
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5
67

78
function luflop(m, n = m; innerflop = 2)
@@ -24,10 +25,10 @@ algs = [
2425
RFLUFactorization(),
2526
MKLLUFactorization(),
2627
FastLUFactorization(),
27-
SimpleLUFactorization()
28+
SimpleLUFactorization(),
29+
ButterflyFactorization(Val(true))
2830
]
2931
res = [Float64[] for i in 1:length(algs)]
30-
3132
ns = 4:8:500
3233
for i in 1:length(ns)
3334
n = ns[i]
@@ -65,3 +66,4 @@ p
6566

6667
savefig("lubench.png")
6768
savefig("lubench.pdf")
69+

ext/LinearSolveRecursiveFactorizationExt.jl

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module LinearSolveRecursiveFactorizationExt
22

33
using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval,
4-
RFLUFactorization, RF32MixedLUFactorization, default_alias_A,
5-
default_alias_b, LinearVerbosity
4+
RFLUFactorization, ButterflyFactorization, RF32MixedLUFactorization,
5+
default_alias_A, default_alias_b, LinearVerbosity
66
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization
77
using SciMLBase: SciMLBase, ReturnCode
88
using SciMLLogging: @SciMLMessage
9+
using TriangularSolve
910

1011
LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true
1112

@@ -20,7 +21,6 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization
2021
end
2122
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
2223
cache.cacheval = (fact, ipiv)
23-
2424
if !LinearAlgebra.issuccess(fact)
2525
@SciMLMessage("Solver failed", cache.verbose, :solver_failure)
2626
return SciMLBase.build_linear_solution(
@@ -107,4 +107,41 @@ function SciMLBase.solve!(
107107
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
108108
end
109109

110+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactorization;
111+
kwargs...)
112+
cache_A = cache.A
113+
cache_A = convert(AbstractMatrix, cache_A)
114+
cache_b = cache.b
115+
M, N = size(cache_A)
116+
workspace = cache.cacheval[1]
117+
thread = alg.thread
118+
119+
if cache.isfresh
120+
@assert M==N "A must be square"
121+
if (size(workspace.A, 1) != M)
122+
workspace = RecursiveFactorization.🦋workspace(cache_A, cache_b)
123+
end
124+
(;A, b, ws, U, V, out, tmp, n) = workspace
125+
RecursiveFactorization.🦋mul!(A, ws)
126+
F = RecursiveFactorization.lu!(A, Val(false), thread)
127+
cache.cacheval = (workspace, F)
128+
cache.isfresh = false
129+
end
130+
131+
workspace, F = cache.cacheval
132+
(;A, b, ws, U, V, out, tmp, n) = workspace
133+
b[1:M] .= cache_b
134+
mul!(tmp, U', b)
135+
TriangularSolve.ldiv!(F, tmp, thread)
136+
mul!(b, V, tmp)
137+
out .= @view b[1:n]
138+
SciMLBase.build_linear_solution(alg, out, nothing, cache)
139+
end
140+
141+
function LinearSolve.init_cacheval(alg::ButterflyFactorization, A, b, u, Pl, Pr, maxiters::Int,
142+
abstol, reltol, verbose::Bool, assumptions::LinearSolve.OperatorAssumptions)
143+
ws = RecursiveFactorization.🦋workspace(A, b), RecursiveFactorization.lu!(rand(1, 1), Val(false), alg.thread)
110144
end
145+
146+
end
147+

src/LinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ for kralg in (Krylov.lsmr!, Krylov.craigmr!)
446446
end
447447
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
448448
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
449-
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
449+
:RFLUFactorization, :ButterflyFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
450450
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
451451
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
452452
:MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization)
@@ -480,7 +480,7 @@ cudss_loaded(A) = false
480480
is_cusparse(A) = false
481481

482482
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
483-
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
483+
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization,
484484
NormalCholeskyFactorization, NormalBunchKaufmanFactorization,
485485
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
486486
SparspakFactorization, DiagonalFactorization, CholeskyFactorization,

src/extension_algs.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,29 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror =
254254
RFLUFactorization(pivot, thread; throwerror)
255255
end
256256

257+
"""
258+
`ButterflyFactorization()`
259+
260+
A fast pure Julia LU-factorization implementation
261+
using RecursiveFactorization.jl. This method utilizes a butterly
262+
factorization approach rather than pivoting.
263+
"""
264+
struct ButterflyFactorization{T} <: AbstractDenseFactorization
265+
thread::Val{T}
266+
function ButterflyFactorization(::Val{T}; throwerror = true) where {T}
267+
if !userecursivefactorization(nothing)
268+
throwerror &&
269+
error("ButterflyFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`")
270+
end
271+
new{T}()
272+
end
273+
end
274+
275+
function ButterflyFactorization(; thread = Val(true), throwerror = true)
276+
ButterflyFactorization(thread; throwerror)
277+
end
278+
279+
257280
# There's no options like pivot here.
258281
# But I'm not sure it makes sense as a GenericFactorization
259282
# since it just uses `LAPACK.getrf!`.

test/butterfly.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using LinearAlgebra, LinearSolve
2+
using Test
3+
using RecursiveFactorization
4+
5+
@testset "Random Matricies" begin
6+
for i in 490 : 510
7+
A = rand(i, i)
8+
b = rand(i)
9+
prob = LinearProblem(A, b)
10+
x = solve(prob, ButterflyFactorization())
11+
@test norm(A * x .- b) <= 1e-6
12+
end
13+
end
14+
15+
function wilkinson(N)
16+
A = zeros(N, N)
17+
A[1:(N+1):N*N] .= 1
18+
A[:, end] .= 1
19+
for n in 1:(N - 1)
20+
for r in (n + 1):N
21+
@inbounds A[r, n] = -1
22+
end
23+
end
24+
A
25+
end
26+
27+
@testset "Wilkinson" begin
28+
for i in 790 : 810
29+
A = wilkinson(i)
30+
b = rand(i)
31+
prob = LinearProblem(A, b)
32+
x = solve(prob, ButterflyFactorization())
33+
@test norm(A * x .- b) <= 1e-10
34+
end
35+
end

0 commit comments

Comments
 (0)