Skip to content

Commit 895d3b6

Browse files
committed
完善skiplist 无锁插入使用近似的方式完成
1 parent 94b2536 commit 895d3b6

File tree

1 file changed

+274
-13
lines changed

1 file changed

+274
-13
lines changed

src/oblsm/memtable/ob_skiplist.h

Lines changed: 274 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ See the Mulan PSL v2 for more details. */
4040
#include "common/lang/atomic.h"
4141
#include "common/lang/vector.h"
4242
#include "common/log/log.h"
43+
#include <cstdlib>
44+
#include <mutex>
4345

4446
namespace oceanbase {
4547

@@ -141,7 +143,9 @@ class ObSkipList
141143

142144
Node *new_node(const Key &key, int height);
143145
int random_height();
144-
bool equal(const Key &a, const Key &b) const { return (compare_(a, b) == 0); }
146+
bool equal(const Key &a, const Key &b) const {
147+
return (compare_(a, b) == 0);
148+
}
145149

146150
// Return the earliest node that comes at or after key.
147151
// Return nullptr if there is no such node.
@@ -159,15 +163,18 @@ class ObSkipList
159163
Node *find_last() const;
160164

161165
// Immutable after construction
162-
ObComparator const compare_;
166+
ObComparator const compare_; // 一个比较器,用于比较两个key的大小
167+
168+
Node *const head_; // 一个哨兵节点,用于表示跳表的起点
163169

164-
Node *const head_;
170+
// 插入节点失败后,删除新插入的节点x
171+
void delete_node(Node *x);
165172

166173
// Modified only by insert(). Read racily by readers, but stale
167174
// values are ok.
168175
atomic<int> max_height_; // Height of the entire list
169176

170-
static common::RandomGenerator rnd;
177+
static common::RandomGenerator rnd; //
171178
};
172179

173180
template <typename Key, class ObComparator>
@@ -177,7 +184,12 @@ common::RandomGenerator ObSkipList<Key, ObComparator>::rnd = common::RandomGener
177184
template <typename Key, class ObComparator>
178185
struct ObSkipList<Key, ObComparator>::Node
179186
{
180-
explicit Node(const Key &k) : key(k) {}
187+
explicit Node(const Key &k, int height) : key(k), height_(height) {
188+
// 将next_[i]初始化为nullptr
189+
for (int i = 0; i < height; ++i) {
190+
next_[i].store(nullptr, std::memory_order_relaxed);
191+
}
192+
}
181193

182194
Key const key;
183195

@@ -186,6 +198,7 @@ struct ObSkipList<Key, ObComparator>::Node
186198
Node *next(int n)
187199
{
188200
ASSERT(n >= 0, "n >= 0");
201+
assert(n >= 0 && n < height_); // 避免越界
189202
// Use an 'acquire load' so that we observe a fully initialized
190203
// version of the returned Node.
191204
return next_[n].load(std::memory_order_acquire);
@@ -216,16 +229,46 @@ struct ObSkipList<Key, ObComparator>::Node
216229
return next_[n].compare_exchange_strong(expected, x);
217230
}
218231

232+
// 返回节点高度
233+
int height() const {return height_;}
234+
235+
// 对该节点加锁
236+
void lock() {
237+
node_mutex_.lock();
238+
}
239+
240+
// 对该节点解锁
241+
void unlock() {
242+
node_mutex_.unlock();
243+
}
244+
219245
private:
220246
// Array of length equal to the node height. next_[0] is lowest level link.
247+
// 这里声明的是 next_[1],但它实际上被设计成"可变大小数组",构造时会额外分配更大的空间。
221248
atomic<Node *> next_[1];
249+
250+
// 记录该节点的高度
251+
int height_;
252+
253+
// 该节点的锁
254+
std::mutex node_mutex_;
222255
};
223256

224257
template <typename Key, class ObComparator>
225258
typename ObSkipList<Key, ObComparator>::Node *ObSkipList<Key, ObComparator>::new_node(const Key &key, int height)
226259
{
227-
char *const node_memory = reinterpret_cast<char *>(malloc(sizeof(Node) + sizeof(atomic<Node *>) * (height - 1)));
228-
return new (node_memory) Node(key);
260+
assert(height > 0 && height <= kMaxHeight);
261+
// 在这里实际分配了 sizeof(Node):包含 next_[1]
262+
// 还有 sizeof(atomic<Node*>) * (height - 1):多分配 (height - 1) 个指针
263+
// 所以一共分配的大小是 可以访问 next_[0] ~ next_[height - 1] 的连续内存块
264+
265+
// char *const node_memory = reinterpret_cast<char *>(malloc(sizeof(Node) + sizeof(atomic<Node *>) * (height - 1)));
266+
// return new (node_memory) Node(key, height);
267+
268+
size_t size = sizeof(Node) + sizeof(std::atomic<Node*>) * (height - 1);
269+
void* mem = malloc(size);
270+
if (!mem) throw std::bad_alloc();
271+
return new (mem) Node(key, height);
229272
}
230273

