Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
786bc0d
solve_shifted_system!
farhadrclass Sep 23, 2024
f95d210
Update src/utilities.jl
farhadrclass Sep 23, 2024
bac9f6f
Update src/utilities.jl
farhadrclass Sep 23, 2024
62d35f6
Update src/utilities.jl
farhadrclass Sep 23, 2024
186b6c5
Update src/utilities.jl
farhadrclass Sep 23, 2024
0d4f952
Update src/utilities.jl
farhadrclass Sep 23, 2024
4672bd3
Update src/utilities.jl
farhadrclass Sep 23, 2024
f3ea738
Update src/utilities.jl
farhadrclass Sep 23, 2024
4bf8f55
Update src/utilities.jl
farhadrclass Sep 23, 2024
025645e
Update src/utilities.jl
farhadrclass Sep 23, 2024
94fb35f
changes after review
farhadrclass Sep 23, 2024
ded60c1
Update src/lbfgs.jl
farhadrclass Sep 28, 2024
e5b054a
Update src/lbfgs.jl
farhadrclass Sep 28, 2024
6c40d0a
Update src/utilities.jl
farhadrclass Sep 28, 2024
773f2fd
Update src/utilities.jl
farhadrclass Sep 28, 2024
13162dc
Update src/utilities.jl
farhadrclass Sep 28, 2024
2cf4246
updated according to the PR review
farhadrclass Oct 2, 2024
5c1eeae
updated according to PR, renamed variable
farhadrclass Oct 2, 2024
1323a98
added ldiv!
farhadrclass Oct 3, 2024
d1da6ef
Update src/utilities.jl
farhadrclass Oct 4, 2024
ad9567f
Update src/utilities.jl
farhadrclass Oct 4, 2024
d5c1cfd
Update src/utilities.jl
farhadrclass Oct 4, 2024
759ab02
Update src/utilities.jl
farhadrclass Oct 4, 2024
3094644
Update src/utilities.jl
farhadrclass Oct 4, 2024
88479e5
Update src/utilities.jl
farhadrclass Oct 4, 2024
bc6ba12
Update src/utilities.jl
farhadrclass Oct 4, 2024
5a173ef
Update src/utilities.jl
farhadrclass Oct 4, 2024
4ad62e6
Update src/utilities.jl
farhadrclass Oct 4, 2024
17ffb34
Update src/utilities.jl
farhadrclass Oct 4, 2024
87bd7b6
Update src/utilities.jl
farhadrclass Oct 4, 2024
e039299
Update test/test_solve_shifted_system.jl
farhadrclass Oct 4, 2024
89d28d6
Update test/test_solve_shifted_system.jl
farhadrclass Oct 4, 2024
aa4c8a3
added the example
farhadrclass Oct 4, 2024
90084b9
Update utilities.jl
farhadrclass Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Function | Description
`size` | Return the size of a linear operator
`symmetric` | Determine whether the operator is symmetric
`normest` | Estimate the 2-norm
`solve_shifted_system!` | Computes the regularized L-BFGS step


## Other Operations on Operators
Expand Down
95 changes: 94 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export check_ctranspose, check_hermitian, check_positive_definite, normest
export check_ctranspose, check_hermitian, check_positive_definite, normest, solve_shifted_system!

"""
normest(S) estimates the matrix 2-norm of S.
Expand Down Expand Up @@ -145,3 +145,96 @@ end

check_positive_definite(M::AbstractMatrix; kwargs...) =
check_positive_definite(LinearOperator(M); kwargs...)


"""
solve_shifted_system!(op::LBFGSOperator{T,I,F1,F2,F3},
z::AbstractVector{T},
σ::T,
γ_inv::T,
inv_Cz::AbstractVector{T},
p::AbstractVector{T},
v::AbstractVector{T},
u::AbstractVector{T}) where {T,I,F1,F2,F3}

Computes the regularized L-BFGS step by solving the linear system:

(B_k + σ_k * I) s = -∇f(x_k)

where `B_k` is the L-BFGS approximation of the Hessian, `σ_k` is a regularization parameter, and `∇f(x_k)` is the gradient at `x_k`.

