Skip to content

Commit 23cb0da

Browse files
committed
Add tests
1 parent 2f1c1b3 commit 23cb0da

File tree

2 files changed

+156
-48
lines changed

2 files changed

+156
-48
lines changed

src/Nonlinear/parse.jl

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,46 +37,62 @@ function parse_expression(::Model, ::Expression, x::Any, ::Int)
3737
end
3838

3939
function _extract_subexpression!(expr::Expression, root::Int)
40-
nodes_idx = [root]
41-
values_idx = Int[]
42-
for i in (root+1):length(expr.nodes)
43-
node = expr.nodes[i]
44-
j = searchsortedlast(nodes_idx, node.parent)
45-
if j == 0
46-
continue
40+
n = length(expr.nodes)
41+
# The whole subexpression is continuous in the tape
42+
first_out = findfirst((root+1):n) do i
43+
expr.nodes[i].parent < root
44+
end
45+
if isnothing(first_out)
46+
I = root:n
47+
else
48+
I = root .+ (0:(first_out - 1))
49+
end
50+
first_value = findfirst(I) do i
51+
expr.nodes[i].type == NODE_VALUE
52+
end
53+
if isnothing(first_value)
54+
V = nothing
55+
else
56+
last_value = findlast(I) do i
57+
expr.nodes[i].type == NODE_VALUE
4758
end
48-
if node.parent == nodes_idx[j]
49-
push!(nodes_idx, i)
59+
V = expr.nodes[I[first_value]].index:expr.nodes[I[last_value]].index
60+
end
61+
if !isnothing(first_out)
62+
for i in last(I)+1:n
63+
node = expr.nodes[i]
5064
index = node.index
51-
if node.type == NODE_VALUE
52-
push!(values_idx, node.index)
53-
index = length(values_idx)
65+
if node.type == NODE_VALUE && !isnothing(V)
66+
@assert index >= last(V)
67+
index -= length(V)
5468
end
55-
expr.nodes[i] = Node(node.type, node.index, j)
56-
else
57-
index = node.index
58-
if node.type == NODE_VALUE
59-
# We use the fact that values of `node.index` are increasing
60-
# along the nodes of `expr.nodes` for which `node.type` is `NODE_VALUE`
61-
index -= length(values_idx)
69+
parent = node.parent
70+
if parent > root
71+
@assert parent > last(I)
72+
parent -= length(I) - 1
6273
end
63-
expr.nodes[i] = Node(node.type, index, node.parent - j)
74+
expr.nodes[i] = Node(node.type, index, parent)
6475
end
6576
end
66-
subexpr = Expression(expr.nodes[nodes_idx], expr.values[values_idx])
67-
deleteat!(expr.nodes, nodes_idx)
68-
deleteat!(expr.values, values_idx)
69-
return subexpr
77+
return I, V
7078
end
7179

7280
function _extract_subexpression!(data::Model, expr::Expression, root::Int)
7381
parent = expr.nodes[root].parent
74-
push!(data.expressions, _extract_subexpression!(expr, root))
82+
I, V = _extract_subexpression!(expr, root)
83+
subexpr = Expression(expr.nodes[I], isnothing(V) ? Float64[] : expr.values[V])
84+
push!(data.expressions, subexpr)
7585
index = ExpressionIndex(length(data.expressions))
76-
if parent != 0
77-
push!(expr.nodes, Node(NODE_SUBEXPRESSION, index.value, parent))
86+
expr.nodes[root] = Node(NODE_SUBEXPRESSION, index.value, parent)
87+
if length(I) > 1
88+
deleteat!(expr.nodes, I[2:end])
89+
if !isnothing(V)
90+
deleteat!(expr.values, V)
91+
end
92+
else
93+
@assert isnothing(V)
7894
end
79-
return index
95+
return index, I
8096
end
8197

