Skip to content

Commit d59105d

Browse files
committed
fix
1 parent d9c74a4 commit d59105d

File tree

1 file changed

+45
-50
lines changed

1 file changed

+45
-50
lines changed

p2p/discover/v4_udp_test.go

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -110,29 +110,50 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr n
110110

111111
// waits for a packet to be sent by the transport.
112112
// validate should have type func(X, netip.AddrPort, []byte), where X is a packet type.
113-
func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) {
113+
// skipTypes is an optional list of packet types to skip
114+
func (test *udpTest) waitPacketOut(validate interface{}, skipTypes ...reflect.Type) (closed bool) {
114115
test.t.Helper()
115116

116-
dgram, err := test.pipe.receive()
117-
if err == errClosed {
118-
return true
119-
} else if err != nil {
120-
test.t.Error("packet receive error:", err)
121-
return false
122-
}
123-
p, _, hash, err := v4wire.Decode(dgram.data)
124-
if err != nil {
125-
test.t.Errorf("sent packet decode error: %v", err)
126-
return false
127-
}
128-
fn := reflect.ValueOf(validate)
129-
exptype := fn.Type().In(0)
130-
if !reflect.TypeOf(p).AssignableTo(exptype) {
131-
test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
117+
timeout := time.After(1 * time.Minute)
118+
for {
119+
select {
120+
case <-timeout:
121+
test.t.Fatal("timeout waiting for packet")
122+
default:
123+
}
124+
125+
dgram, err := test.pipe.receive()
126+
if err == errClosed {
127+
return true
128+
} else if err != nil {
129+
test.t.Error("packet receive error:", err)
130+
return false
131+
}
132+
p, _, hash, err := v4wire.Decode(dgram.data)
133+
if err != nil {
134+
test.t.Errorf("sent packet decode error: %v", err)
135+
return false
136+
}
137+
ptype := reflect.TypeOf(p)
138+
skip := false
139+
for _, skipType := range skipTypes {
140+
if ptype == skipType {
141+
skip = true
142+
break
143+
}
144+
}
145+
if skip {
146+
continue
147+
}
148+
fn := reflect.ValueOf(validate)
149+
exptype := fn.Type().In(0)
150+
if !reflect.TypeOf(p).AssignableTo(exptype) {
151+
test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
152+
return false
153+
}
154+
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(hash)})
132155
return false
133156
}
134-
fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(hash)})
135-
return false
136157
}
137158

138159
func TestUDPv4_packetErrors(t *testing.T) {
@@ -282,46 +303,20 @@ func TestUDPv4_findnode(t *testing.T) {
282303
expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true)
283304
test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp})
284305
waitNeighbors := func(want []*enode.Node) {
285-
timeout := time.After(1 * time.Minute)
286-
for {
287-
select {
288-
case <-timeout:
289-
t.Fatal("timeout waiting for neighbors response")
290-
default:
291-
}
292-
dgram, err := test.pipe.receive()
293-
if err == errClosed {
294-
t.Fatal("socket closed before receiving neighbors")
295-
} else if err != nil {
296-
t.Fatal("packet receive error:", err)
297-
}
298-
p, _, _, err := v4wire.Decode(dgram.data)
299-
if err != nil {
300-
t.Fatalf("sent packet decode error: %v", err)
301-
}
302-
// Skip any PING packets from revalidation
303-
if _, ok := p.(*v4wire.Ping); ok {
304-
continue
305-
}
306-
// Check we got NEIGHBORS
307-
neighbors, ok := p.(*v4wire.Neighbors)
308-
if !ok {
309-
t.Fatalf("sent packet type mismatch, got: %T, want: *v4wire.Neighbors", p)
310-
}
311-
if len(neighbors.Nodes) != len(want) {
312-
t.Errorf("wrong number of results: got %d, want %d", len(neighbors.Nodes), len(want))
306+
test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) {
307+
if len(p.Nodes) != len(want) {
308+
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), len(want))
313309
return
314310
}
315-
for i, n := range neighbors.Nodes {
311+
for i, n := range p.Nodes {
316312
if n.ID.ID() != want[i].ID() {
317313
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i])
318314
}
319315
if !live[n.ID.ID()] {
320316
t.Errorf("result includes dead node %v", n.ID.ID())
321317
}
322318
}
323-
return
324-
}
319+
}, reflect.TypeOf((*v4wire.Ping)(nil)))
325320
}
326321
// Receive replies.
327322
want := expected.entries

0 commit comments

Comments
 (0)