### Parameters
- `op::LBFGSOperator{T,I,F1,F2,F3}`: The L-BFGS operator `B_k`. Encodes the curvature information used to approximate the Hessian.
- `z::AbstractVector{T}`: The vector representing `-∇f(x_k)` (negative gradient at the current iterate).
- `σ::T`: The regularization parameter, used to shift the Hessian approximation `B_k` by adding a multiple of the identity matrix.
- `γ_inv::T`: The inverse of the initial curvature `γ_0`, used to initialize the L-BFGS matrix.
- `inv_Cz::AbstractVector{T}`: A preallocated vector used to store the result of the solution. It will be overwritten in the function to hold the computed `s`.
- `p::AbstractVector{T}`: A temporary matrix used in the computation to avoid allocating new memory during the solve process.
- `v::AbstractVector{T}`: A temporary vector for intermediate values during the solution process.
- `u::AbstractVector{T}`: A temporary vector that stores elements of the L-BFGS vectors `a_k` or `b_k` used in the algorithm.

### Returns
- `inv_Cz::AbstractVector{T}`: The solution vector `s` such that `(B_k + σ * I) s = -∇f(x_k)`.

### Method
The function solves the system efficiently without allocating new memory, by reusing preallocated arrays `inv_Cz`, `p`, `v`, and `u`. It computes the solution iteratively, making use of the structure of the L-BFGS approximation to the Hessian.

The method uses a two-loop recursion-like approach with modifications to handle the regularization term `σ`. The initial inverse Hessian approximation `B_0` is initialized using `γ_inv`.

The solution is computed on the CPU, but the function can be extended for GPU compatibility.

### Notes
- `data.a[k]` and `data.b[k]` represent stored L-BFGS vectors used in the computation of `B_k`. They are loaded into `u` and updated in each iteration.
- The computation involves vector dot products and matrix updates which are efficiently performed in place.

### To-Do
- Implement GPU support for further acceleration.

