diff --git a/src/Utilities/copy/index_map.jl b/src/Utilities/copy/index_map.jl index 7fe2aa00d9..3149886f7a 100644 --- a/src/Utilities/copy/index_map.jl +++ b/src/Utilities/copy/index_map.jl @@ -20,6 +20,7 @@ struct IndexMap <: AbstractDict{MOI.Index,MOI.Index} typeof(CleverDicts.index_to_key), } con_map::DoubleDicts.IndexDoubleDict + nl_cache::Dict{MOI.ScalarNonlinearFunction,MOI.ScalarNonlinearFunction} end """ @@ -30,7 +31,8 @@ The dictionary-like object returned by [`MOI.copy_to`](@ref). function IndexMap() var_map = CleverDicts.CleverDict{MOI.VariableIndex,MOI.VariableIndex}() con_map = DoubleDicts.IndexDoubleDict() - return IndexMap(var_map, con_map) + nl_cache = Dict{MOI.ScalarNonlinearFunction,MOI.ScalarNonlinearFunction}() + return IndexMap(var_map, con_map, nl_cache) end function _identity_constraints_map( @@ -104,3 +106,7 @@ Base.length(map::IndexMap) = length(map.var_map) + length(map.con_map) function Base.iterate(map::IndexMap, args...) return iterate(Base.Iterators.flatten((map.var_map, map.con_map)), args...) end + +function map_indices(index_map::IndexMap, f::MOI.ScalarNonlinearFunction) + return map_indices(Base.Fix1(getindex, index_map), f, index_map.nl_cache) +end diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 5826319ec4..977b70ca00 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -346,7 +346,11 @@ end function map_indices( index_map::F, f::MOI.ScalarNonlinearFunction, + nl_cache = nothing, ) where {F<:Function} + if !isnothing(nl_cache) && haskey(nl_cache, f) + return nl_cache[f] + end root = MOI.ScalarNonlinearFunction(f.head, similar(f.args)) stack = Tuple{MOI.ScalarNonlinearFunction,Int,MOI.ScalarNonlinearFunction}[] for (i, fi) in enumerate(f.args) @@ -359,6 +363,10 @@ function map_indices( while !isempty(stack) parent, i, arg = pop!(stack) if arg isa MOI.ScalarNonlinearFunction + if !isnothing(nl_cache) && haskey(nl_cache, arg) + parent.args[i] = nl_cache[arg] + continue + end child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args)) for (j, argj) in enumerate(arg.args) if argj isa MOI.ScalarNonlinearFunction @@ -368,10 +376,16 @@ function map_indices( end end parent.args[i] = child + if !isnothing(nl_cache) + nl_cache[arg] = child + end else parent.args[i] = MOI.Utilities.map_indices(index_map, arg) end end + if !isnothing(nl_cache) + nl_cache[f] = root + end return root end