Skip to content

Commit a78fc15

Browse files
committed
Fix
1 parent 9d8a931 commit a78fc15

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

src/Nonlinear/parse.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,34 @@ end
3939
function _extract_subexpression!(expr::Expression, root::Int)
4040
n = length(expr.nodes)
4141
# The whole subexpression is continuous in the tape
42-
first_out = findfirst((root+1):n) do i
43-
return expr.nodes[i].parent < root
42+
first_out = first_value = last_value = nothing
43+
for i in root:n
44+
node = expr.nodes[i]
45+
if i != root && node.parent < root
46+
first_out = i
47+
break
48+
end
49+
index = node.index
50+
if node.type == NODE_VALUE
51+
if isnothing(first_value)
52+
first_value = node.index
53+
last_value = first_value
54+
else
55+
last_value = node.index
56+
end
57+
index -= first_value - 1
58+
end
59+
expr.nodes[i] = Node(node.type, index, i == root ? -1 : node.parent - root + 1)
4460
end
4561
if isnothing(first_out)
4662
I = root:n
4763
else
48-
I = root .+ (0:(first_out-1))
49-
end
50-
first_value = findfirst(I) do i
51-
return expr.nodes[i].type == NODE_VALUE
64+
I = root:(first_out-1)
5265
end
5366
if isnothing(first_value)
5467
V = nothing
5568
else
56-
last_value = findlast(I) do i
57-
return expr.nodes[i].type == NODE_VALUE
58-
end
59-
V = expr.nodes[I[first_value]].index:expr.nodes[I[last_value]].index
69+
V = first_value:last_value
6070
end
6171
if !isnothing(first_out)
6272
for i in (last(I)+1):n
@@ -129,9 +139,12 @@ function parse_expression(
129139
if val isa Tuple{Expression,Int}
130140
__expr, __node = val
131141
if _expr === __expr && __node > first(I)
132-
@assert __node > last(I)
133-
data.cache[key] =
134-
(__expr, __node - length(I) + 1)
142+
if __node <= last(I)
143+
data.cache[key] = (data.expressions[subexpr.value], __node - first(I) + 1)
144+
else
145+
data.cache[key] =
146+
(__expr, __node - length(I) + 1)
147+
end
135148
end
136149
end
137150
end

test/Nonlinear/Nonlinear.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,7 @@ function test_extract_subexpression()
14521452
sub = MOI.ScalarNonlinearFunction(:^, Any[x, 3])
14531453
f = MOI.ScalarNonlinearFunction(:+, Any[sub, sub])
14541454
expr = Nonlinear.parse_expression(model, f)
1455+
display(expr.nodes)
14551456
@test expr == Nonlinear.Expression(
14561457
[
14571458
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 1, -1),
@@ -1462,9 +1463,9 @@ function test_extract_subexpression()
14621463
)
14631464
expected_sub = Nonlinear.Expression(
14641465
[
1465-
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 4, 1)
1466-
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 2)
1467-
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 2)
1466+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 4, -1)
1467+
Nonlinear.Node(Nonlinear.NODE_MOI_VARIABLE, 1, 1)
1468+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 1)
14681469
],
14691470
[3.0],
14701471
)
@@ -1511,10 +1512,10 @@ function test_extract_subexpression()
15111512
expected_sub,
15121513
Nonlinear.Expression(
15131514
[
1514-
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, 1),
1515-
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 3),
1516-
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 3),
1517-
Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 3),
1515+
Nonlinear.Node(Nonlinear.NODE_CALL_MULTIVARIATE, 3, -1),
1516+
Nonlinear.Node(Nonlinear.NODE_VALUE, 1, 1),
1517+
Nonlinear.Node(Nonlinear.NODE_SUBEXPRESSION, 1, 1),
1518+
Nonlinear.Node(Nonlinear.NODE_VALUE, 2, 1),
15181519
],
15191520
[2.0, 1.0],
15201521
),

0 commit comments

Comments
 (0)