Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comms/ctran/backends/ib/CtranIb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ CtranIbSingleton::CtranIbSingleton() {
auto ibvInitResult = ibverbx::ibvInit();
FOLLY_EXPECTED_CHECKTHROW(ibvInitResult);
auto maybeDeviceList = ibverbx::IbvDevice::ibvGetDeviceList(
NCCL_IB_HCA, NCCL_IB_HCA_PREFIX, CTRAN_IB_ANY_PORT);
NCCL_IB_HCA, NCCL_IB_HCA_PREFIX, CTRAN_IB_ANY_PORT, NCCL_IB_DATA_DIRECT);
FOLLY_EXPECTED_CHECKTHROW(maybeDeviceList);
ibvDevices = std::move(*maybeDeviceList);

Expand Down
8 changes: 5 additions & 3 deletions comms/ctran/backends/ib/IbvWrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <sys/types.h>

#include "comms/ctran/backends/ib/IbvWrap.h"
#include "comms/ctran/ibverbx/Ibverbx.h"
#include "comms/ctran/ibverbx/IbverbxSymbols.h"
#include "comms/utils/logger/LogUtils.h"

#include "comms/ctran/utils/Checks.h"
Expand Down Expand Up @@ -128,8 +128,9 @@ commResult_t wrap_ibv_get_device_list(
struct ibv_device*** ret,
int* num_devices) {
*ret = ibvSymbols.ibv_internal_get_device_list(num_devices);
if (*ret == nullptr)
if (*ret == nullptr) {
*num_devices = 0;
}
return commSuccess;
}