### References
@misc{erway2013shiftedlbfgssystems,
title={Shifted L-BFGS Systems},
author={Jennifer B. Erway and Vibhor Jain and Roummel F. Marcia},
year={2013},
eprint={1209.5141},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/1209.5141},
}
"""

function solve_shifted_system!(
op::LBFGSOperator{T, I, F1, F2, F3},
z::AbstractVector{T},
σ::T,
γ_inv::T,
inv_Cz::AbstractVector{T},
p::AbstractArray{T},
v::AbstractVector{T},
u::AbstractVector{T},
) where {T, I, F1, F2, F3}
data = op.data
insert = data.insert

inv_c0 = 1 / (γ_inv + σ)
@. inv_Cz = inv_c0 * z

max_i = 2 * data.mem
for i = 1:max_i
j = (i + 1) ÷ 2
k = mod(insert + j - 1, data.mem) + 1
u .= ((i % 2) == 0 ? data.b[k] : data.a[k])

@. p[:, i] = inv_c0 * u

for t = 1:(i - 1)
c0 = dot(view(p, :, t), u)
c1 = (-1)^(t + 1) .* v[t]
c2 = c1 * c0
view(p, :, i) .+= c2 .* view(p, :, t)
end

v[i] = 1 / (1 + (-1)^i * dot(u, view(p, :, i)))
inv_Cz .+= (-1)^(i + 1) * v[i] * (view(p, :, i)' * z) .* view(p, :, i)
end
return inv_Cz
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ include("test_deprecated.jl")
include("test_normest.jl")
include("test_diag.jl")
include("test_chainrules.jl")
include("test_solve_shifted_system.jl")
132 changes: 132 additions & 0 deletions test/test_solve_shifted_system.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using Test
using LinearOperators
using LinearAlgebra

function setup_test_val(; M = 5, n = 100, scaling = false)
σ = 0.1
B = LBFGSOperator(n, mem = 5, scaling = scaling)
s = ones(n)
y = ones(n)

S = randn(n, M)
Y = randn(n, M)

# make sure it is positive
for i = 1:M
if dot(S[:, i], Y[:, i]) < 0
S[:, i] = -S[:, i]
end
end
for i = 1:M
s = S[:, i]
y = Y[:, i]
if dot(s, y) > 1.0e-20
push!(B, s, y)
end
end

γ_inv = 1 / B.data.scaling_factor

x = randn(n)
z = B * x + σ .* x # so we know the true answer is x

data = B.data
# Preallocate vectors for efficiency
p = zeros(size(data.a[1], 1), 2 * (data.mem))
v = zeros(2 * (data.mem))
u = zeros(size(data.a[1], 1))

return B, z, σ, γ_inv, zeros(n), p, v, u, x
end

function test_solve_shifted_system()
@testset "solve_shifted_system! Default setup test" begin
# Setup Test Case 1: Default setup from setup_test_val
B, z, σ, γ_inv, inv_Cz, p, v, u, x = setup_test_val(n = 100, M = 5)

result = solve_shifted_system!(B, z, σ, γ_inv, inv_Cz, p, v, u)

# Test 1: Check if result is a vector of the same size as z
@test length(result) == length(z)

# Test 2: Verify that inv_Cz (result) is modified in place
@test result === inv_Cz

# Test 3: Check if the function produces finite values
@test all(isfinite, result)

# Test 4: Check if inv_Cz is close to the known solution x
# x = B \ (z ./ (1 + σ)) # Known true solution
@test isapprox(inv_Cz, x, atol = 1e-6, rtol = 1e-6)
end
@testset "solve_shifted_system! Larger dimensional system tests" begin

# Test Case 2: Larger dimensional system
dim = 10
mem_size = 5
B, z, σ, γ_inv, inv_Cz, p, v, u, x = setup_test_val(n = dim, M = mem_size)

result = solve_shifted_system!(B, z, σ, γ_inv, inv_Cz, p, v, u)

# Test 5: Check if result is a vector of the same size as z (larger case)
@test length(result) == length(z)

# Test 6: Verify that inv_Cz is modified in place (larger case)
@test result === inv_Cz

# Test 7: Check if the function produces finite values (larger case)
@test all(isfinite, result)

# Test 8: Check if inv_Cz is close to the known solution x (larger case)
# x = B \ (z ./ (1 + σ))
@test isapprox(inv_Cz, x, atol = 1e-6, rtol = 1e-6)
end

@testset "solve_shifted_system! Minimal memory size test" begin
# Test Case 3: Minimal memory size case (memory size = 1)
dim = 4
mem_size = 1
B, z, σ, γ_inv, inv_Cz, p, v, u, x = setup_test_val(n = dim, M = mem_size)

result = solve_shifted_system!(B, z, σ, γ_inv, inv_Cz, p, v, u)

# Test 9: Check if result is a vector of the same size as z (minimal memory)
@test length(result) == length(z)

# Test 10: Verify that inv_Cz is modified in place (minimal case)
@test result === inv_Cz

# Test 11: Check if the function produces finite values (minimal case)
@test all(isfinite, result)

# Test 12: Check if inv_Cz is close to the known solution x (minimal memory)
# x = B \ (z ./ (1 + σ))
@test isapprox(inv_Cz, x, atol = 1e-6, rtol = 1e-6)
end

@testset "solve_shifted_system! Extra large memory size test" begin

# Test Case 4: Even larger system with more memory (case 4)
dim = 50
mem_size = 10
B, z, σ, γ_inv, inv_Cz, p, v, u, x = setup_test_val(n = dim, M = mem_size)

# Call the function
result = solve_shifted_system!(B, z, σ, γ_inv, inv_Cz, p, v, u)

# Test 13: Check if result is a vector of the same size as z (case 4)
@test length(result) == length(z)

# Test 14: Verify that inv_Cz is modified in place (case 4)
@test result === inv_Cz

# Test 15: Check if the function produces finite values (case 4)
@test all(isfinite, result)

# Test 16: Check if inv_Cz is close to the known solution x (case 4)
# x = B \ (z ./ (1 + σ))
@test isapprox(inv_Cz, x, atol = 1e-6, rtol = 1e-6)
end
end

test_solve_shifted_system()
Loading