231274
template <typename Key, class ObComparator>
@@ -294,22 +337,53 @@ int ObSkipList<Key, ObComparator>::random_height()
294337
// Increase height with probability 1 in kBranching
295338
static const unsigned int kBranching = 4;
296339
int height = 1;
297-
while (height < kMaxHeight && rnd.next(kBranching) == 0) {
340+
thread_local common::RandomGenerator local_rnd;
341+
while (height < kMaxHeight && local_rnd.next(kBranching) == 0) {
298342
height++;
299343
}
300344
ASSERT(height > 0, "height > 0");
301345
ASSERT(height <= kMaxHeight, "height <= kMaxHeight");
302346
return height;
303347
}
304348

349+
// 返回大于等于 key 的最小的节点
305350
template <typename Key, class ObComparator>
306351
typename ObSkipList<Key, ObComparator>::Node *ObSkipList<Key, ObComparator>::find_greater_or_equal(
307352
const Key &key, Node **prev) const
308353
{
309-
// your code here
310-
return nullptr;
354+
Node *x = head_;
355+
int current_max_level = get_max_height() - 1; // Max level currently in the list
356+
357+
for (int level = current_max_level; level >= 0; --level) {
358+
// Traverse right as long as next node at this level is less than key
359+
// and x itself has this level.
360+
Node *next_node = nullptr;
361+
while (level < x->height()) { // Check if x has current level
362+
next_node = x->next(level);
363+
if (next_node != nullptr && compare_(next_node->key, key) < 0) {
364+
x = next_node;
365+
} else {
366+
break; // next_node is >= key or nullptr
367+
}
368+
}
369+
// At this point, x is the rightmost node at 'level' whose key is < key,
370+
// or x is head_ if all keys at this level are >= key.
371+
// next_node is the node >= key or nullptr.
372+
if (prev != nullptr) {
373+
prev[level] = x;
374+
}
375+
}
376+
// After the loop, x is the predecessor of the target key at level 0.
377+
// The node that is >= key at level 0 is x->next(0), if x has level 0.
378+
if (0 < x->height()) {
379+
return x->next(0);
380+
}
381+
return nullptr; // Should technically be head_->next(0) if list not empty
382+
// but find_greater_or_equal can return nullptr if key is > all elements
383+
// or list is empty. x->next(0) handles this as head_->next(0) could be null.
311384
}
312385