Expand Down Expand Up @@ -480,8 +481,9 @@ static void ibvModifyQpLog(
remoteGidRes =
ibvGetGidStr(remoteGid, remoteGidName, sizeof(remoteGidName));
// we need pd->context to retrieve local GID, skip if not there
if (!qp->pd->context)
if (!qp->pd->context) {
goto print;
}
gidIndex = avAttr->ah_attr.grh.sgid_index;
union ibv_gid localGid;
FB_COMMCHECKGOTO(
Expand Down
179 changes: 179 additions & 0 deletions comms/ctran/ibverbx/Coordinator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include "comms/ctran/ibverbx/Coordinator.h"

#include <folly/Singleton.h>
#include "comms/ctran/ibverbx/IbvVirtualQp.h"

namespace ibverbx {

namespace {
folly::Singleton<Coordinator> coordinatorSingleton{};
}

/*** Coordinator ***/

std::shared_ptr<Coordinator> Coordinator::getCoordinator() {
return coordinatorSingleton.try_get();
}

// Register APIs for mapping management
void Coordinator::registerVirtualQp(
uint32_t virtualQpNum,
IbvVirtualQp* virtualQp) {
virtualQpNumToVirtualQp_[virtualQpNum] = virtualQp;
}

void Coordinator::registerVirtualCq(
uint32_t virtualCqNum,
IbvVirtualCq* virtualCq) {
virtualCqNumToVirtualCq_[virtualCqNum] = virtualCq;
}

void Coordinator::registerPhysicalQpToVirtualQp(
int physicalQpNum,
uint32_t virtualQpNum) {
physicalQpNumToVirtualQpNum_[physicalQpNum] = virtualQpNum;
}

void Coordinator::registerVirtualQpToVirtualSendCq(
uint32_t virtualQpNum,
uint32_t virtualSendCqNum) {
virtualQpNumToVirtualSendCqNum_[virtualQpNum] = virtualSendCqNum;
}

void Coordinator::registerVirtualQpToVirtualRecvCq(
uint32_t virtualQpNum,
uint32_t virtualRecvCqNum) {
virtualQpNumToVirtualRecvCqNum_[virtualQpNum] = virtualRecvCqNum;
}

void Coordinator::registerVirtualQpWithVirtualCqMappings(
IbvVirtualQp* virtualQp,
uint32_t virtualSendCqNum,
uint32_t virtualRecvCqNum) {
// Extract virtual QP number from the virtual QP object
uint32_t virtualQpNum = virtualQp->getVirtualQpNum();

// Register the virtual QP
registerVirtualQp(virtualQpNum, virtualQp);

// Register all physical QP to virtual QP mappings
for (const auto& qp : virtualQp->getQpsRef()) {
registerPhysicalQpToVirtualQp(qp.qp()->qp_num, virtualQpNum);
}
// Register notify QP
registerPhysicalQpToVirtualQp(
virtualQp->getNotifyQpRef().qp()->qp_num, virtualQpNum);

// Register virtual QP to virtual CQ relationships
registerVirtualQpToVirtualSendCq(virtualQpNum, virtualSendCqNum);
registerVirtualQpToVirtualRecvCq(virtualQpNum, virtualRecvCqNum);
}

// Access APIs for testing and internal use
const std::unordered_map<uint32_t, IbvVirtualQp*>&
Coordinator::getVirtualQpMap() const {
return virtualQpNumToVirtualQp_;
}

const std::unordered_map<uint32_t, IbvVirtualCq*>&
Coordinator::getVirtualCqMap() const {
return virtualCqNumToVirtualCq_;
}

const std::unordered_map<int, uint32_t>&
Coordinator::getPhysicalQpToVirtualQpMap() const {
return physicalQpNumToVirtualQpNum_;
}

const std::unordered_map<uint32_t, uint32_t>&
Coordinator::getVirtualQpToVirtualSendCqMap() const {
return virtualQpNumToVirtualSendCqNum_;
}

const std::unordered_map<uint32_t, uint32_t>&
Coordinator::getVirtualQpToVirtualRecvCqMap() const {
return virtualQpNumToVirtualRecvCqNum_;
}

// Update API for move operations - only need to update pointer maps
void Coordinator::updateVirtualQpPointer(
uint32_t virtualQpNum,
IbvVirtualQp* newPtr) {
virtualQpNumToVirtualQp_[virtualQpNum] = newPtr;
}

void Coordinator::updateVirtualCqPointer(
uint32_t virtualCqNum,
IbvVirtualCq* newPtr) {
virtualCqNumToVirtualCq_[virtualCqNum] = newPtr;
}

void Coordinator::unregisterVirtualQp(
uint32_t virtualQpNum,
IbvVirtualQp* ptr) {
// Only unregister if the pointer in the map matches the object being
// destroyed. This handles the case where the object was moved and the map was
// already updated with the new pointer.
auto it = virtualQpNumToVirtualQp_.find(virtualQpNum);
if (it == virtualQpNumToVirtualQp_.end() || it->second != ptr) {
// Object was moved, map already updated, nothing to do
return;
}

// Remove entries from all maps related to this virtual QP
virtualQpNumToVirtualQp_.erase(virtualQpNum);
virtualQpNumToVirtualSendCqNum_.erase(virtualQpNum);
virtualQpNumToVirtualRecvCqNum_.erase(virtualQpNum);

// Remove all physical QP to virtual QP mappings that point to this virtual QP
for (auto it = physicalQpNumToVirtualQpNum_.begin();
it != physicalQpNumToVirtualQpNum_.end();) {
if (it->second == virtualQpNum) {
it = physicalQpNumToVirtualQpNum_.erase(it);
} else {
++it;
}
}
}

void Coordinator::unregisterVirtualCq(
uint32_t virtualCqNum,
IbvVirtualCq* ptr) {
// Only unregister if the pointer in the map matches the object being
// destroyed. This handles the case where the object was moved and the map was
// already updated with the new pointer.
auto it = virtualCqNumToVirtualCq_.find(virtualCqNum);
if (it == virtualCqNumToVirtualCq_.end() || it->second != ptr) {
// Object was moved, map already updated, nothing to do
return;
}

// Remove the virtual CQ from the pointer map
virtualCqNumToVirtualCq_.erase(virtualCqNum);

// Remove all virtual QP to virtual send CQ mappings that point to this
// virtual CQ
for (auto it = virtualQpNumToVirtualSendCqNum_.begin();
it != virtualQpNumToVirtualSendCqNum_.end();) {
if (it->second == virtualCqNum) {
it = virtualQpNumToVirtualSendCqNum_.erase(it);
} else {
++it;
}
}

// Remove all virtual QP to virtual recv CQ mappings that point to this
// virtual CQ
for (auto it = virtualQpNumToVirtualRecvCqNum_.begin();
it != virtualQpNumToVirtualRecvCqNum_.end();) {
if (it->second == virtualCqNum) {
it = virtualQpNumToVirtualRecvCqNum_.erase(it);
} else {
++it;
}
}
}

} // namespace ibverbx
148 changes: 148 additions & 0 deletions comms/ctran/ibverbx/Coordinator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#pragma once

#include <memory>
#include <unordered_map>

#include <folly/Expected.h>
#include "comms/ctran/ibverbx/IbvCommon.h"

namespace ibverbx {

class IbvVirtualQp;
class IbvVirtualCq;

// Coordinator class responsible for routing commands and responses between
// IbvVirtualQp and IbvVirtualCq. Maintains mappings from physical QP numbers to
// IbvVirtualQp pointers, and from virtual CQ numbers to IbvVirtualCq pointers.
// Acts as a router to forward requests between these two classes.
//
// NOTE: The Coordinator APIs are NOT thread-safe. Users must ensure proper
// synchronization when accessing Coordinator methods from multiple threads.
// Thread-safe support can be added in the future if needed.
class Coordinator {
public:
Coordinator() = default;
~Coordinator() = default;

// Disable copy constructor and assignment operator
Coordinator(const Coordinator&) = delete;
Coordinator& operator=(const Coordinator&) = delete;

// Allow default move constructor and assignment operator
Coordinator(Coordinator&&) = default;
Coordinator& operator=(Coordinator&&) = default;

inline void submitRequestToVirtualCq(VirtualCqRequest&& request);
inline folly::Expected<VirtualQpResponse, Error> submitRequestToVirtualQp(
VirtualQpRequest&& request);

// Register APIs for mapping management
void registerVirtualQp(uint32_t virtualQpNum, IbvVirtualQp* virtualQp);
void registerVirtualCq(uint32_t virtualCqNum, IbvVirtualCq* virtualCq);
void registerPhysicalQpToVirtualQp(int physicalQpNum, uint32_t virtualQpNum);
void registerVirtualQpToVirtualSendCq(
uint32_t virtualQpNum,
uint32_t virtualSendCqNum);
void registerVirtualQpToVirtualRecvCq(
uint32_t virtualQpNum,
uint32_t virtualRecvCqNum);

// Consolidated registration API for IbvVirtualQp - registers the virtual QP
// along with all its physical QPs and CQ relationships in one call
void registerVirtualQpWithVirtualCqMappings(
IbvVirtualQp* virtualQp,
uint32_t virtualSendCqNum,
uint32_t virtualRecvCqNum);

// Getter APIs for accessing mappings
inline IbvVirtualCq* getVirtualSendCq(uint32_t virtualQpNum) const;
inline IbvVirtualCq* getVirtualRecvCq(uint32_t virtualQpNum) const;
inline IbvVirtualQp* getVirtualQpByPhysicalQpNum(int physicalQpNum) const;
inline IbvVirtualQp* getVirtualQpById(uint32_t virtualQpNum) const;
inline IbvVirtualCq* getVirtualCqById(uint32_t virtualCqNum) const;

// Access APIs for testing and internal use
const std::unordered_map<uint32_t, IbvVirtualQp*>& getVirtualQpMap() const;
const std::unordered_map<uint32_t, IbvVirtualCq*>& getVirtualCqMap() const;
const std::unordered_map<int, uint32_t>& getPhysicalQpToVirtualQpMap() const;
const std::unordered_map<uint32_t, uint32_t>& getVirtualQpToVirtualSendCqMap()
const;
const std::unordered_map<uint32_t, uint32_t>& getVirtualQpToVirtualRecvCqMap()
const;

// Update API for move operations - only need to update pointer maps
void updateVirtualQpPointer(uint32_t virtualQpNum, IbvVirtualQp* newPtr);
void updateVirtualCqPointer(uint32_t virtualCqNum, IbvVirtualCq* newPtr);

// Unregister API for cleanup during destruction
void unregisterVirtualQp(uint32_t virtualQpNum, IbvVirtualQp* ptr);
void unregisterVirtualCq(uint32_t virtualCqNum, IbvVirtualCq* ptr);

static std::shared_ptr<Coordinator> getCoordinator();

private:
// Map 1: Virtual QP Num -> Virtual QP pointer
std::unordered_map<uint32_t, IbvVirtualQp*> virtualQpNumToVirtualQp_;

// Map 2: Virtual CQ Num -> Virtual CQ pointer
std::unordered_map<uint32_t, IbvVirtualCq*> virtualCqNumToVirtualCq_;

// Map 3: Virtual QP Num -> Virtual Send CQ Num (relationship)
std::unordered_map<uint32_t, uint32_t> virtualQpNumToVirtualSendCqNum_;

// Map 4: Virtual QP Num -> Virtual Recv CQ Num (relationship)
std::unordered_map<uint32_t, uint32_t> virtualQpNumToVirtualRecvCqNum_;

// Map 5: Physical QP number -> Virtual QP Num (for routing)
std::unordered_map<int, uint32_t> physicalQpNumToVirtualQpNum_;
};

// Coordinator inline functions
inline IbvVirtualCq* Coordinator::getVirtualSendCq(
uint32_t virtualQpNum) const {
auto it = virtualQpNumToVirtualSendCqNum_.find(virtualQpNum);
if (it == virtualQpNumToVirtualSendCqNum_.end()) {
return nullptr;
}
return getVirtualCqById(it->second);
}

inline IbvVirtualCq* Coordinator::getVirtualRecvCq(
uint32_t virtualQpNum) const {
auto it = virtualQpNumToVirtualRecvCqNum_.find(virtualQpNum);
if (it == virtualQpNumToVirtualRecvCqNum_.end()) {
return nullptr;
}
return getVirtualCqById(it->second);
}

inline IbvVirtualQp* Coordinator::getVirtualQpByPhysicalQpNum(
int physicalQpNum) const {
auto it = physicalQpNumToVirtualQpNum_.find(physicalQpNum);
if (it == physicalQpNumToVirtualQpNum_.end()) {
return nullptr;
}
return getVirtualQpById(it->second);
}

inline IbvVirtualQp* Coordinator::getVirtualQpById(
uint32_t virtualQpNum) const {
auto it = virtualQpNumToVirtualQp_.find(virtualQpNum);
if (it == virtualQpNumToVirtualQp_.end()) {
return nullptr;
}
return it->second;
}

inline IbvVirtualCq* Coordinator::getVirtualCqById(
uint32_t virtualCqNum) const {
auto it = virtualCqNumToVirtualCq_.find(virtualCqNum);
if (it == virtualCqNumToVirtualCq_.end()) {
return nullptr;
}
return it->second;
}

} // namespace ibverbx
Loading