Skip to content

Commit af961e5

Browse files
oscarddssmithoscardssmith
authored andcommitted
prepare for switching to Linsolve Interface
1 parent 8d830c9 commit af961e5

File tree

7 files changed

+106
-235
lines changed

7 files changed

+106
-235
lines changed

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,15 @@ function wrapprecs(_Pl, _Pr, weight, u)
5858
Pl, Pr
5959
end
6060

61+
function wrapprecs(linsolver, W, weight)
62+
if hasproperty(linsolver, :precs) && isnothing(linsolver.precs)
63+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
64+
Pr = Diagonal(_vec(weight))
65+
precs = Returns((Pl, Pr))
66+
return remake(linsolver; precs)
67+
else
68+
return linsolver
69+
end
70+
end
71+
6172
Base.resize!(p::LinearSolve.LinearCache, i) = p

lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,9 @@ end
252252
reltol = eps(eltype(dz))
253253
end
254254

255-
if is_always_new(nlsolver) || (iter == 1 && new_W)
256-
linres = dolinsolve(integrator, linsolve; A = W, b = _vec(b), linu = _vec(dz),
257-
reltol = reltol)
258-
else
259-
linres = dolinsolve(
260-
integrator, linsolve; A = nothing, b = _vec(b), linu = _vec(dz),
261-
reltol = reltol)
262-
end
255+
make_new_W = is_always_new(nlsolver) || (iter == 1 && new_W)
256+
linres = dolinsolve(integrator, linsolve; A = make_new_W ? W : nothing, b = _vec(b),
257+
linu = _vec(dz), reltol)
263258

