diff --git a/threadsafe.go b/threadsafe.go index 664fc61..0f3e593 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -208,12 +208,12 @@ func (t *threadSafeSet[T]) Cardinality() int { func (t *threadSafeSet[T]) Each(cb func(T) bool) { t.RLock() + defer t.RUnlock() for elem := range *t.uss { if cb(elem) { break } } - t.RUnlock() } func (t *threadSafeSet[T]) Iter() <-chan T { @@ -286,8 +286,9 @@ func (t *threadSafeSet[T]) Pop() (T, bool) { } func (t *threadSafeSet[T]) ToSlice() []T { - keys := make([]T, 0, t.Cardinality()) t.RLock() + l := len(*t.uss) + keys := make([]T, 0, l) for elem := range *t.uss { keys = append(keys, elem) } diff --git a/threadsafe_test.go b/threadsafe_test.go index 9037616..ed15d02 100644 --- a/threadsafe_test.go +++ b/threadsafe_test.go @@ -27,6 +27,7 @@ package mapset import ( "encoding/json" + "fmt" "math/rand" "runtime" "sync" @@ -636,3 +637,49 @@ func Test_MarshalJSON(t *testing.T) { t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) } } + +// Test_DeadlockOnEachCallbackWhenPanic ensures that should a panic occur within the context +// of the Each callback, progress can still be made on recovery. This is an edge case +// that was called out on issue: https://github.com/deckarep/golang-set/issues/163. +func Test_DeadlockOnEachCallbackWhenPanic(t *testing.T) { + numbers := []int{1, 2, 3, 4} + widgets := NewSet[*int]() + widgets.Append(&numbers[0], &numbers[1], nil, &numbers[2]) + + var panicOccured = false + + doWork := func(s Set[*int]) (err error) { + defer func() { + if r := recover(); r != nil { + panicOccured = true + err = fmt.Errorf("failed to do work: %v", r) + } + }() + + s.Each(func(n *int) bool { + // NOTE: this will throw a panic once we get to the nil element. + _ = *n * 2 + return false + }) + + return nil + } + + card := widgets.Cardinality() + if widgets.Cardinality() != 4 { + t.Errorf("Expected widgets to have 4 elements, but has %d", card) + } + + doWork(widgets) + + if !panicOccured { + t.Error("Expected a panic to occur followed by recover for test to be valid") + } + + widgets.Add(&numbers[3]) + + card = widgets.Cardinality() + if widgets.Cardinality() != 5 { + t.Errorf("Expected widgets to have 5 elements, but has %d", card) + } +}