@@ -1599,9 +1599,23 @@ end
1599
1599
# Assumptions. Either:
1600
1600
# A) model.parts MPI tasks are included in parts_redistributed_model MPI tasks; or
1601
1601
# B) model.parts MPI tasks include parts_redistributed_model MPI tasks
1602
+ const WeightsArrayType= Union{Nothing,MPIArray{<: Vector{<:Integer} }}
1602
1603
function GridapDistributed. redistribute (model:: OctreeDistributedDiscreteModel{Dc,Dp} ,
1603
- parts_redistributed_model= model. parts) where {Dc,Dp}
1604
+ parts_redistributed_model= model. parts;
1605
+ weights:: WeightsArrayType = nothing ) where {Dc,Dp}
1604
1606
parts = (parts_redistributed_model === model. parts) ? model. parts : parts_redistributed_model
1607
+ _weights= nothing
1608
+ if (weights != = nothing )
1609
+ Gridap. Helpers. @notimplementedif parts!= = model. parts
1610
+ _weights= map (model. dmodel. models,weights) do lmodel,weights
1611
+ # The length of the local weights array has to match the number of
1612
+ # cells in the model. This includes both owned and ghost cells.
1613
+ # Only the flags for owned cells are actually taken into account.
1614
+ @assert num_cells (lmodel)== length (weights)
1615
+ convert (Vector{Cint},weights)
1616
+ end
1617
+ end
1618
+
1605
1619
comm = parts. comm
1606
1620
if (GridapDistributed. i_am_in (model. parts. comm) || GridapDistributed. i_am_in (parts. comm))
1607
1621
if (parts_redistributed_model != = model. parts)
@@ -1610,7 +1624,7 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc
1610
1624
@assert A || B
1611
1625
end
1612
1626
if (parts_redistributed_model=== model. parts || A)
1613
- _redistribute_parts_subseteq_parts_redistributed (model,parts_redistributed_model)
1627
+ _redistribute_parts_subseteq_parts_redistributed (model,parts_redistributed_model,_weights )
1614
1628
else
1615
1629
_redistribute_parts_supset_parts_redistributed (model, parts_redistributed_model)
1616
1630
end
@@ -1619,7 +1633,9 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc
1619
1633
end
1620
1634
end
1621
1635
1622
- function _redistribute_parts_subseteq_parts_redistributed (model:: OctreeDistributedDiscreteModel{Dc,Dp} , parts_redistributed_model) where {Dc,Dp}
1636
+ function _redistribute_parts_subseteq_parts_redistributed (model:: OctreeDistributedDiscreteModel{Dc,Dp} ,
1637
+ parts_redistributed_model,
1638
+ _weights:: WeightsArrayType ) where {Dc,Dp}
1623
1639
parts = (parts_redistributed_model === model. parts) ? model. parts : parts_redistributed_model
1624
1640
if (parts_redistributed_model === model. parts)
1625
1641
ptr_pXest_old = model. ptr_pXest
@@ -1631,7 +1647,15 @@ function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistribut
1631
1647
parts. comm)
1632
1648
end
1633
1649
ptr_pXest_new = pXest_copy (model. pXest_type, ptr_pXest_old)
1634
- pXest_partition! (model. pXest_type, ptr_pXest_new)
1650
+ if (_weights != = nothing )
1651
+ init_fn_callback_c = pXest_reset_callbacks (model. pXest_type)
1652
+ map (_weights) do _weights
1653
+ pXest_reset_data! (model. pXest_type, ptr_pXest_new, Cint (sizeof (Cint)), init_fn_callback_c, pointer (_weights))
1654
+ end
1655
+ pXest_partition! (model. pXest_type, ptr_pXest_new; weights_set= true )
1656
+ else
1657
+ pXest_partition! (model. pXest_type, ptr_pXest_new; weights_set= false )
1658
+ end
1635
1659
1636
1660
# Compute RedistributeGlue
1637
1661
parts_snd, lids_snd, old2new = pXest_compute_migration_control_data (model. pXest_type,ptr_pXest_old,ptr_pXest_new)
0 commit comments