diff --git a/network/p2p/gossip/bloom.go b/network/p2p/gossip/bloom.go index 9c05c8db78e0..4b4d8e6af3cc 100644 --- a/network/p2p/gossip/bloom.go +++ b/network/p2p/gossip/bloom.go @@ -5,6 +5,8 @@ package gossip import ( "crypto/rand" + "iter" + "sync" "github.com/prometheus/client_golang/prometheus" @@ -16,8 +18,7 @@ import ( // anticipated at any moment, and a false positive probability of [targetFalsePositiveProbability]. If the // false positive probability exceeds [resetFalsePositiveProbability], the bloom filter will be reset. // -// Invariant: The returned bloom filter is not safe to reset concurrently with -// other operations. However, it is otherwise safe to access concurrently. +// The returned bloom filter is safe for concurrent usage. func NewBloomFilter( registerer prometheus.Registerer, namespace string, @@ -36,12 +37,8 @@ func NewBloomFilter( metrics: metrics, } - err = resetBloomFilter( - filter, - minTargetElements, - targetFalsePositiveProbability, - resetFalsePositiveProbability, - ) + // A lock is unnecessary as no other goroutine could have access. + err = filter.resetWhenLocked(minTargetElements) return filter, err } @@ -52,6 +49,11 @@ type BloomFilter struct { metrics *bloom.Metrics + // [bloom.Filter] itself is threadsafe, but resetting requires replacing it + // entirely. This mutex protects the [BloomFilter] fields, not the + // [bloom.Filter], so resetting is a write while everything else is a read. + resetMu sync.RWMutex + maxCount int bloom *bloom.Filter // salt is provided to eventually unblock collisions in Bloom. It's possible @@ -61,17 +63,29 @@ type BloomFilter struct { } func (b *BloomFilter) Add(gossipable Gossipable) { + b.resetMu.RLock() + b.addWhenLocked(gossipable) + b.resetMu.RUnlock() +} + +func (b *BloomFilter) addWhenLocked(gossipable Gossipable) { h := gossipable.GossipID() bloom.Add(b.bloom, h[:], b.salt[:]) b.metrics.Count.Inc() } func (b *BloomFilter) Has(gossipable Gossipable) bool { + b.resetMu.RLock() + defer b.resetMu.RUnlock() + h := gossipable.GossipID() return bloom.Contains(b.bloom, h[:], b.salt[:]) } func (b *BloomFilter) Marshal() ([]byte, []byte) { + b.resetMu.RLock() + defer b.resetMu.RUnlock() + bloomBytes := b.bloom.Marshal() // salt must be copied here to ensure the bytes aren't overwritten if salt // is later modified. @@ -81,37 +95,61 @@ func (b *BloomFilter) Marshal() ([]byte, []byte) { // ResetBloomFilterIfNeeded resets a bloom filter if it breaches [targetFalsePositiveProbability]. // -// If [targetElements] exceeds [minTargetElements], the size of the bloom filter will grow to maintain -// the same [targetFalsePositiveProbability]. -// -// Returns true if the bloom filter was reset. +// Deprecated: use [BloomFilter.ResetIfNeeded]. func ResetBloomFilterIfNeeded( bloomFilter *BloomFilter, targetElements int, ) (bool, error) { - if bloomFilter.bloom.Count() <= bloomFilter.maxCount { + return bloomFilter.ResetIfNeeded(targetElements, nil) +} + +// ResetIfNeeded resets the bloom filter if it breaches [targetFalsePositiveProbability]. +// +// If [targetElements] exceeds [minTargetElements], the size of the bloom filter will grow to maintain +// the same [targetFalsePositiveProbability]. +// +// Returns true if the bloom filter was reset, in which case the elements +// yielded by `refillWith` are added to the filter. +func (b *BloomFilter) ResetIfNeeded(targetElements int, refillWith iter.Seq[Gossipable]) (bool, error) { + mu := &b.resetMu + + // Although this pattern requires a double checking of the same property, + // it's cheap and avoids unnecessarily locking out all other goroutines on + // every call to this method. + isResetNeeded := func() bool { + return b.bloom.Count() > b.maxCount + } + mu.RLock() + reset := isResetNeeded() + mu.RUnlock() + if !reset { return false, nil } - targetElements = max(bloomFilter.minTargetElements, targetElements) - err := resetBloomFilter( - bloomFilter, - targetElements, - bloomFilter.targetFalsePositiveProbability, - bloomFilter.resetFalsePositiveProbability, - ) - return err == nil, err + mu.Lock() + defer mu.Unlock() + // Another thread may have beaten us to acquire the write lock. + if !isResetNeeded() { + return false, nil + } + + targetElements = max(b.minTargetElements, targetElements) + if err := b.resetWhenLocked(targetElements); err != nil { + return false, err + } + + if refillWith != nil { + for g := range refillWith { + b.addWhenLocked(g) + } + } + return true, nil } -func resetBloomFilter( - bloomFilter *BloomFilter, - targetElements int, - targetFalsePositiveProbability, - resetFalsePositiveProbability float64, -) error { +func (b *BloomFilter) resetWhenLocked(targetElements int) error { numHashes, numEntries := bloom.OptimalParameters( targetElements, - targetFalsePositiveProbability, + b.targetFalsePositiveProbability, ) newBloom, err := bloom.New(numHashes, numEntries) if err != nil { @@ -122,10 +160,10 @@ func resetBloomFilter( return err } - bloomFilter.maxCount = bloom.EstimateCount(numHashes, numEntries, resetFalsePositiveProbability) - bloomFilter.bloom = newBloom - bloomFilter.salt = newSalt + b.maxCount = bloom.EstimateCount(numHashes, numEntries, b.resetFalsePositiveProbability) + b.bloom = newBloom + b.salt = newSalt - bloomFilter.metrics.Reset(newBloom, bloomFilter.maxCount) + b.metrics.Reset(newBloom, b.maxCount) return nil } diff --git a/network/p2p/gossip/bloom_test.go b/network/p2p/gossip/bloom_test.go index 61fe94e5bc60..497319c2ca79 100644 --- a/network/p2p/gossip/bloom_test.go +++ b/network/p2p/gossip/bloom_test.go @@ -5,6 +5,7 @@ package gossip import ( "slices" + "sync" "testing" "github.com/prometheus/client_golang/prometheus" @@ -106,3 +107,84 @@ func TestBloomFilterRefresh(t *testing.T) { }) } } + +func TestBloomFilterClobber(t *testing.T) { + b, err := NewBloomFilter(prometheus.NewRegistry(), "", 1, 0.5, 0.5) + require.NoError(t, err, "NewBloomFilter()") + + start := make(chan struct{}) + var wg sync.WaitGroup + + for _, fn := range []func(){ + func() { b.Add(&testTx{}) }, + func() { b.Has(&testTx{}) }, + func() { b.Marshal() }, + func() { + _, err := b.ResetIfNeeded(1, nil) + require.NoErrorf(t, err, "%T.ResetIfNeeded()", b) + }, + } { + for range 10_000 { + wg.Add(1) + go func() { + <-start + fn() + wg.Done() + }() + } + } + + close(start) + wg.Wait() +} + +func TestBloomFilterRefillAfterReset(t *testing.T) { + b, err := NewBloomFilter(prometheus.NewRegistry(), "", 1, 0.5, 0.5) + require.NoError(t, err, "NewBloomFilter()") + + var before []*testTx + for i := range byte(10) { + before = append(before, &testTx{ids.ID{1, i}}) + } + + after := &testTx{ids.ID{2, 0}} + refill := func(yield func(Gossipable) bool) { + yield(after) + } + + steps := []struct { + setup func() + targetEls int + // Although we assert if resetting occurred, this is just to confirm + // proper test setup. The real test is via [BloomFilter.Has] of the + // before / after elements. + wantReset bool + }{ + { + setup: func() { + b.Add(before[0]) + }, + targetEls: 1e6, + wantReset: false, + }, + { + setup: func() { + for _, g := range before { + b.Add(g) + } + }, + targetEls: 1, + wantReset: true, + }, + } + + for _, s := range steps { + s.setup() + reset, err := b.ResetIfNeeded(s.targetEls, refill) + require.NoError(t, err, "ResetIfNeeded()") + require.Equal(t, s.wantReset, reset, "ResetIfNeeded()") + + require.Equalf(t, !reset, b.Has(before[0]), "Has([existing element]) when ResetIfNeeded() returned %t", reset) + require.Equalf(t, reset, b.Has(after), "Has([iterator element]) when ResetIfNeeded() returned %t", reset) + } +}