Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 406cc6b

Browse files
Merge pull request #148 from SciML/ChrisRackauckas-patch-1
Unstrict Zygote
2 parents 2a37950 + 430f17f commit 406cc6b

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function OptimizationBase.instantiate_function(
3030
adtype, soadtype = OptimizationBase.generate_adtype(adtype)
3131

3232
if g == true && f.grad === nothing
33-
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p))
33+
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict=Val(false))
3434
function grad(res, θ)
3535
gradient!(f.f, res, prep_grad, adtype, θ, Constant(p))
3636
end
@@ -47,7 +47,7 @@ function OptimizationBase.instantiate_function(
4747

4848
if fg == true && f.fg === nothing
4949
if g == false
50-
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p))
50+
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict=Val(false))
5151
end
5252
function fg!(res, θ)
5353
(y, _) = value_and_gradient!(f.f, res, prep_grad, adtype, θ, Constant(p))
@@ -68,7 +68,7 @@ function OptimizationBase.instantiate_function(
6868
hess_sparsity = f.hess_prototype
6969
hess_colors = f.hess_colorvec
7070
if h == true && f.hess === nothing
71-
prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p))
71+
prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict=Val(false))
7272
function hess(res, θ)
7373
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
7474
end
@@ -143,7 +143,7 @@ function OptimizationBase.instantiate_function(
143143
cons_jac_prototype = f.cons_jac_prototype
144144
cons_jac_colorvec = f.cons_jac_colorvec
145145
if cons !== nothing && cons_j == true && f.cons_j === nothing
146-
prep_jac = prepare_jacobian(cons_oop, adtype, x)
146+
prep_jac = prepare_jacobian(cons_oop, adtype, x, strict=Val(false))
147147
function cons_j!(J, θ)
148148
jacobian!(cons_oop, J, prep_jac, adtype, θ)
149149
if size(J, 1) == 1
@@ -157,7 +157,7 @@ function OptimizationBase.instantiate_function(
157157
end
158158

159159
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
160-
prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),))
160+
prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),), strict=Val(false))
161161
function cons_vjp!(J, θ, v)
162162
pullback!(cons_oop, (J,), prep_pullback, adtype, θ, (v,))
163163
end
@@ -169,7 +169,7 @@ function OptimizationBase.instantiate_function(
169169

170170
if cons !== nothing && f.cons_jvp === nothing && cons_jvp == true
171171
prep_pushforward = prepare_pushforward(
172-
cons_oop, adtype, x, (ones(eltype(x), length(x)),))
172+
cons_oop, adtype, x, (ones(eltype(x), length(x)),), strict=Val(false))
173173
function cons_jvp!(J, θ, v)
174174
pushforward!(cons_oop, (J,), prep_pushforward, adtype, θ, (v,))
175175
end
@@ -182,7 +182,7 @@ function OptimizationBase.instantiate_function(
182182
conshess_sparsity = f.cons_hess_prototype
183183
conshess_colors = f.cons_hess_colorvec
184184
if cons !== nothing && cons_h == true && f.cons_h === nothing
185-
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
185+
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i), strict=Val(false))
186186
for i in 1:num_cons]
187187

188188
function cons_h!(H, θ)
@@ -201,7 +201,7 @@ function OptimizationBase.instantiate_function(
201201
if f.lag_h === nothing && cons !== nothing && lag_h == true
202202
lag_extras = prepare_hessian(
203203
lagrangian, soadtype, x, Constant(one(eltype(x))),
204-
Constant(ones(eltype(x), num_cons)), Constant(p))
204+
Constant(ones(eltype(x), num_cons)), Constant(p), strict=Val(false))
205205
lag_hess_prototype = zeros(Bool, num_cons, length(x))
206206

207207
function lag_h!(H::AbstractMatrix, θ, σ, λ)
@@ -294,7 +294,7 @@ function OptimizationBase.instantiate_function(
294294
adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype)
295295

296296
if g == true && f.grad === nothing
297-
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p))
297+
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p), strict=Val(false))
298298
function grad(res, θ)
299299
gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
300300
end
@@ -311,7 +311,7 @@ function OptimizationBase.instantiate_function(
311311

312312
if fg == true && f.fg === nothing
313313
if g == false
314-
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p))
314+
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p), strict=Val(false))
315315
end
316316
function fg!(res, θ)
317317
(y, _) = value_and_gradient!(
@@ -334,7 +334,7 @@ function OptimizationBase.instantiate_function(
334334
hess_sparsity = f.hess_prototype
335335
hess_colors = f.hess_colorvec
336336
if h == true && f.hess === nothing
337-
prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p))
337+
prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict=Val(false))
338338
function hess(res, θ)
339339
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
340340
end
@@ -458,7 +458,7 @@ function OptimizationBase.instantiate_function(
458458
conshess_sparsity = f.cons_hess_prototype
459459
conshess_colors = f.cons_hess_colorvec
460460
if cons !== nothing && f.cons_h === nothing && cons_h == true
461-
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
461+
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i), strict=Val(false))
462462
for i in 1:num_cons]
463463
colores = getfield.(prep_cons_hess, :coloring_result)
464464
conshess_sparsity = getfield.(colores, :A)
@@ -479,7 +479,7 @@ function OptimizationBase.instantiate_function(
479479
if cons !== nothing && f.lag_h === nothing && lag_h == true
480480
lag_extras = prepare_hessian(
481481
lagrangian, soadtype, x, Constant(one(eltype(x))),
482-
Constant(ones(eltype(x), num_cons)), Constant(p))
482+
Constant(ones(eltype(x), num_cons)), Constant(p), strict=Val(false))
483483
lag_hess_prototype = lag_extras.coloring_result.A
484484
lag_hess_colors = lag_extras.coloring_result.color
485485

0 commit comments

Comments
 (0)