Skip to content

Commit 0e60d36

Browse files
committed
Deadlock: if the callback on Each panics and the app has a recovery, this will ensure that the thread-safe set is left in a usable state by ensuring the read lock is Unlocked. This also has a minor tweak on ToSlice to ensure the capacity hint is correct
1 parent 4c06bfc commit 0e60d36

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

threadsafe.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,12 @@ func (t *threadSafeSet[T]) Cardinality() int {
208208

209209
func (t *threadSafeSet[T]) Each(cb func(T) bool) {
210210
t.RLock()
211+
defer t.RUnlock()
211212
for elem := range *t.uss {
212213
if cb(elem) {
213214
break
214215
}
215216
}
216-
t.RUnlock()
217217
}
218218

219219
func (t *threadSafeSet[T]) Iter() <-chan T {
@@ -286,8 +286,9 @@ func (t *threadSafeSet[T]) Pop() (T, bool) {
286286
}
287287

288288
func (t *threadSafeSet[T]) ToSlice() []T {
289-
keys := make([]T, 0, t.Cardinality())
290289
t.RLock()
290+
l := len(*t.uss)
291+
keys := make([]T, 0, l)
291292
for elem := range *t.uss {
292293
keys = append(keys, elem)
293294
}

threadsafe_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ package mapset
2727

2828
import (
2929
"encoding/json"
30+
"fmt"
3031
"math/rand"
3132
"runtime"
3233
"sync"
@@ -636,3 +637,50 @@ func Test_MarshalJSON(t *testing.T) {
636637
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
637638
}
638639
}
640+
641+
// Test_DeadlockOnEachCallbackWhenPanic ensures that should a panic occur within the context
642+
// of the Each callback, progress can still be made on recovery. This is an edge case
643+
// that was called out on issue: https://github.com/deckarep/golang-set/issues/163.
644+
func Test_DeadlockOnEachCallbackWhenPanic(t *testing.T) {
645+
numbers := []int{1, 2, 3, 4}
646+
widgets := NewSet[*int]()
647+
widgets.Append(&numbers[0], &numbers[1], nil, &numbers[2])
648+
649+
var panicOccured = false
650+
651+
doWork := func(s Set[*int]) (err error) {
652+
defer func() {
653+
if r := recover(); r != nil {
654+
panicOccured = true
655+
err = fmt.Errorf("failed to print IDs: panicked: %v", r)
656+
}
657+
}()
658+
659+
s.Each(func(n *int) bool {
660+
// NOTE: this will throw a panic once we get to the nil element.
661+
result := *n * 2
662+
fmt.Println("result is dereferenced doubled:", result)
663+
return false
664+
})
665+
666+
return nil
667+
}
668+
669+
card := widgets.Cardinality()
670+
if widgets.Cardinality() != 4 {
671+
t.Errorf("Expected widgets to have 4 elements, but has %d", card)
672+
}
673+
674+
doWork(widgets)
675+
676+
if !panicOccured {
677+
t.Error("Expected a panic to occur followed by recover for test to be valid")
678+
}
679+
680+
widgets.Add(&numbers[3])
681+
682+
card = widgets.Cardinality()
683+
if widgets.Cardinality() != 5 {
684+
t.Errorf("Expected widgets to have 5 elements, but has %d", card)
685+
}
686+
}

0 commit comments

Comments
 (0)