8298
function parse_expression(
@@ -92,7 +108,31 @@ function parse_expression(
92108
if haskey(data.cache, arg)
93109
subexpr = data.cache[arg]
94110
if subexpr isa Tuple{Expression,Int}
95-
subexpr = _extract_subexpression!(data, subexpr...)
111+
_expr, _node = subexpr
112+
subexpr, I = _extract_subexpression!(data, _expr, _node)
113+
if expr === _expr
114+
if parent_node > first(I)
115+
@assert parent_node > last(I)
116+
parent_node -= length(I) - 1
117+
end
118+
for i in eachindex(stack)
119+
_parent_node = stack[i][1]
120+
if _parent_node > first(I)
121+
@assert _parent_node > last(I)
122+
stack[i] = (_parent_node - length(I) + 1, stack[i][2])
123+
end
124+
end
125+
end
126+
for (key, val) in data.cache
127+
if val isa Tuple{Expression,Int}
128+
__expr, __node = val
129+
if _expr === __expr && __node > first(I)
130+
@assert __node > last(I)
131+
data.cache[key] = (__expr, __node - length(I) + 1)
132+
end
133+
end
134+
end
135+
data.cache[arg] = subexpr
96136
end
97137
parse_expression(data, expr, subexpr::ExpressionIndex, parent_node)
98138
else

test/Nonlinear/Nonlinear.jl

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,25 +1446,93 @@ function test_intercept_ForwardDiff_MethodError()
14461446
return
14471447
end
14481448

1449-
end # TestNonlinear
1450-
1451-
TestNonlinear.runtests()
1449+
function test_extract_subexpression()
1450+
model = Nonlinear.Model()
1451+
x = MOI.VariableIndex(1)
1452+
sub = MOI.ScalarNonlinearFunction(:^, Any[x, 3])
1453+
f = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
1454+
expr = Nonlinear.parse_expression(model, f)
1455+
@test expr == Nonlinear.Expression(
1456+
[
1457+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1),
1458+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1459+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1460+
],
1461+
Float64[],
1462+
)
1463+
expected_sub = Nonlinear.Expression(
1464+
[
1465+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 4, 1)
1466+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 2)
1467+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 2)
1468+
],
1469+
[3.0],
1470+
)
1471+
@test model.expressions == [expected_sub]
1472+
@test model.cache[sub] == Nonlinear.ExpressionIndex(1)
1473+
1474+
h = MOI.ScalarNonlinearFunction(:*, Any[2, sub, 1])
1475+
g = MOI.ScalarNonlinearFunction(:+, Any[sub, h])
1476+
expr = MOI.Nonlinear.parse_expression(model, g)
1477+
expected_g = Nonlinear.Expression(
1478+
[
1479+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1)
1480+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1)
1481+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1)
1482+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 3)
1483+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 3)
1484+
Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 3)
1485+
],
1486+
[2.0, 1.0],
1487+
)
1488+
@test expr == expected_g
1489+
# It should have detected the sub-expressions that was the same as `f`
1490+
@test model.expressions == [expected_sub]
1491+
# This means that it didn't get to extract from `g`, let's also test
1492+
# with extraction by starting with an empty model
14521493

1453-
using Revise, Test
1454-
import MathOptInterface as MOI
1494+
model = Nonlinear.Model()
1495+
MOI.Nonlinear.set_objective(model, g)
1496+
@test model.objective == expected_g
1497+
@test model.expressions == [expected_sub]
1498+
# Test that the objective function gets rewritten as we reuse `h`
1499+
# Also test that we don't change the parents in the stack of `h`
1500+
# by creating a long stack
1501+
prod = MOI.ScalarNonlinearFunction(
1502+
:*,
1503+
[h, x],
1504+
)
1505+
sum = MOI.ScalarNonlinearFunction(
1506+
:*,
1507+
[x, x, x, x, prod],
1508+
)
1509+
expr = Nonlinear.parse_expression(model, sum)
1510+
@test isempty(model.objective.values)
1511+
@test model.objective.nodes == [
1512+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1),
1513+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1514+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 2, 1),
1515+
]
1516+
@test model.expressions == [expected_sub, Nonlinear.Expression([
1517+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1),
1518+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 3),
1519+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 3),
1520+
Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 3),
1521+
], [2.0, 1.0])]
1522+
@test isempty(expr.values)
1523+
@test expr.nodes == [
1524+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, -1),
1525+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1526+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1527+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1528+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1),
1529+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1),
1530+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 2, 6),
1531+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 6),
1532+
]
1533+
return
1534+
end
14551535

1456-
model = MOI.Nonlinear.Model()
1457-
x = MOI.VariableIndex(1)
1458-
sub = MOI.ScalarNonlinearFunction(:^, Any[x, 3])
1459-
func = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
1460-
expr = MOI.Nonlinear.parse_expression(model, func)
1536+
end # TestNonlinear
14611537

1462-
func = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
1463-
g = MOI.ScalarNonlinearFunction(:+, Any[
1464-
sub,
1465-
MOI.ScalarNonlinearFunction(
1466-
:*,
1467-
Any[1, sub]
1468-
),
1469-
])
1470-
expr = MOI.Nonlinear.parse_expression(model, g)
1538+
TestNonlinear.runtests()

0 commit comments

Comments
 (0)