Skip to content

Commit 2f1c1b3

Browse files
committed
Convert common sub-functions as common sub-expressions
1 parent 1694a00 commit 2f1c1b3

File tree

5 files changed

+83
-3
lines changed

5 files changed

+83
-3
lines changed

src/Nonlinear/parse.jl

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,49 @@ function parse_expression(::Model, ::Expression, x::Any, ::Int)
3636
)
3737
end
3838

39+
function _extract_subexpression!(expr::Expression, root::Int)
40+
nodes_idx = [root]
41+
values_idx = Int[]
42+
for i in (root+1):length(expr.nodes)
43+
node = expr.nodes[i]
44+
j = searchsortedlast(nodes_idx, node.parent)
45+
if j == 0
46+
continue
47+
end
48+
if node.parent == nodes_idx[j]
49+
push!(nodes_idx, i)
50+
index = node.index
51+
if node.type == NODE_VALUE
52+
push!(values_idx, node.index)
53+
index = length(values_idx)
54+
end
55+
expr.nodes[i] = Node(node.type, node.index, j)
56+
else
57+
index = node.index
58+
if node.type == NODE_VALUE
59+
# We use the fact that values of `node.index` are increasing
60+
# along the nodes of `expr.nodes` for which `node.type` is `NODE_VALUE`
61+
index -= length(values_idx)
62+
end
63+
expr.nodes[i] = Node(node.type, index, node.parent - j)
64+
end
65+
end
66+
subexpr = Expression(expr.nodes[nodes_idx], expr.values[values_idx])
67+
deleteat!(expr.nodes, nodes_idx)
68+
deleteat!(expr.values, values_idx)
69+
return subexpr
70+
end
71+
72+
function _extract_subexpression!(data::Model, expr::Expression, root::Int)
73+
parent = expr.nodes[root].parent
74+
push!(data.expressions, _extract_subexpression!(expr, root))
75+
index = ExpressionIndex(length(data.expressions))
76+
if parent != 0
77+
push!(expr.nodes, Node(NODE_SUBEXPRESSION, index.value, parent))
78+
end
79+
return index
80+
end
81+
3982
function parse_expression(
4083
data::Model,
4184
expr::Expression,
@@ -46,7 +89,16 @@ function parse_expression(
4689
while !isempty(stack)
4790
parent_node, arg = pop!(stack)
4891
if arg isa MOI.ScalarNonlinearFunction
49-
_parse_without_recursion_inner(stack, data, expr, arg, parent_node)
92+
if haskey(data.cache, arg)
93+
subexpr = data.cache[arg]
94+
if subexpr isa Tuple{Expression,Int}
95+
subexpr = _extract_subexpression!(data, subexpr...)
96+
end
97+
parse_expression(data, expr, subexpr::ExpressionIndex, parent_node)
98+
else
99+
_parse_without_recursion_inner(stack, data, expr, arg, parent_node)
100+
data.cache[arg] = (expr, length(expr.nodes))
101+
end
50102
else
51103
# We can use recursion here, because ScalarNonlinearFunction only
52104
# occur in other ScalarNonlinearFunction.
@@ -82,7 +134,7 @@ function _parse_without_recursion_inner(stack, data, expr, x, parent)
82134
parent = length(expr.nodes)
83135
# Args need to be pushed onto the stack in reverse because the stack is a
84136
# first-in last-out datastructure.
85-
for arg in reverse(x.args)
137+
for arg in Iterators.Reverse(x.args)
86138
push!(stack, (parent, arg))
87139
end
88140
return

src/Nonlinear/types.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ tree.
7676
struct Expression
7777
nodes::Vector{Node}
7878
values::Vector{Float64}
79-
Expression() = new(Node[], Float64[])
8079
end
8180

81+
Expression() = Expression(Node[], Float64[])
82+
8283
function Base.:(==)(x::Expression, y::Expression)
8384
return x.nodes == y.nodes && x.values == y.values
8485
end
@@ -165,6 +166,8 @@ mutable struct Model
165166
operators::OperatorRegistry
166167
# This is a private field, used only to increment the ConstraintIndex.
167168
last_constraint_index::Int64
169+
# This is a private field, used to detect common subexpressions.
170+
cache::Dict{MOI.ScalarNonlinearFunction,Union{ExpressionIndex,Tuple{Expression,Int}}}
168171
function Model()
169172
return new(
170173
nothing,
@@ -173,6 +176,7 @@ mutable struct Model
173176
Float64[],
174177
OperatorRegistry(),
175178
0,
179+
Dict{MOI.ScalarNonlinearFunction,Union{ExpressionIndex,Tuple{Expression,Int}}}(),
176180
)
177181
end
178182
end

src/Utilities/functions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,8 @@ function canonical(f::MOI.AbstractFunction)
10541054
return g
10551055
end
10561056

1057+
canonical(f::MOI.ScalarNonlinearFunction) = f
1058+
10571059
canonicalize!(f::Union{MOI.VectorOfVariables,MOI.VariableIndex}) = f
10581060

10591061
"""

src/Utilities/vector_of_constraints.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ function MOI.get(
103103
) where {F,S}
104104
MOI.throw_if_not_valid(v, ci)
105105
f, _ = v.constraints[ci]::Tuple{F,S}
106+
if f isa MOI.ScalarNonlinearFunction
107+
return f
108+
end
106109
return copy(f)
107110
end
108111

test/Nonlinear/Nonlinear.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,3 +1449,22 @@ end
14491449
end # TestNonlinear
14501450

14511451
TestNonlinear.runtests()
1452+
1453+
using Revise, Test
1454+
import MathOptInterface as MOI
1455+
1456+
model = MOI.Nonlinear.Model()
1457+
x = MOI.VariableIndex(1)
1458+
sub = MOI.ScalarNonlinearFunction(:^, Any[x, 3])
1459+
func = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
1460+
expr = MOI.Nonlinear.parse_expression(model, func)
1461+
1462+
func = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
1463+
g = MOI.ScalarNonlinearFunction(:+, Any[
1464+
sub,
1465+
MOI.ScalarNonlinearFunction(
1466+
:*,
1467+
Any[1, sub]
1468+
),
1469+
])
1470+
expr = MOI.Nonlinear.parse_expression(model, g)

0 commit comments

Comments
 (0)