Skip to content

Commit 9ec668b

Browse files
committed
use Parameter set
1 parent 91865cf commit 9ec668b

File tree

5 files changed

+129
-70
lines changed

5 files changed

+129
-70
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ optimize!(model)
6464

6565
# differentiate w.r.t. p
6666
direction_p = 3.0
67-
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(p), direction_p)
67+
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(p), Parameter(direction_p))
6868
DiffOpt.forward_differentiate!(model)
6969
@show MOI.get(model, DiffOpt.ForwardVariablePrimal(), x) == direction_p * 3 / pc_val
7070

@@ -82,7 +82,7 @@ optimize!(model)
8282
DiffOpt.empty_input_sensitivities!(model)
8383
# differentiate w.r.t. pc
8484
direction_pc = 10.0
85-
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(pc), direction_pc)
85+
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(pc), Parameter(direction_pc))
8686
DiffOpt.forward_differentiate!(model)
8787
@show abs(MOI.get(model, DiffOpt.ForwardVariablePrimal(), x) -
8888
-direction_pc * 3 * p_val / pc_val^2) < 1e-5
@@ -93,8 +93,8 @@ DiffOpt.empty_input_sensitivities!(model)
9393
direction_x = 10.0
9494
MOI.set(model, DiffOpt.ReverseVariablePrimal(), x, direction_x)
9595
DiffOpt.reverse_differentiate!(model)
96-
@show MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(p)) == direction_x * 3 / pc_val
97-
@show abs(MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(pc)) -
96+
@show MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(p)) == MOI.Parameter(direction_x * 3 / pc_val)
97+
@show abs(MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(pc)).value -
9898
-direction_x * 3 * p_val / pc_val^2) < 1e-5
9999
```
100100

src/DiffOpt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ include("utils.jl")
1919
include("product_of_sets.jl")
2020
include("diff_opt.jl")
2121
include("moi_wrapper.jl")
22-
include("jump_moi_overloads.jl")
2322
include("parameters.jl")
23+
include("jump_moi_overloads.jl")
2424

2525
include("copy_dual.jl")
2626
include("bridges.jl")

src/jump_moi_overloads.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,37 @@ function MOI.get(
8080
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
8181
end
8282

83+
# extras to handle model_dirty
84+
85+
function MOI.get(
86+
model::JuMP.Model,
87+
attr::ReverseConstraintSet,
88+
var_ref::JuMP.ConstraintRef,
89+
)
90+
JuMP.check_belongs_to_model(var_ref, model)
91+
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
92+
end
93+
94+
function MOI.set(
95+
model::JuMP.Model,
96+
attr::ForwardConstraintSet,
97+
con_ref::JuMP.ConstraintRef,
98+
set::MOI.AbstractScalarSet,
99+
)
100+
JuMP.check_belongs_to_model(con_ref, model)
101+
return MOI.set(JuMP.backend(model), attr, JuMP.index(con_ref), set)
102+
end
103+
104+
function MOI.set(
105+
model::JuMP.Model,
106+
attr::ForwardConstraintSet,
107+
con_ref::JuMP.ConstraintRef,
108+
set::JuMP.AbstractScalarSet,
109+
)
110+
JuMP.check_belongs_to_model(con_ref, model)
111+
return MOI.set(JuMP.backend(model), attr, JuMP.index(con_ref), JuMP.moi_set(set))
112+
end
113+
83114
"""
84115
abstract type AbstractLazyScalarFunction <: MOI.AbstractScalarFunction end
85116

src/parameters.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,14 @@ function MOI.set(
308308
model::POI.Optimizer,
309309
::ForwardConstraintSet,
310310
ci::MOI.ConstraintIndex{MOI.VariableIndex,MOI.Parameter{T}},
311-
value::Number,
311+
set::MOI.Parameter,
312312
) where {T}
313313
variable = MOI.VariableIndex(ci.value)
314314
if _is_variable(model, variable)
315315
error("Trying to set a forward parameter sensitivity for a variable")
316316
end
317317
sensitivity_data = _get_sensitivity_data(model)
318-
sensitivity_data.parameter_input_forward[variable] = value
318+
sensitivity_data.parameter_input_forward[variable] = set.value
319319
return
320320
end
321321

@@ -573,16 +573,7 @@ function MOI.get(
573573
error("Trying to get a backward parameter sensitivity for a variable")
574574
end
575575
sensitivity_data = _get_sensitivity_data(model)
576-
return get(sensitivity_data.parameter_output_backward, variable, 0.0)
577-
end
578-
579-
# extras to handle model_dirty
580-
581-
function MOI.get(
582-
model::JuMP.Model,
583-
attr::ReverseConstraintSet,
584-
var_ref::JuMP.ConstraintRef,
585-
)
586-
JuMP.check_belongs_to_model(var_ref, model)
587-
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
576+
return MOI.Parameter{T}(
577+
get(sensitivity_data.parameter_output_backward, variable, 0.0),
578+
)
588579
end

0 commit comments

Comments
 (0)