Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions src/uct/cuda/gdaki/gdaki.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ __device__ static inline ucs_status_t uct_gdaki_progress_thread(uct_gdaki_dev_ep
}

cuda::atomic_ref<uint64_t, cuda::thread_scope_device> ref(qp->sq_wqe_pi);
uint64_t sq_wqe_pi = qp->sq_wqe_pi;
ref.fetch_max(((wqe_cnt - sq_wqe_pi) & 0xffff) + sq_wqe_pi + 1);
ref.fetch_add(1);

doca_gpu_dev_verbs_fence_release<DOCA_GPUNETIO_VERBS_SYNC_SCOPE_GPU>();
}
Expand Down Expand Up @@ -254,6 +253,7 @@ uct_gdaki_db(struct doca_gpu_dev_verbs_qp *qp, uint64_t wqe_base, unsigned count
{
cuda::atomic_ref<uint64_t, cuda::thread_scope_device> ref(qp->sq_ready_index);
uint64_t wqe_base_orig = wqe_base;
__threadfence();
while (!ref.compare_exchange_strong(wqe_base, wqe_base + count,
cuda::std::memory_order_relaxed)) {
wqe_base = wqe_base_orig;
Expand Down Expand Up @@ -553,7 +553,6 @@ uct_gdaki_put_batch_part_impl(uct_gdaki_batch_t *batch, uint64_t flags,
uint64_t dst;
uint32_t rkey;
int opcode;
uint32_t fc;
unsigned lane_id;
unsigned num_lanes;
#if ENABLE_PARAMS_CHECK
Expand All @@ -580,7 +579,6 @@ uct_gdaki_put_batch_part_impl(uct_gdaki_batch_t *batch, uint64_t flags,
return UCS_ERR_NO_RESOURCE;
}

fc = doca_gpu_dev_verbs_wqe_idx_inc_mask(qp->sq_wqe_pi, qp->sq_wqe_num / 2);
wqe_idx = doca_gpu_dev_verbs_wqe_idx_inc_mask(wqe_base, lane_id);
for (uint32_t i = lane_id; i < count; i += num_lanes) {
uint32_t idx;
Expand All @@ -601,11 +599,8 @@ uct_gdaki_put_batch_part_impl(uct_gdaki_batch_t *batch, uint64_t flags,
continue;
}

if (((flags & UCT_DEV_BATCH_FLAG_COMP) && (i == count - 1)) ||
(!(flags & UCT_DEV_BATCH_FLAG_COMP) && (wqe_idx == fc))) {
cflag = DOCA_GPUNETIO_MLX5_WQE_CTRL_CQ_UPDATE;
ep->ops[wqe_idx & qp->sq_wqe_mask].comp = comp;
}
cflag = DOCA_GPUNETIO_MLX5_WQE_CTRL_CQ_UPDATE;
ep->ops[wqe_idx & qp->sq_wqe_mask].comp = comp;

wqe_ptr = doca_gpu_dev_verbs_get_wqe_ptr(qp, wqe_idx);
src = batch->list[idx].src + src_off;
Expand Down