Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/largeScale.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import FrankWolfe
using LinearAlgebra

n = Int(1e9)
k = 1000
n = Int(1e5)
k = 10000

xpi = rand(n);
total = sum(xpi);
Expand All @@ -16,7 +16,7 @@ end
# better for memory consumption as we do coordinate-wise ops

function cf(x, xp)
return norm(x .- xp)^2
return LinearAlgebra.norm(x .- xp)^2
end

function cgrad!(storage, x, xp)
Expand All @@ -33,6 +33,8 @@ FrankWolfe.benchmark_oracles(x -> cf(x, xp), (str, x) -> cgrad!(str, x, xp), lmo
(str, x) -> cgrad!(str, x, xp),
lmo,
x0,
nep=false,
L=2,
max_iteration=k,
line_search=FrankWolfe.agnostic,
print_iter=k / 10,
Expand Down
13 changes: 11 additions & 2 deletions src/FrankWolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function fw(
verbose=false,
linesearch_tol=1e-7,
emphasis::Emphasis=blas,
nep=false,
gradient=nothing
)
function print_header(data)
Expand Down Expand Up @@ -100,7 +101,7 @@ function fw(
trajData = []
time_start = time_ns()

if (line_search === shortstep || line_search === adaptive) && L == Inf
if (line_search === shortstep || line_search === adaptive || nep === true ) && L == Inf
println("FATAL: Lipschitz constant not set. Prepare to blow up spectacularly.")
end

Expand Down Expand Up @@ -143,7 +144,7 @@ function fw(
@emphasis(emphasis, gradient = (momentum * gradient) + (1 - momentum) * gtemp)
end
first_iter = false

v = compute_extreme_point(lmo, gradient)

# go easy on the memory - only compute if really needed
Expand All @@ -162,6 +163,14 @@ function fw(
(t, primal, primal - dual_gap, dual_gap, (time_ns() - time_start) / 1.0e9),
)
end

# build-in NEP here
if nep === true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if nep === true
if nep

just checking the boolean directly

# argmin_v v^T(1-2y)
# y = x_t - 1/L * (t+1)/2 * gradient
@. gradient = 1 - 2 * (x - 1 / (L * 2 / (t+1)) * gradient)
v = compute_extreme_point(lmo, gradient)
end

if line_search === agnostic
gamma = 2 // (2 + t)
Expand Down