Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ComponentArrays

import ChainRulesCore
import StaticArrayInterface, ArrayInterface, Functors
import Base.merge

using LinearAlgebra

Expand Down
43 changes: 42 additions & 1 deletion src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ directly on an `AbstractAxis`.

# Examples

```julia-repl
```jldoctest
julia> using ComponentArrays

julia> ca = ComponentArray(a=1, b=[1,2,3], c=(a=4,))
Expand All @@ -336,3 +336,44 @@ julia> sum(prod(ca[k]) for k in valkeys(ca))
return :($k)
end
valkeys(ca::ComponentVector) = valkeys(getaxes(ca)[1])

"""
merge(cvec1::ComponentVector, cvecs::ComponentVector...)

Construct a new ComponentVector by merging two or more existing ones, in a left-associative
manner. If a key is present in two or more CVectors, the right-most CVector takes priority.
The type of the resulting CVector will be promoted to accomodate all the types of the merged
CVectors

# Examples
```jldoctest
julia> c1 = ComponentArray(a=1.2, b=2.3)
ComponentVector{Float64}(a = 1.2, b = 2.3)

julia> c2 = ComponentArray(a=1,h=4)
ComponentVector{Int64}(a = 1, h = 4)

julia> merge(c1,c2)
ComponentVector{Float64}(a = 1.0, b = 2.3, h = 4.0)

julia> merge(c2,c1)
ComponentVector{Float64}(a = 1.2, h = 4.0, b = 2.3)
```
"""
function Base.merge(ca::ComponentVector{T}, ca2::ComponentVector{T}) where T
ax = getaxes(ca)
ax2 = getaxes(ca2)
vks = valkeys(ax[1])
vks2 = valkeys(ax2[1])
_p = Vector{T}()
sizehint!(_p, length(c1) + length(ca2))
for vk in vks
if vk in vks2
_p = vcat(_p, ca2[vk])
else
_p = vcat(_p, ca[vk])
end
end
ComponentArray(_p, merged_ax)
end
Base.merge(a::ComponentVector, b::ComponentVector, cs::ComponentVector) = merge(merge(a,b), cs...)