Skip to content

Commit 0690749

Browse files
committed
Add nl cache to map_indices
1 parent 9a89cd2 commit 0690749

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/Utilities/copy/index_map.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct IndexMap <: AbstractDict{MOI.Index,MOI.Index}
2020
typeof(CleverDicts.index_to_key),
2121
}
2222
con_map::DoubleDicts.IndexDoubleDict
23+
nl_cache::Dict{MOI.ScalarNonlinearFunction,MOI.ScalarNonlinearFunction}
2324
end
2425

2526
"""
@@ -30,7 +31,8 @@ The dictionary-like object returned by [`MOI.copy_to`](@ref).
3031
function IndexMap()
3132
var_map = CleverDicts.CleverDict{MOI.VariableIndex,MOI.VariableIndex}()
3233
con_map = DoubleDicts.IndexDoubleDict()
33-
return IndexMap(var_map, con_map)
34+
nl_cache = Dict{MOI.ScalarNonlinearFunction,MOI.ScalarNonlinearFunction}()
35+
return IndexMap(var_map, con_map, nl_cache)
3436
end
3537

3638
function _identity_constraints_map(
@@ -104,3 +106,10 @@ Base.length(map::IndexMap) = length(map.var_map) + length(map.con_map)
104106
function Base.iterate(map::IndexMap, args...)
105107
return iterate(Base.Iterators.flatten((map.var_map, map.con_map)), args...)
106108
end
109+
110+
function map_indices(
111+
index_map::IndexMap,
112+
f::MOI.ScalarNonlinearFunction,
113+
)
114+
return map_indices(Base.Fix1(getindex, index_map), f, index_map.nl_cache)
115+
end

src/Utilities/functions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,11 @@ end
346346
function map_indices(
347347
index_map::F,
348348
f::MOI.ScalarNonlinearFunction,
349+
nl_cache = nothing,
349350
) where {F<:Function}
351+
if !isnothing(nl_cache) && haskey(nl_cache, f)
352+
return nl_cache[f]
353+
end
350354
root = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
351355
stack = Tuple{MOI.ScalarNonlinearFunction,Int,MOI.ScalarNonlinearFunction}[]
352356
for (i, fi) in enumerate(f.args)
@@ -359,6 +363,10 @@ function map_indices(
359363
while !isempty(stack)
360364
parent, i, arg = pop!(stack)
361365
if arg isa MOI.ScalarNonlinearFunction
366+
if !isnothing(nl_cache) && haskey(nl_cache, arg)
367+
parent.args[i] = nl_cache[arg]
368+
continue
369+
end
362370
child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args))
363371
for (j, argj) in enumerate(arg.args)
364372
if argj isa MOI.ScalarNonlinearFunction
@@ -368,10 +376,16 @@ function map_indices(
368376
end
369377
end
370378
parent.args[i] = child
379+
if !isnothing(nl_cache)
380+
nl_cache[arg] = child
381+
end
371382
else
372383
parent.args[i] = MOI.Utilities.map_indices(index_map, arg)
373384
end
374385
end
386+
if !isnothing(nl_cache)
387+
nl_cache[f] = root
388+
end
375389
return root
376390
end
377391

0 commit comments

Comments
 (0)