diff --git a/README.md b/README.md index 7a0d2686..1bfe2189 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ Function | Description `size` | Return the size of a linear operator `symmetric` | Determine whether the operator is symmetric `normest` | Estimate the 2-norm +`solve_shifted_system!` | Solves linear system $(B + \sigma I) x = b$, where $B$ is a forward L-BFGS operator and $\sigma \geq 0$. ## Other Operations on Operators diff --git a/src/lbfgs.jl b/src/lbfgs.jl index f03cad23..067f76e0 100644 --- a/src/lbfgs.jl +++ b/src/lbfgs.jl @@ -16,6 +16,9 @@ mutable struct LBFGSData{T, I <: Integer} b::Vector{Vector{T}} insert::I Ax::Vector{T} + shifted_p::Matrix{T} # Temporary matrix used in the computation solve_shifted_system! + shifted_v::Vector{T} + shifted_u::Vector{T} end function LBFGSData( @@ -43,6 +46,9 @@ function LBFGSData( inverse ? Vector{T}(undef, 0) : [zeros(T, n) for _ = 1:mem], 1, Vector{T}(undef, n), + Array{T}(undef, (n, 2*mem)), + Vector{T}(undef, 2*mem), + Vector{T}(undef, n) ) end diff --git a/src/utilities.jl b/src/utilities.jl index 81f953f7..0f53ea64 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -1,4 +1,5 @@ -export check_ctranspose, check_hermitian, check_positive_definite, normest +export check_ctranspose, check_hermitian, check_positive_definite, normest, solve_shifted_system!, ldiv! +import LinearAlgebra.ldiv! """ normest(S) estimates the matrix 2-norm of S. @@ -145,3 +146,141 @@ end check_positive_definite(M::AbstractMatrix; kwargs...) = check_positive_definite(LinearOperator(M); kwargs...) + + + """ + solve_shifted_system!(x, B, b, σ) + +Solve linear system (B + σI) x = b, where B is a forward L-BFGS operator and σ ≥ 0. + +### Parameters + +- `x::AbstractVector{T}`: preallocated vector of length n that is used to store the solution x. +- `B::LBFGSOperator`: forward L-BFGS operator that models a matrix of size n x n. +- `b::AbstractVector{T}`: right-hand side vector of length n. +- `σ::T`: nonnegative shift. + +### Returns + +- `x::AbstractVector{T}`: solution vector `x` of length n. + +### Method + +The method uses a two-loop recursion-like approach with modifications to handle the shift `σ`. + +### Example + +```julia +using Random + +# Problem setup +n = 100 # size of the problem +mem = 10 # L-BFGS memory size +scaling = true # enable scaling + +# Create an L-BFGS operator +B = LBFGSOperator(n, mem = mem, scaling = scaling) + +# Add random {s, y} pairs to the L-BFGS operator +for _ = 1:10 + s = rand(n) + y = rand(n) + push!(B, s, y) # Add the {s, y} pair to B +end + +# Prepare vectors for the system +x = zeros(n) # Preallocated solution vector +b = rand(n) # Right-hand side vector +σ = 0.1 # Small shift value + +# Solve the shifted system +result = solve_shifted_system!(x, B, b, σ) + +# Check that the solution is close enough (residual test) +@assert norm(B * x + σ * x - b) / norm(b) < 1e-8 +``` + +### References + +Erway, J. B., Jain, V., & Marcia, R. F. Shifted L-BFGS Systems. Optimization Methods and Software, 29(5), pp. 992-1004, 2014. +""" +function solve_shifted_system!( + x::AbstractVector{T}, + B::LBFGSOperator{T, I, F1, F2, F3}, + b::AbstractVector{T}, + σ::T, + ) where {T, I, F1, F2, F3} + + if σ < 0 + throw(ArgumentError("σ must be nonnegative")) + end + data = B.data + insert = data.insert + + γ_inv = 1 / data.scaling_factor + x_0 = 1 / (γ_inv + σ) + @. x = x_0 * b + + max_i = 2 * data.mem + sign_i = 1 + + for i = 1:max_i + j = (i + 1) ÷ 2 + k = mod(insert + j - 1, data.mem) + 1 + data.shifted_u .= ((sign_i == -1) ? data.b[k] : data.a[k]) + + @. data.shifted_p[:, i] = x_0 * data.shifted_u + + sign_t = 1 + for t = 1:(i - 1) + c0 = dot(view(data.shifted_p, :, t), data.shifted_u) + c1= sign_t .*data.shifted_v[t] + c2 = c1 * c0 + view(data.shifted_p, :, i) .+= c2 .* view(data.shifted_p, :, t) + sign_t = -sign_t + end + + data.shifted_v[i] = 1 / (1 - sign_i * dot(data.shifted_u, view(data.shifted_p, :, i))) + x .+= sign_i *data.shifted_v[i] * (view(data.shifted_p, :, i)' * b) .* view(data.shifted_p, :, i) + sign_i = -sign_i + end + return x +end + + +""" + ldiv!(x, B, b) + +Solves the linear system Bx = b. + +### Arguments: + +- `x::AbstractVector{T}`: preallocated vector of length n that is used to store the solution x. +- `B::LBFGSOperator`: forward L-BFGS operator that models a matrix of size n x n. +- `b::AbstractVector{T}`: right-hand side vector of length n. +### Returns: + +- `x::AbstractVector{T}`: The modified solution vector containing the solution to the linear system. + +### Examples: + +```julia + +# Create an L-BFGS operator +B = LBFGSOperator(10) + +# Generate random vectors +x = rand(10) +b = rand(10) + +# Solve the linear system +ldiv!(x, B, b) + +# The vector `x` now contains the solution +""" + +function ldiv!(x::AbstractVector{T}, B::LBFGSOperator{T, I, F1, F2, F3}, b::AbstractVector{T}) where {T, I, F1, F2, F3} + # Call solve_shifted_system! with σ = 0 + solve_shifted_system!(x, B, b, T(0.0)) + return x +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 57a4851c..eea61206 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,3 +15,4 @@ include("test_deprecated.jl") include("test_normest.jl") include("test_diag.jl") include("test_chainrules.jl") +include("test_solve_shifted_system.jl") \ No newline at end of file diff --git a/test/test_solve_shifted_system.jl b/test/test_solve_shifted_system.jl new file mode 100644 index 00000000..caafea53 --- /dev/null +++ b/test/test_solve_shifted_system.jl @@ -0,0 +1,64 @@ +using Test +using LinearOperators +using LinearAlgebra + +function setup_test_val(; M = 5, n = 100, scaling = false, σ = 0.1) + B = LBFGSOperator(n, mem = M, scaling = scaling) + H = InverseLBFGSOperator(n, mem = M, scaling = false) + + for _ = 1:10 + s = rand(n) + y = rand(n) + push!(B, s, y) + push!(H, s, y) + end + + x = randn(n) + b = B * x + σ .* x # so we know the true answer is x + + return B, H , b, σ, zeros(n), x +end + +function test_solve_shifted_system() + @testset "solve_shifted_system! Default setup test" begin + # Setup Test Case 1: Default setup from setup_test_val + B,_, b, σ, x_sol, x_true = setup_test_val(n = 100, M = 5) + + result = solve_shifted_system!(x_sol, B, b, σ) + + # Test 1: Check if result is a vector of the same size as z + @test length(result) == length(b) + + # Test 2: Verify that x_sol (result) is modified in place + @test result === x_sol + + # Test 3: Check if the function produces finite values + @test all(isfinite, result) + + # Test 4: Check if x_sol is close to the known solution x + @test isapprox(x_sol, x_true, atol = 1e-6, rtol = 1e-6) + end + @testset "solve_shifted_system! Negative σ test" begin + # Setup Test Case 2: Negative σ + B,_, b, _, x_sol, _ = setup_test_val(n = 100, M = 5) + σ = -0.1 + + # Expect an ArgumentError to be thrown + @test_throws ArgumentError solve_shifted_system!(x_sol, B, b, σ) + end + + @testset "ldiv! test" begin + # Setup Test Case 1: Default setup from setup_test_val + B, H, b, _, x_sol, x_true = setup_test_val(n = 100, M = 5, σ = 0.0) + + # Solve the system using solve_shifted_system! + result = ldiv!(x_sol, B, b) + + # Check consistency with operator-vector product using H + x_H = H * b + @test isapprox(x_sol, x_H, atol = 1e-6, rtol = 1e-6) + @test isapprox(x_sol, x_true, atol = 1e-6, rtol = 1e-6) + end +end + +test_solve_shifted_system()