264259
if !SciMLBase.successful_retcode(linres.retcode) &&
265260
linres.retcode != SciMLBase.ReturnCode.Default

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,9 @@ function build_nlsolver(
211211
jac_config = build_jac_config(alg, nf, uf, du1, uprev, u, ztmp, dz)
212212
end
213213
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
214-
linprob = LinearProblem(W, _vec(k); u0 = _vec(dz))
215-
Pl, Pr = wrapprecs(
216-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
217-
nothing)...,
218-
weight, dz)
219-
linsolve = init(linprob, alg.linsolve,
214+
linprob = LinearProblem(W, _vec(k), (isdae ? du1 : nothing,u,p,t); u0 = _vec(dz))
215+
linsolve = init(linprob, wrapprecs(alg.linsolve, W, weight),
220216
alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
221-
Pl = Pl, Pr = Pr,
222217
assumptions = LinearSolve.OperatorAssumptions(true))
223218

224219
tType = typeof(t)

lib/OrdinaryDiffEqRosenbrock/Project.toml

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,75 +4,73 @@ authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
44
version = "1.18.1"
55

66
[deps]
7-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
9+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
810
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
911
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
10-
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
11-
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
12+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1214
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
15+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
16+
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
17+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
18+
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
1319
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
1420
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
15-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16-
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
17-
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
18-
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
19-
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2021
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
21-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
22-
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
2322
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
24-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2523
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
24+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
25+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2626

27-
[extras]
28-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
29-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
30-
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
31-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
32-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
34-
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
35-
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
36-
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
37-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
27+
[sources]
28+
OrdinaryDiffEqCore = {path = "../OrdinaryDiffEqCore"}
29+
OrdinaryDiffEqDifferentiation = {path = "../OrdinaryDiffEqDifferentiation"}
3830

3931
[compat]
40-
ForwardDiff = "0.10.38, 1"
41-
Test = "<0.0.1, 1"
42-
FastBroadcast = "0.3"
43-
Random = "<0.0.1, 1"
32+
ADTypes = "1.16"
33+
AllocCheck = "0.2"
34+
Aqua = "0.8.11"
35+
DiffEqBase = "6.176"
4436
DiffEqDevTools = "2.44.4"
45-
FiniteDiff = "2.27"
46-
MuladdMacro = "0.2"
4737
DifferentiationInterface = "0.6.54, 0.7"
38+
Enzyme = "0.13"
39+
FastBroadcast = "0.3"
40+
FiniteDiff = "2.27"
41+
ForwardDiff = "0.10.38, 1"
42+
JET = "0.9.18, 0.10.4"
43+
LinearAlgebra = "1.10"
4844
LinearSolve = "3.26"
45+
MacroTools = "0.5"
46+
MuladdMacro = "0.2"
47+
ODEProblemLibrary = "0.1.8"
48+
OrdinaryDiffEqCore = "1.29.0"
49+
OrdinaryDiffEqDifferentiation = "1.12.0"
50+
OrdinaryDiffEqNonlinearSolve = "1.13.0"
4951
Polyester = "0.7"
5052
PrecompileTools = "1.2"
51-
LinearAlgebra = "1.10"
52-
OrdinaryDiffEqDifferentiation = "1.12.0"
53-
SciMLBase = "2.99"
54-
OrdinaryDiffEqCore = "1.29.0"
55-
Static = "1.2"
56-
Aqua = "0.8.11"
5753
Preferences = "1.4"
58-
Enzyme = "0.13"
59-
MacroTools = "0.5"
60-
julia = "1.10"
61-
JET = "0.9.18, 0.10.4"
62-
ADTypes = "1.16"
54+
Random = "<0.0.1, 1"
6355
RecursiveArrayTools = "3.36"
64-
ODEProblemLibrary = "0.1.8"
65-
OrdinaryDiffEqNonlinearSolve = "1.13.0"
66-
AllocCheck = "0.2"
67-
DiffEqBase = "6.176"
6856
Reexport = "1.2"
6957
SafeTestsets = "0.1.0"
58+
SciMLBase = "2.120.0"
59+
Static = "1.2"
60+
Test = "<0.0.1, 1"
61+
julia = "1.10"
62+
63+
[extras]
64+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
65+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
66+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
67+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
68+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
69+
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
70+
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
71+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
72+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
73+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7074

7175
[targets]
7276
test = ["DiffEqDevTools", "Random", "OrdinaryDiffEqNonlinearSolve", "SafeTestsets", "Test", "ODEProblemLibrary", "Enzyme", "JET", "Aqua", "AllocCheck"]
73-
74-
[sources.OrdinaryDiffEqDifferentiation]
75-
path = "../OrdinaryDiffEqDifferentiation"
76-
77-
[sources.OrdinaryDiffEqCore]
78-
path = "../OrdinaryDiffEqCore"

lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab
227227
function alg_cache(alg::$algname,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
228228
tf = TimeDerivativeWrapper(f,u,p)
229229
uf = UDerivativeWrapper(f,t,p)
230-
J,W = build_J_W(alg,u,uprev,p,t,dt,f, nothing, uEltypeNoUnits,Val(false))
230+
J,W = build_J_W(alg,u,uprev,p,t,dt,f,nothing, uEltypeNoUnits,Val(false))
231231
$constcachename(tf,uf,$tabname(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)),J,W,nothing)
232232
end
233233
function alg_cache(alg::$algname,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
@@ -238,6 +238,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab
238238
fsalfirst = zero(rate_prototype)
239239
fsallast = zero(rate_prototype)
240240
dT = zero(rate_prototype)
241+
241242
tmp = zero(rate_prototype)
242243
atmp = similar(u, uEltypeNoUnits)
243244
weight = similar(u, uEltypeNoUnits)
@@ -250,11 +251,8 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab
250251
grad_config = build_grad_config(alg,f,tf,du1,t)
251252
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2)
252253
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
253-
254-
linprob = LinearProblem(W,_vec(linsolve_tmp); u0=_vec(tmp))
255-
linsolve = init(linprob,alg.linsolve,alias = LinearAliasSpecifier(alias_A=true,alias_b=true),
256-
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
257-
Pr = Diagonal(_vec(weight)))
254+
linprob = LinearProblem(W,_vec(linsolve_tmp), (nothing, u, p, t); u0=_vec(tmp))
255+
linsolve = init(linprob,alg.linsolve, alias = LinearAliasSpecifier(alias_A=true,alias_b=true))
258256
$cachename($(valsyms...))
259257
end
260258
end
@@ -915,7 +913,6 @@ end
915913

916914

917915

918-
919916
"""
920917
@ROS2(part)
921918

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 20 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,12 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
154154

155155
grad_config = build_grad_config(alg, f, tf, du1, t)
156156
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
157-
158157
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
159-
160-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
161-
Pl, Pr = wrapprecs(
162-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
163-
nothing)..., weight, tmp)
164-
linsolve = init(
165-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
166-
Pl = Pl, Pr = Pr,
158+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing,u,p,t); u0 = _vec(tmp))
159+
linsolve = init(linprob, alg.linsolve,
160+
alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
167161
assumptions = LinearSolve.OperatorAssumptions(true))
168162

169-
170163
algebraic_vars = f.mass_matrix === I ? nothing :
171164
[all(iszero, x) for x in eachcol(f.mass_matrix)]
172165

@@ -201,22 +194,12 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
201194
tf = TimeGradientWrapper(f, uprev, p)
202195
uf = UJacobianWrapper(f, t, p)
203196
linsolve_tmp = zero(rate_prototype)
204-
205197
grad_config = build_grad_config(alg, f, tf, du1, t)
206198
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
207-
208199
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
209-
210-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
211-
212-
Pl, Pr = wrapprecs(
213-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
214-
nothing)..., weight, tmp)
215-
linsolve = init(
216-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
217-
Pl = Pl, Pr = Pr,
200+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing,u,p,t); u0 = _vec(tmp))
201+
linsolve = init(linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
218202
assumptions = LinearSolve.OperatorAssumptions(true))
219-
220203
algebraic_vars = f.mass_matrix === I ? nothing :
221204
[all(iszero, x) for x in eachcol(f.mass_matrix)]
222205

@@ -355,20 +338,13 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
355338
tf = TimeGradientWrapper(f, uprev, p)
356339
uf = UJacobianWrapper(f, t, p)
357340
linsolve_tmp = zero(rate_prototype)
358-
341+
359342
grad_config = build_grad_config(alg, f, tf, du1, t)
360343
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
361344
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
362-
363-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
364-
Pl, Pr = wrapprecs(
365-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
366-
nothing)..., weight, tmp)
367-
linsolve = init(
368-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
369-
Pl = Pl, Pr = Pr,
345+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp))
346+
linsolve = init(linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
370347
assumptions = LinearSolve.OperatorAssumptions(true))
371-
372348
Rosenbrock33Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
373349
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
374350
linsolve_tmp,
@@ -445,21 +421,13 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
445421
tf = TimeGradientWrapper(f, uprev, p)
446422
uf = UJacobianWrapper(f, t, p)
447423
linsolve_tmp = zero(rate_prototype)
448-
424+
449425
grad_config = build_grad_config(alg, f, tf, du1, t)
450426
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
451-
452427
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
453-
454-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
455-
Pl, Pr = wrapprecs(
456-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
457-
nothing)..., weight, tmp)
458-
linsolve = init(
459-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
460-
Pl = Pl, Pr = Pr,
428+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp))
429+
linsolve = init(linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
461430
assumptions = LinearSolve.OperatorAssumptions(true))
462-
463431
Rosenbrock34Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
464432
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
465433
linsolve_tmp,
@@ -643,21 +611,12 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
643611
tf = TimeGradientWrapper(f, uprev, p)
644612
uf = UJacobianWrapper(f, t, p)
645613
linsolve_tmp = zero(rate_prototype)
646-
647614
grad_config = build_grad_config(alg, f, tf, du1, t)
648615
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
649-
650616
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
651-
652-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
653-
Pl, Pr = wrapprecs(
654-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
655-
nothing)..., weight, tmp)
656-
linsolve = init(
657-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
658-
Pl = Pl, Pr = Pr,
617+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp))
618+
linsolve = init(linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
659619
assumptions = LinearSolve.OperatorAssumptions(true))
660-
661620
Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
662621
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
663622
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
@@ -691,22 +650,13 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
691650

692651
tf = TimeGradientWrapper(f, uprev, p)
693652
uf = UJacobianWrapper(f, t, p)
694-
653+
linsolve_tmp = zero(rate_prototype)
695654
grad_config = build_grad_config(alg, f, tf, du1, t)
696655
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
697-
698656
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
699-
700-
linsolve_tmp = zero(rate_prototype)
701-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
702-
Pl, Pr = wrapprecs(
703-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
704-
nothing)..., weight, tmp)
705-
linsolve = init(
706-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
707-
Pl = Pl, Pr = Pr,
657+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp))
658+
linsolve = init(linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
708659
assumptions = LinearSolve.OperatorAssumptions(true))
709-
710660
Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
711661
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
712662
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
@@ -804,25 +754,15 @@ function alg_cache(
804754

805755
tf = TimeGradientWrapper(f, uprev, p)
806756
uf = UJacobianWrapper(f, t, p)
807-
757+
linsolve_tmp = zero(rate_prototype)
808758
grad_config = build_grad_config(alg, f, tf, du1, t)
809759
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
810-
811760
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
761+
linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp))
762+
linsolve = init(linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
763+
assumptions = LinearSolve.OperatorAssumptions(true))
812764

813-
Pl, Pr = wrapprecs(
814-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
815-
nothing)..., weight, tmp)
816-
817-
linsolve_tmp = zero(rate_prototype)
818-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0=_vec(tmp))
819-
820-
linsolve = init(
821-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A=true, alias_b=true),
822-
Pl=Pl, Pr=Pr,
823-
assumptions=LinearSolve.OperatorAssumptions(true))
824765

825-
826766
# Return the cache struct with vectors
827767
RosenbrockCache(
828768
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,

0 commit comments

Comments
 (0)