@@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt
22
33@info " IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"
44
5+ import Base: show
6+
57using DifferentialEquations
68import DifferentialEquations: solve
79
@@ -15,10 +17,30 @@ using DocStringExtensions
1517
1618export DERelative
1719
20+ import Manifolds: allocate, compose, hat, Identity, vee, log
1821
1922
2023getManifold (de:: DERelative{T} ) where {T} = getManifold (de. domain)
2124
25+
26+ function Base. show (
27+ io:: IO ,
28+ :: Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}
29+ ) where {T,O}
30+ println (io, " DERelative{" )
31+ println (io, " " , T)
32+ println (io, " " , O. name. name)
33+ println (io, " }" )
34+ nothing
35+ end
36+
37+ Base. show (
38+ io:: IO ,
39+ :: MIME"text/plain" ,
40+ der:: DERelative
41+ ) = show (io, der)
42+
43+
2244"""
2345$SIGNATURES
2446
@@ -28,7 +50,9 @@ DevNotes
2850- TODO does not yet incorporate Xi.nanosecond field.
2951- TODO does not handle timezone crossing properly yet.
3052"""
31- function _calcTimespan (Xi:: AbstractVector{<:DFGVariable} )
53+ function _calcTimespan (
54+ Xi:: AbstractVector{<:DFGVariable}
55+ )
3256 #
3357 tsmps = getTimestamp .(Xi[1 : 2 ]) .| > DateTime .| > datetime2unix
3458 # toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3)
@@ -47,10 +71,10 @@ function DERelative(
4771 f:: Function ,
4872 data = () -> ();
4973 dt:: Real = 1 ,
50- state0:: AbstractVector{<:Real} = zeros (getDimension (domain)),
51- state1:: AbstractVector{<:Real} = zeros (getDimension (domain)),
74+ state0:: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
75+ state1:: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
5276 tspan:: Tuple{<:Real, <:Real} = _calcTimespan (Xi),
53- problemType = DiscreteProblem,
77+ problemType = ODEProblem, # DiscreteProblem,
5478)
5579 #
5680 datatuple = if 2 < length (Xi)
@@ -60,11 +84,11 @@ function DERelative(
6084 data
6185 end
6286 # forward time problem
63- fproblem = problemType (f, state0, tspan, datatuple; dt = dt )
87+ fproblem = problemType (f, state0, tspan, datatuple; dt)
6488 # backward time problem
6589 bproblem = problemType (f, state1, (tspan[2 ], tspan[1 ]), datatuple; dt = - dt)
6690 # build the IIF recognizable object
67- return DERelative (domain, fproblem, bproblem, datatuple, getSample)
91+ return DERelative (domain, fproblem, bproblem, datatuple) # , getSample)
6892end
6993
7094function DERelative (
@@ -75,8 +99,8 @@ function DERelative(
7599 data = () -> ();
76100 Xi:: AbstractArray{<:DFGVariable} = getVariable .(dfg, labels),
77101 dt:: Real = 1 ,
78- state0 :: AbstractVector{<:Real} = zeros (getDimension (domain)),
79- state1 :: AbstractVector{<:Real} = zeros (getDimension (domain)),
102+ state1 :: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
103+ state0 :: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
80104 tspan:: Tuple{<:Real, <:Real} = _calcTimespan (Xi),
81105 problemType = DiscreteProblem,
82106)
@@ -85,26 +109,32 @@ function DERelative(
85109 domain,
86110 f,
87111 data;
88- dt = dt ,
89- state0 = state0 ,
90- state1 = state1 ,
91- tspan = tspan ,
92- problemType = problemType ,
112+ dt,
113+ state0,
114+ state1,
115+ tspan,
116+ problemType,
93117 )
94118end
95119#
96120#
97121
98122# n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
99- function _solveFactorODE! (measArr, prob, u0pts, Xtra... )
123+ function _solveFactorODE! (
124+ measArr,
125+ prob,
126+ u0pts,
127+ Xtra...
128+ )
100129 # happens when more variables (n-ary) must be included in DE solve
101130 for (xid, xtra) in enumerate (Xtra)
102131 # update the data register before ODE solver calls the function
103- prob. p[xid + 1 ][:] = xtra[:]
132+ prob. p[xid + 1 ][:] = xtra[:] # FIXME , unlikely to work with ArrayPartition, maybe use MArray and `.=`
104133 end
105134
106135 # set the initial condition
107- prob. u0[:] = u0pts[:]
136+ prob. u0 .= u0pts
137+
108138 sol = DifferentialEquations. solve (prob)
109139
110140 # extract solution from solved ode
@@ -155,21 +185,21 @@ end
155185
156186
157187# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE
158- function (cf:: CalcFactor{<:DERelative} )(measurement, X... )
188+ function (cf:: CalcFactor{<:DERelative} )(
189+ measurement,
190+ X...
191+ )
159192 #
193+ # numerical measurement values
160194 meas1 = measurement[1 ]
161- diffOp = measurement[2 ]
162-
195+ # work on-manifold via sampleFactor piggy back of particular manifold definition
196+ M = measurement[2 ]
197+ # lazy factor pointer
163198 oderel = cf. factor
164-
165- # work on-manifold
166- # diffOp = meas[2]
167- # if backwardSolve else forward
168-
169199 # check direction
170-
171200 solveforIdx = cf. solvefor
172-
201+
202+ # if backwardSolve else forward
173203 if solveforIdx > 2
174204 # need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
175205 solveforIdx = 2
@@ -185,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
185215 end
186216
187217 # find the difference between measured and predicted.
188- # # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
189- # # FIXME , obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
190- # @show cf._sampleIdx, solveforIdx, meas1
191-
192- # FIXME
193- res = zeros (size (X[2 ], 1 ))
194- for i = 1 : size (X[2 ], 1 )
195- # diffop( reference?, test? ) <===> ΔX = test \ reference
196- res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
197- end
218+ # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
219+ res_ = compose (M, inv (M, X[solveforIdx]), meas1)
220+ res = vee (M, Identity (M), log (M, Identity (M), res_))
221+
198222 return res
199223end
200224
@@ -249,28 +273,32 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
249273 oder = cf. factor
250274
251275 # how many trajectories to propagate?
252- # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
253- meas = [zeros (getDimension (cf. fullvariables[2 ])) for _ = 1 : N]
276+ #
277+ v2T = getVariableType (cf. fullvariables[2 ])
278+ meas = [allocate (getPointIdentity (v2T)) for _ = 1 : N]
279+ # meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
254280
255281 # pick forward or backward direction
256282 # set boundary condition
257- u0pts = if cf. solvefor == 1
283+ u0pts, M = if cf. solvefor == 1
258284 # backward direction
259285 prob = oder. backwardProblem
286+ M_ = getManifold (getVariableType (cf. fullvariables[1 ]))
260287 addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks (
261- convert (Tuple, getManifold ( getVariableType (cf . fullvariables[ 1 ])) ),
288+ convert (Tuple, M_ ),
262289 )
263290 # getBelief(cf.fullvariables[2]) |> getPoints
264- cf. _legacyParams[2 ]
291+ cf. _legacyParams[2 ], M_
265292 else
266293 # forward backward
267294 prob = oder. forwardProblem
295+ M_ = getManifold (getVariableType (cf. fullvariables[2 ]))
268296 # buffer manifold operations for use during factor evaluation
269297 addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks (
270- convert (Tuple, getManifold ( getVariableType (cf . fullvariables[ 2 ])) ),
298+ convert (Tuple, M_ ),
271299 )
272300 # getBelief(cf.fullvariables[1]) |> getPoints
273- cf. _legacyParams[1 ]
301+ cf. _legacyParams[1 ], M_
274302 end
275303
276304 # solve likely elements
@@ -281,17 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
281309 # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
282310 end
283311
284- return map (x -> (x, diffOp), meas)
312+ # return meas, M
313+ return map (x -> (x, M), meas)
285314end
286315# getDimension(oderel.domain)
287316
288317
289318
290-
291-
292- # # the function
293- # ode.problem.f.f
294-
295- #
296-
297319end # module
0 commit comments