-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
ref JuliaDiff/ChainRulesTestUtils.jl#258
FiniteDifferences.jl/src/to_vec.jl
Lines 36 to 57 in 5c2979e
# Fallback method for `to_vec`. Won't always do what you wanted, but should be fine a decent | |
# chunk of the time. | |
function to_vec(x::T) where {T} | |
Base.isstructtype(T) || throw(error("Expected a struct type")) | |
isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types | |
val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T)) | |
vals = first.(val_vecs_and_backs) | |
backs = last.(val_vecs_and_backs) | |
v, vals_from_vec = to_vec(vals) | |
function structtype_from_vec(v::Vector{<:Real}) | |
val_vecs = vals_from_vec(v) | |
values = map((b, v) -> b(v), backs, val_vecs) | |
try | |
T(values...) | |
catch MethodError | |
return _force_construct(T, values...) | |
end | |
end | |
return v, structtype_from_vec | |
end |
Metadata
Metadata
Assignees
Labels
No labels