diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b81b769..86f8d3b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,7 +75,7 @@ to turn off this check") endif() set(GINKGO_CHECKOUT_VERSION - "ogl_0600_gko190" + "ogl_0600_gko110" CACHE STRING "Use specific version of ginkgo") if(OGL_DEVEL_TOOLS) diff --git a/include/OGL/MatrixWrapper/Distributed.hpp b/include/OGL/MatrixWrapper/Distributed.hpp index bda53b87..7060bca5 100644 --- a/include/OGL/MatrixWrapper/Distributed.hpp +++ b/include/OGL/MatrixWrapper/Distributed.hpp @@ -13,7 +13,6 @@ #include "OGL/MatrixWrapper/HostMatrix.hpp" #include "OGL/Repartitioner.hpp" - /* The RepartDistMatrix class is a wrapper around Ginkgos distributed Matrix * class * @@ -38,20 +37,60 @@ class RepartDistMatrix using reorder_map_type = std::tuple>, scalar *, std::shared_ptr>>; - using all_to_all_data = std::tuple; + using all_to_all_data = std::tuple; struct pairwise_data { - label id; // original interface id on orig rank - label send; // 0 - send, 1, receive, 2 same_rank - label comm_rank; // other side of communication - label length; // length of the interface to communicate - label send_id; // - scalar *recv_ptr; // + label id; // original interface id on orig rank + label send; // 0 - send, 1, receive, 2 same_rank + label comm_rank; // other side of communication + label length; // length of the interface to communicate + label send_id; // + scalar *recv_ptr; // where to put the received data (linop data) + label recv_offset; // offset to begin of linop }; using gko::EnableLinOp::convert_to; using gko::EnableLinOp::move_to; + /* @brief replaces all data ptrs in all_to_all_data with new_ptr + */ + std::vector update_reorder_map_ptr( + std::vector in, scalar *new_ptr) const + { + std::vector out; + auto [map, _, pad] = in[0]; //) { + out.emplace_back(map, new_ptr, pad); + // } + return out; + } + + /* @brief replaces all data ptrs in all_to_all_data with new_ptr + */ + std::vector update_all_to_all_recv_ptr( + std::vector in, scalar *new_ptr) const + { + std::vector out; + for (auto [id, comm_pattern, data_ptr, offset] : in) { + out.emplace_back(id, comm_pattern, new_ptr, offset); + } + return out; + } + + /* @brief replaces all data ptrs in pairwise_data with new_ptr + */ + std::vector update_pairwise_recv_ptr( + std::vector in, scalar *new_ptr) const + { + std::vector out; + for (auto [linop_id, mode, comm_rank, length, send_id, recv_ptr, + offset] : in) { + out.push_back(pairwise_data{linop_id, mode, comm_rank, length, + send_id, new_ptr, offset}); + } + return out; + } + + std::shared_ptr get_dist_matrix() const { return this->dist_mtx_; @@ -80,6 +119,33 @@ class RepartDistMatrix } } + std::shared_ptr clone() const + { + auto new_dist_mtx = gko::share(this->dist_mtx_->clone()); + auto new_local_ptr = gko::as>( + new_dist_mtx->get_local_matrix()); + const scalar *new_data_ptr = new_local_ptr->get_const_values(); + + auto all_to_all_data = update_all_to_all_recv_ptr( + this->all_to_all_update_data_, const_cast(new_data_ptr)); + + auto new_non_local_ptr = gko::as>( + new_dist_mtx->get_non_local_matrix()); + const scalar *new_non_local_data_ptr = + new_non_local_ptr->get_const_values(); + auto pairwise_update_data = update_pairwise_recv_ptr( + this->pairwise_update_data_, + const_cast(new_non_local_data_ptr)); + + auto new_reorder_map = update_reorder_map_ptr( + this->reorder_maps_, const_cast(new_data_ptr)); + + return std::make_shared( + this->get_executor(), this->get_communicator(), + this->matrix_format_, new_dist_mtx, this->repartitioner_, + this->fuse_, all_to_all_data, pairwise_update_data, new_reorder_map, + this->compress_to_global_); + } /** * Copy-assigns a CombinationMatrix matrix. Preserves executor, copies @@ -101,7 +167,6 @@ class RepartDistMatrix this->pairwise_update_data_ = other.pairwise_update_data_; this->reorder_maps_ = other.reorder_maps_; this->compress_to_global_ = other.compress_to_global_; - this->linops_ = other.linops_; } return *this; } @@ -126,7 +191,6 @@ class RepartDistMatrix std::move(other.pairwise_update_data_); this->reorder_maps_ = std::move(other.reorder_maps_); this->compress_to_global_ = std::move(other.compress_to_global_); - this->linops_ = std::move(other.linops_); } return *this; } @@ -135,6 +199,9 @@ class RepartDistMatrix void update(const ExecutorHandler &exec_handler, std::shared_ptr host_A, label verbose); + void update(const ExecutorHandler &exec_handler, const scalar *diag_ptr, + const scalar *face_ptr, label verbose); + word get_matrix_format() const { return matrix_format_; } std::shared_ptr get_dist_mtx() const { return dist_mtx_; } @@ -147,8 +214,10 @@ class RepartDistMatrix std::vector all_to_all_update_data, std::vector pairwise_update_data, std::vector reorder_maps, - std::vector