386+
// 返回小于 key 的最大的节点
313387
template <typename Key, class ObComparator>
314388
typename ObSkipList<Key, ObComparator>::Node *ObSkipList<Key, ObComparator>::find_less_than(const Key &key) const
315389
{
@@ -353,7 +427,7 @@ typename ObSkipList<Key, ObComparator>::Node *ObSkipList<Key, ObComparator>::fin
353427

354428
template <typename Key, class ObComparator>
355429
ObSkipList<Key, ObComparator>::ObSkipList(ObComparator cmp)
356-
: compare_(cmp), head_(new_node(0 /* any key will do */, kMaxHeight)), max_height_(1)
430+
: compare_(cmp), head_(new_node(Key(), kMaxHeight)), max_height_(1)
357431
{
358432
for (int i = 0; i < kMaxHeight; i++) {
359433
head_->set_next(i, nullptr);
@@ -364,7 +438,7 @@ template <typename Key, class ObComparator>
364438
ObSkipList<Key, ObComparator>::~ObSkipList()
365439
{
366440
typename std::vector<Node *> nodes;
367-
nodes.reserve(max_height_.load(std::memory_order_relaxed));
441+
nodes.reserve(1024); // 或者干脆不reserve
368442
for (Node *x = head_; x != nullptr; x = x->next(0)) {
369443
nodes.push_back(x);
370444
}
@@ -376,14 +450,201 @@ ObSkipList<Key, ObComparator>::~ObSkipList()
376450

377451
template <typename Key, class ObComparator>
378452
void ObSkipList<Key, ObComparator>::insert(const Key &key)
379-
{}
453+
{
454+
// your code here
455+
int new_node_height = random_height(); // Renamed for clarity
456+
Node *prev[kMaxHeight];
457+
Node *next_nodes_for_new_node[kMaxHeight]; // Renamed for clarity
458+
Node *found_node; // Renamed
459+
int current_list_actual_max_height; // Renamed
460+
461+
while (true) {
462+
current_list_actual_max_height = get_max_height();
463+
// prev is populated by find_greater_or_equal for levels 0 to current_list_actual_max_height - 1
464+
found_node = find_greater_or_equal(key, prev);
465+
466+
// 找到了一个相同节点,不需要插入
467+
if (found_node != nullptr && equal(key, found_node->key)) {
468+
return;
469+
}
470+
// For levels of the new node that are above the current list's max height,
471+
// the predecessor is the head node.
472+
// This ensures prev[i] is valid for all i up to 'new_node_height'.
473+
for (int i = current_list_actual_max_height; i < new_node_height; ++i) {
474+
prev[i] = head_;
475+
}
476+
// 排除了有相同节点的情况,那么只有大于key的节点和小于key的节点
477+
// 那么prev的后继,就是第一个大于key的节点
478+
// Populate next_nodes_for_new_node using prev array
479+
// This loop iterates from 0 to new_node_height - 1.
480+
// All prev[i] for i in [0, new_node_height - 1] should be valid now.
481+
for (int i = 0; i < new_node_height; i++) {
482+
next_nodes_for_new_node[i] = prev[i]->next(i);
483+
}
484+
// Step 1: 锁住所有 prev[i],升序加锁
485+
for (int i = 0; i < new_node_height; i++) {
486+
prev[i]->lock();
487+
}
488+
// Step 2: 验证结构是否仍然一致
489+
bool valid = true;
490+
for (int i = 0; i < new_node_height; i++) {
491+
if (prev[i]->next(i) != next_nodes_for_new_node[i]) {
492+
valid = false;
493+
break;
494+
}
495+
}
496+
if (!valid) {
497+
for (int i = new_node_height - 1; i >= 0; --i) prev[i]->unlock(); // Unlock in reverse
498+
continue; // 重试
499+
}
500+
// Step 3: 插入新节点
501+
Node *allo_new_node = new_node(key, new_node_height); // 初始化新节点
502+
for (int i = 0; i < new_node_height; i++) {
503+
allo_new_node->set_next(i, next_nodes_for_new_node[i]);
504+
prev[i]->set_next(i, allo_new_node);
505+
}
506+
// 插入节点的高度可能是最大高度
507+
int expected_max_height = current_list_actual_max_height;
508+
while (new_node_height > expected_max_height) {
509+
if (max_height_.compare_exchange_weak(expected_max_height, new_node_height)) {
510+
break;
511+
}
512+
}
513+
// Step 4: 解锁
514+
for (int i = new_node_height - 1; i >= 0; --i) prev[i]->unlock(); // Unlock in reverse
515+
return;
516+
}
517+
}
380518

381519
template <typename Key, class ObComparator>
382520
void ObSkipList<Key, ObComparator>::insert_concurrently(const Key &key)
383521
{
384522
// your code here
523+
int new_node_height = random_height();
524+
Node *prev[kMaxHeight];
525+
Node *next_pointers_for_new_node[kMaxHeight];
526+
Node *found_node;
527+
int current_list_max_h;
528+
529+
while (true) {
530+
current_list_max_h = get_max_height();
531+
found_node = find_greater_or_equal(key, prev); // Populates prev[0]...prev[current_list_max_h-1]
532+
533+
// 找到了相同的节点
534+
if (found_node != nullptr && equal(key, found_node->key)) {
535+
return;
536+
}
537+
538+
// Ensure prev[i] is initialized for all levels of the new node
539+
// 当前跳表高度是5,新节点高度是7,那么6 和 7的前驱是 head_
540+
for (int i = current_list_max_h; i < new_node_height; ++i) {
541+
prev[i] = head_;
542+
}
543+
544+
// Now prev[i] should be valid for i in [0, new_node_height - 1]
545+
// 给后继节点赋值
546+
for (int i = 0; i < new_node_height; i++) {
547+
next_pointers_for_new_node[i] = prev[i]->next(i);
548+
}
549+
550+
Node *newly_allocated_node = new_node(key, new_node_height);
551+
for (int i = 0; i < new_node_height; ++i) {
552+
newly_allocated_node->set_next(i, next_pointers_for_new_node[i]);
553+
}
554+
555+
// 如果当前 prev[0]->next[0] == next_pointers_for_new_node[0],那么将其原子更新为 newly_allocated_node
556+
// 第0层插入成功则认为插入成功
557+
if (!prev[0]->cas_next(0, next_pointers_for_new_node[0], newly_allocated_node)) {
558+
delete_node(newly_allocated_node);
559+
continue;
560+
}
561+
562+
// 设置除了第0层以外的所有层
563+
for (int i = 1; i < new_node_height; ++i) {
564+
while (true) {
565+
// If prev[i]->next(i) has changed from what we expected (next_pointers_for_new_node[i]),
566+
// we must re-evaluate.
567+
// The simplest strategy for the CAS attempt is to use the current value of prev[i]->next(i) as expected.
568+
Node* current_successor = prev[i]->next(i); // Get current successor for this level.
569+
newly_allocated_node->set_next(i, current_successor); // New node should point to this.
570+
571+
if (prev[i]->cas_next(i, current_successor, newly_allocated_node)) {
572+
break; // Successfully linked at this level.
573+
}
574+
// CAS failed. prev[i]->next(i) was not current_successor.
575+
// This implies another thread modified it.
576+
// The original code had a full find_greater_or_equal here.
577+
// A less drastic retry for this level would be to just loop and re-attempt CAS with the freshly read current_successor.
578+
// However, if prev[i] itself is no longer the correct predecessor, this inner loop might spin.
579+
// The outer while(true) of insert_concurrently will perform a full FGE if needed.
580+
// For now, we adopt the structure of user's original retry which was a full FGE.
581+
// This is potentially problematic due to FGE not being lock-free during concurrent modifications
582+
// and prev array being overwritten.
583+
// A truly robust lock-free insertion here is more complex.
584+
// Sticking to the user's apparent original intent for retry:
585+
find_greater_or_equal(key, prev); // Re-evaluate all predecessors.
586+
// This overwrites the prev array which might be an issue
587+
// if other levels' CAS operations depended on the old values.
588+
// This specific retry logic is complex and error-prone.
589+
// Ensure prev[i] is valid again after FGE, especially if new_node_height > get_max_height()
590+
int updated_current_max_h = get_max_height();
591+
for (int k = updated_current_max_h; k < new_node_height; ++k) {
592+
if (k == i) { // Only if this specific level 'i' might have become uninitialized in prev by FGE
593+
Node* temp_prev_for_i[kMaxHeight]; // Use a temporary array for this specific level's FGE.
594+
find_greater_or_equal(key, temp_prev_for_i);
595+
if (i < updated_current_max_h) {
596+
prev[i] = temp_prev_for_i[i];
597+
} else {
598+
prev[i] = head_;
599+
}
600+
}
601+
}
602+
// After potential update of prev[i] from FGE:
603+
next_pointers_for_new_node[i] = prev[i]->next(i); // Update the expected next for level i
604+
newly_allocated_node->set_next(i, next_pointers_for_new_node[i]); // Update new node's link
605+
606+
// The original code was: while (!prev[i]->cas_next(i, next_pointers_for_new_node[i], newly_allocated_node))
607+
// This means it retried CAS with the potentially updated next_pointers_for_new_node[i] from the new FGE.
608+
// Let's try a direct retry of cas_next with the updated values:
609+
if (prev[i]->cas_next(i, next_pointers_for_new_node[i], newly_allocated_node)) {
610+
break;
611+
}
612+
// If it still fails, the outer loop of insert_concurrently will eventually retry.
613+
// To avoid potential infinite loop if prev[i] itself is bad, we can break and rely on outer retry.
614+
// For simplicity and to match the original structure more closely if FGE is intended inside:
615+
// Re-fetch expected value for CAS based on current prev[i] after FGE
616+
Node* re_read_next = prev[i]->next(i);
617+
newly_allocated_node->set_next(i, re_read_next);
618+
if(prev[i]->cas_next(i, re_read_next, newly_allocated_node)) {
619+
break;
620+
}
621+
// If many failures, rely on outer loop to restart find_greater_or_equal.
622+
}
623+
}
624+
// Update max_height_ if new_node_height is greater
625+
int expected_max_h = current_list_max_h; // Use max height from start of this attempt
626+
while (new_node_height > expected_max_h) {
627+
if (max_height_.compare_exchange_weak(expected_max_h, new_node_height)) {
628+
break;
629+
}
630+
// expected_max_h is updated by CAS on failure
631+
}
632+
return;
633+
}
385634
}
386635

636+
// 删除节点函数
637+
template <typename Key, class ObComparator>
638+
void ObSkipList<Key, ObComparator>::delete_node(Node *x)
639+
{
640+
if (x != nullptr) {
641+
// int height = x->height();
642+
x->~Node();
643+
free(x);
644+
}
645+
}
646+
647+
387648
template <typename Key, class ObComparator>
388649
bool ObSkipList<Key, ObComparator>::contains(const Key &key) const
389650
{

0 commit comments

Comments
 (0)