@@ -121,13 +121,15 @@ void generate_alltoall_update_data(
121121 std::vector<RepartDistMatrix::all_to_all_data> &update_data)
122122{
123123 label linop_offset_store{0 };
124- for (size_t i = 0 ; i < 3 ; i++) {
124+ // NOTE in case of symmetric matrix 0 (upper) is same as 1 (lower)
125+ // thus we can start at 1
126+ label start = 0 ;
127+ for (size_t i = start; i < 3 ; i++) {
125128 label interface_size = in->get_rows ()[i].size ();
126129 label linop_idx = (fuse) ? 0 : in->get_id ()[i];
127130 label linop_offset = (fuse) ? linop_offset_store : 0 ;
128131 auto comm_pattern = compute_gather_to_owner_counts (
129132 exec_handler, ranks_per_owner, interface_size);
130-
131133 size_t recv_size = comm_pattern.recv_offsets .back ();
132134
133135 // NOTE Probably dont need to store linops[linop-idx] because we can
@@ -395,6 +397,9 @@ void update_impl(
395397 auto all_to_all_update = [repart_comm, ref_exec, device_exec,
396398 all_to_all_update_data, host_A, force_host_buffer,
397399 exec_handler, rank]() {
400+ // NOTE if symmetric (get it from host_A) we can skip id=0 and wait till
401+ // id=1 has been copied to use device copy
402+ //
398403 for (auto [id, comm_pattern, data_ptr] : all_to_all_update_data) {
399404 // auto start = std::chrono::steady_clock::now();
400405 auto repartAllToAll =
@@ -409,24 +414,43 @@ void update_impl(
409414 // communicate_values(ref_exec, device_exec, repart_comm,
410415 // repartAllToAll,
411416 // send_data_ptr, data_ptr, force_host_buffer);
412- // if ( repart_comm->rank() == 0 ) {
413417 // std::cout << __FILE__ <<
414418 // " Pstream::rank " << Pstream::myProcNo() <<
415419 // " repart_rank() " << repart_comm->rank() <<
416420 // " send_offsets.back() " <<
421+ // " id " << id <<
417422 // repartAllToAll.send_offsets.back() << " recv_counts: " <<
418423 // repartAllToAll.recv_counts << " recv_offsets: " <<
419424 // repartAllToAll.recv_offsets <<
420425 // std::endl;
421- // }
422- MPI_Request request;
423-
424- MPI_Igatherv (send_data_ptr, repartAllToAll.send_offsets .back (),
425- MPI_DOUBLE, data_ptr,
426- repartAllToAll.recv_counts .data (),
427- repartAllToAll.recv_offsets .data (), MPI_DOUBLE, 0 ,
428- repart_comm->get (), &request);
429- MPI_Wait (&request, MPI_STATUS_IGNORE);
426+
427+ if (id == 0 && host_A->get_symmetric ()) {
428+ } else {
429+ MPI_Request request;
430+ MPI_Igatherv (send_data_ptr, repartAllToAll.send_offsets .back (),
431+ MPI_DOUBLE, data_ptr,
432+ repartAllToAll.recv_counts .data (),
433+ repartAllToAll.recv_offsets .data (), MPI_DOUBLE, 0 ,
434+ repart_comm->get (), &request);
435+ MPI_Wait (&request, MPI_STATUS_IGNORE);
436+ }
437+
438+ // Perform symmetric inter device copy
439+ if (id == 1 && repart_comm->rank () == 0 &&
440+ host_A->get_symmetric ()) {
441+ auto [zid, zcomm_pattern, zdata_ptr] =
442+ all_to_all_update_data[0 ];
443+ // copy recv size data from data_ptr to zdata_ptr
444+ //
445+ label recv_buffer_size = repartAllToAll.recv_offsets .back ();
446+ auto l_view = gko::array<scalar>::view (
447+ device_exec, recv_buffer_size, data_ptr);
448+
449+ auto u_view = gko::array<scalar>::view (
450+ device_exec, recv_buffer_size, zdata_ptr);
451+
452+ u_view = l_view;
453+ }
430454 }
431455 };
432456
0 commit comments