diff --git a/src/Nonlinear/parse.jl b/src/Nonlinear/parse.jl index 66dee02eab..7aa3c4ea0b 100644 --- a/src/Nonlinear/parse.jl +++ b/src/Nonlinear/parse.jl @@ -36,6 +36,77 @@ function parse_expression(::Model, ::Expression, x::Any, ::Int) ) end +function _extract_subexpression!(expr::Expression, root::Int) + n = length(expr.nodes) + # The whole subexpression is continuous in the tape + first_out = first_value = last_value = nothing + for i in root:n + node = expr.nodes[i] + if i != root && node.parent < root + first_out = i + break + end + index = node.index + if node.type == NODE_VALUE + if isnothing(first_value) + first_value = node.index + last_value = first_value + else + last_value = node.index + end + index -= first_value - 1 + end + expr.nodes[i] = + Node(node.type, index, i == root ? -1 : node.parent - root + 1) + end + if isnothing(first_out) + I = root:n + else + I = root:(first_out-1) + end + if isnothing(first_value) + V = nothing + else + V = first_value:last_value + end + if !isnothing(first_out) + for i in (last(I)+1):n + node = expr.nodes[i] + index = node.index + if node.type == NODE_VALUE && !isnothing(V) + @assert index >= last(V) + index -= length(V) + end + parent = node.parent + if parent > root + @assert parent > last(I) + parent -= length(I) - 1 + end + expr.nodes[i] = Node(node.type, index, parent) + end + end + return I, V +end + +function _extract_subexpression!(data::Model, expr::Expression, root::Int) + parent = expr.nodes[root].parent + I, V = _extract_subexpression!(expr, root) + subexpr = + Expression(expr.nodes[I], isnothing(V) ? Float64[] : expr.values[V]) + push!(data.expressions, subexpr) + index = ExpressionIndex(length(data.expressions)) + expr.nodes[root] = Node(NODE_SUBEXPRESSION, index.value, parent) + if length(I) > 1 + deleteat!(expr.nodes, I[2:end]) + if !isnothing(V) + deleteat!(expr.values, V) + end + else + @assert isnothing(V) + end + return index, I +end + function parse_expression( data::Model, expr::Expression, @@ -46,7 +117,59 @@ function parse_expression( while !isempty(stack) parent_node, arg = pop!(stack) if arg isa MOI.ScalarNonlinearFunction - _parse_without_recursion_inner(stack, data, expr, arg, parent_node) + if haskey(data.cache, arg) + subexpr = data.cache[arg] + if subexpr isa Tuple{Expression,Int} + _expr, _node = subexpr + subexpr, I = _extract_subexpression!(data, _expr, _node) + if expr === _expr + if parent_node > first(I) + @assert parent_node > last(I) + parent_node -= length(I) - 1 + end + for i in eachindex(stack) + _parent_node = stack[i][1] + if _parent_node > first(I) + @assert _parent_node > last(I) + stack[i] = + (_parent_node - length(I) + 1, stack[i][2]) + end + end + end + for (key, val) in data.cache + if val isa Tuple{Expression,Int} + __expr, __node = val + if _expr === __expr && __node > first(I) + if __node <= last(I) + data.cache[key] = ( + data.expressions[subexpr.value], + __node - first(I) + 1, + ) + else + data.cache[key] = + (__expr, __node - length(I) + 1) + end + end + end + end + data.cache[arg] = subexpr + end + parse_expression( + data, + expr, + subexpr::ExpressionIndex, + parent_node, + ) + else + _parse_without_recursion_inner( + stack, + data, + expr, + arg, + parent_node, + ) + data.cache[arg] = (expr, length(expr.nodes)) + end else # We can use recursion here, because ScalarNonlinearFunction only # occur in other ScalarNonlinearFunction. @@ -82,7 +205,7 @@ function _parse_without_recursion_inner(stack, data, expr, x, parent) parent = length(expr.nodes) # Args need to be pushed onto the stack in reverse because the stack is a # first-in last-out datastructure. - for arg in reverse(x.args) + for arg in Iterators.Reverse(x.args) push!(stack, (parent, arg)) end return diff --git a/src/Nonlinear/types.jl b/src/Nonlinear/types.jl index 36ac83a245..3ae8abde91 100644 --- a/src/Nonlinear/types.jl +++ b/src/Nonlinear/types.jl @@ -76,9 +76,10 @@ tree. struct Expression nodes::Vector{Node} values::Vector{Float64} - Expression() = new(Node[], Float64[]) end +Expression() = Expression(Node[], Float64[]) + function Base.:(==)(x::Expression, y::Expression) return x.nodes == y.nodes && x.values == y.values end @@ -165,6 +166,11 @@ mutable struct Model operators::OperatorRegistry # This is a private field, used only to increment the ConstraintIndex. last_constraint_index::Int64 + # This is a private field, used to detect common subexpressions. + cache::Dict{ + MOI.ScalarNonlinearFunction, + Union{ExpressionIndex,Tuple{Expression,Int}}, + } function Model() return new( nothing, @@ -173,6 +179,10 @@ mutable struct Model Float64[], OperatorRegistry(), 0, + Dict{ + MOI.ScalarNonlinearFunction, + Union{ExpressionIndex,Tuple{Expression,Int}}, + }(), ) end end diff --git a/src/Utilities/copy/index_map.jl b/src/Utilities/copy/index_map.jl index 7fe2aa00d9..3149886f7a 100644 --- a/src/Utilities/copy/index_map.jl +++ b/src/Utilities/copy/index_map.jl @@ -20,6 +20,7 @@ struct IndexMap <: AbstractDict{MOI.Index,MOI.Index} typeof(CleverDicts.index_to_key), } con_map::DoubleDicts.IndexDoubleDict + nl_cache::Dict{MOI.ScalarNonlinearFunction,MOI.ScalarNonlinearFunction} end """ @@ -30,7 +31,8 @@ The dictionary-like object returned by [`MOI.copy_to`](@ref). function IndexMap() var_map = CleverDicts.CleverDict{MOI.VariableIndex,MOI.VariableIndex}() con_map = DoubleDicts.IndexDoubleDict() - return IndexMap(var_map, con_map) + nl_cache = Dict{MOI.ScalarNonlinearFunction,MOI.ScalarNonlinearFunction}() + return IndexMap(var_map, con_map, nl_cache) end function _identity_constraints_map( @@ -104,3 +106,7 @@ Base.length(map::IndexMap) = length(map.var_map) + length(map.con_map) function Base.iterate(map::IndexMap, args...) return iterate(Base.Iterators.flatten((map.var_map, map.con_map)), args...) end + +function map_indices(index_map::IndexMap, f::MOI.ScalarNonlinearFunction) + return map_indices(Base.Fix1(getindex, index_map), f, index_map.nl_cache) +end diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 5826319ec4..b1aa0a5e55 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -346,7 +346,11 @@ end function map_indices( index_map::F, f::MOI.ScalarNonlinearFunction, + nl_cache = nothing, ) where {F<:Function} + if !isnothing(nl_cache) && haskey(nl_cache, f) + return nl_cache[f] + end root = MOI.ScalarNonlinearFunction(f.head, similar(f.args)) stack = Tuple{MOI.ScalarNonlinearFunction,Int,MOI.ScalarNonlinearFunction}[] for (i, fi) in enumerate(f.args) @@ -359,6 +363,10 @@ function map_indices( while !isempty(stack) parent, i, arg = pop!(stack) if arg isa MOI.ScalarNonlinearFunction + if !isnothing(nl_cache) && haskey(nl_cache, arg) + parent.args[i] = nl_cache[arg] + continue + end child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args)) for (j, argj) in enumerate(arg.args) if argj isa MOI.ScalarNonlinearFunction @@ -368,10 +376,16 @@ function map_indices( end end parent.args[i] = child + if !isnothing(nl_cache) + nl_cache[arg] = child + end else parent.args[i] = MOI.Utilities.map_indices(index_map, arg) end end + if !isnothing(nl_cache) + nl_cache[f] = root + end return root end @@ -1054,6 +1068,8 @@ function canonical(f::MOI.AbstractFunction) return g end +canonical(f::MOI.ScalarNonlinearFunction) = f + canonicalize!(f::Union{MOI.VectorOfVariables,MOI.VariableIndex}) = f """ diff --git a/src/Utilities/vector_of_constraints.jl b/src/Utilities/vector_of_constraints.jl index b12b97bd85..339b32543c 100644 --- a/src/Utilities/vector_of_constraints.jl +++ b/src/Utilities/vector_of_constraints.jl @@ -103,6 +103,9 @@ function MOI.get( ) where {F,S} MOI.throw_if_not_valid(v, ci) f, _ = v.constraints[ci]::Tuple{F,S} + if f isa MOI.ScalarNonlinearFunction + return f + end return copy(f) end diff --git a/test/Nonlinear/Nonlinear.jl b/test/Nonlinear/Nonlinear.jl index 5c85b4b0b6..105d503be0 100644 --- a/test/Nonlinear/Nonlinear.jl +++ b/test/Nonlinear/Nonlinear.jl @@ -1446,6 +1446,94 @@ function test_intercept_ForwardDiff_MethodError() return end +function test_extract_subexpression() + model = Nonlinear.Model() + x = MOI.VariableIndex(1) + sub = MOI.ScalarNonlinearFunction(:^, Any[x, 3]) + f = MOI.ScalarNonlinearFunction(:+, Any[sub, sub]) + expr = Nonlinear.parse_expression(model, f) + display(expr.nodes) + @test expr == Nonlinear.Expression( + [ + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1), + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1), + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1), + ], + Float64[], + ) + expected_sub = Nonlinear.Expression( + [ + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 4, -1) + Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1) + Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 1) + ], + [3.0], + ) + @test model.expressions == [expected_sub] + @test model.cache[sub] == Nonlinear.ExpressionIndex(1) + + h = MOI.ScalarNonlinearFunction(:*, Any[2, sub, 1]) + g = MOI.ScalarNonlinearFunction(:+, Any[sub, h]) + expr = MOI.Nonlinear.parse_expression(model, g) + expected_g = Nonlinear.Expression( + [ + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1) + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1) + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1) + Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 3) + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 3) + Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 3) + ], + [2.0, 1.0], + ) + @test expr == expected_g + # It should have detected the sub-expressions that was the same as `f` + @test model.expressions == [expected_sub] + # This means that it didn't get to extract from `g`, let's also test + # with extraction by starting with an empty model + + model = Nonlinear.Model() + MOI.Nonlinear.set_objective(model, g) + @test model.objective == expected_g + @test model.expressions == [expected_sub] + # Test that the objective function gets rewritten as we reuse `h` + # Also test that we don't change the parents in the stack of `h` + # by creating a long stack + prod = MOI.ScalarNonlinearFunction(:*, [h, x]) + sum = MOI.ScalarNonlinearFunction(:*, [x, x, x, x, prod]) + expr = Nonlinear.parse_expression(model, sum) + @test isempty(model.objective.values) + @test model.objective.nodes == [ + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1), + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1), + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 2, 1), + ] + @test model.expressions == [ + expected_sub, + Nonlinear.Expression( + [ + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, -1), + Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 1), + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1), + Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 1), + ], + [2.0, 1.0], + ), + ] + @test isempty(expr.values) + @test expr.nodes == [ + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, -1), + Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1), + Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1), + Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1), + Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1), + Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1), + Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 2, 6), + Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 6), + ] + return +end + end # TestNonlinear TestNonlinear.runtests()