diff --git a/.gitignore b/.gitignore index 17ee9912f..c394e1ab8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ test/bug.jl before.pdf after.pdf *.pdf +prof.txt diff --git a/examples/largeScale.jl b/examples/largeScale.jl index 81e5fc5e5..18eafd496 100644 --- a/examples/largeScale.jl +++ b/examples/largeScale.jl @@ -1,8 +1,8 @@ import FrankWolfe using LinearAlgebra -n = Int(1e9) -k = 1000 +n = Int(1e5) +k = 10000 xpi = rand(n); total = sum(xpi); @@ -16,7 +16,7 @@ end # better for memory consumption as we do coordinate-wise ops function cf(x, xp) - return norm(x .- xp)^2 + return LinearAlgebra.norm(x .- xp)^2 end function cgrad!(storage, x, xp) @@ -33,6 +33,8 @@ FrankWolfe.benchmark_oracles(x -> cf(x, xp), (str, x) -> cgrad!(str, x, xp), lmo (str, x) -> cgrad!(str, x, xp), lmo, x0, + nep=false, + L=2, max_iteration=k, line_search=FrankWolfe.agnostic, print_iter=k / 10, diff --git a/src/FrankWolfe.jl b/src/FrankWolfe.jl index 370515982..d8110d673 100644 --- a/src/FrankWolfe.jl +++ b/src/FrankWolfe.jl @@ -56,6 +56,7 @@ function fw( verbose=false, linesearch_tol=1e-7, emphasis::Emphasis=blas, + nep=false, gradient=nothing ) function print_header(data) @@ -103,7 +104,7 @@ function fw( trajData = [] time_start = time_ns() - if (line_search === shortstep || line_search === adaptive) && L == Inf + if (line_search === shortstep || line_search === adaptive || nep === true ) && L == Inf println("FATAL: Lipschitz constant not set. Prepare to blow up spectacularly.") end @@ -146,7 +147,7 @@ function fw( @emphasis(emphasis, gradient = (momentum * gradient) + (1 - momentum) * gtemp) end first_iter = false - + v = compute_extreme_point(lmo, gradient) # go easy on the memory - only compute if really needed @@ -165,6 +166,14 @@ function fw( (t, primal, primal - dual_gap, dual_gap, (time_ns() - time_start) / 1.0e9), ) end + + # build-in NEP here + if nep === true + # argmin_v v^T(1-2y) + # y = x_t - 1/L * (t+1)/2 * gradient + @. gradient = 1 - 2 * (x - 1 / (L * 2 / (t+1)) * gradient) + v = compute_extreme_point(lmo, gradient) + end if line_search === agnostic gamma = 2 // (2 + t)