@@ -110,29 +110,50 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr n
110110
111111// waits for a packet to be sent by the transport.
112112// validate should have type func(X, netip.AddrPort, []byte), where X is a packet type.
113- func (test * udpTest ) waitPacketOut (validate interface {}) (closed bool ) {
113+ // skipTypes is an optional list of packet types to skip
114+ func (test * udpTest ) waitPacketOut (validate interface {}, skipTypes ... reflect.Type ) (closed bool ) {
114115 test .t .Helper ()
115116
116- dgram , err := test .pipe .receive ()
117- if err == errClosed {
118- return true
119- } else if err != nil {
120- test .t .Error ("packet receive error:" , err )
121- return false
122- }
123- p , _ , hash , err := v4wire .Decode (dgram .data )
124- if err != nil {
125- test .t .Errorf ("sent packet decode error: %v" , err )
126- return false
127- }
128- fn := reflect .ValueOf (validate )
129- exptype := fn .Type ().In (0 )
130- if ! reflect .TypeOf (p ).AssignableTo (exptype ) {
131- test .t .Errorf ("sent packet type mismatch, got: %v, want: %v" , reflect .TypeOf (p ), exptype )
117+ timeout := time .After (1 * time .Minute )
118+ for {
119+ select {
120+ case <- timeout :
121+ test .t .Fatal ("timeout waiting for packet" )
122+ default :
123+ }
124+
125+ dgram , err := test .pipe .receive ()
126+ if err == errClosed {
127+ return true
128+ } else if err != nil {
129+ test .t .Error ("packet receive error:" , err )
130+ return false
131+ }
132+ p , _ , hash , err := v4wire .Decode (dgram .data )
133+ if err != nil {
134+ test .t .Errorf ("sent packet decode error: %v" , err )
135+ return false
136+ }
137+ ptype := reflect .TypeOf (p )
138+ skip := false
139+ for _ , skipType := range skipTypes {
140+ if ptype == skipType {
141+ skip = true
142+ break
143+ }
144+ }
145+ if skip {
146+ continue
147+ }
148+ fn := reflect .ValueOf (validate )
149+ exptype := fn .Type ().In (0 )
150+ if ! reflect .TypeOf (p ).AssignableTo (exptype ) {
151+ test .t .Errorf ("sent packet type mismatch, got: %v, want: %v" , reflect .TypeOf (p ), exptype )
152+ return false
153+ }
154+ fn .Call ([]reflect.Value {reflect .ValueOf (p ), reflect .ValueOf (dgram .to ), reflect .ValueOf (hash )})
132155 return false
133156 }
134- fn .Call ([]reflect.Value {reflect .ValueOf (p ), reflect .ValueOf (dgram .to ), reflect .ValueOf (hash )})
135- return false
136157}
137158
138159func TestUDPv4_packetErrors (t * testing.T ) {
@@ -282,46 +303,20 @@ func TestUDPv4_findnode(t *testing.T) {
282303 expected := test .table .findnodeByID (testTarget .ID (), bucketSize , true )
283304 test .packetIn (nil , & v4wire.Findnode {Target : testTarget , Expiration : futureExp })
284305 waitNeighbors := func (want []* enode.Node ) {
285- timeout := time .After (1 * time .Minute )
286- for {
287- select {
288- case <- timeout :
289- t .Fatal ("timeout waiting for neighbors response" )
290- default :
291- }
292- dgram , err := test .pipe .receive ()
293- if err == errClosed {
294- t .Fatal ("socket closed before receiving neighbors" )
295- } else if err != nil {
296- t .Fatal ("packet receive error:" , err )
297- }
298- p , _ , _ , err := v4wire .Decode (dgram .data )
299- if err != nil {
300- t .Fatalf ("sent packet decode error: %v" , err )
301- }
302- // Skip any PING packets from revalidation
303- if _ , ok := p .(* v4wire.Ping ); ok {
304- continue
305- }
306- // Check we got NEIGHBORS
307- neighbors , ok := p .(* v4wire.Neighbors )
308- if ! ok {
309- t .Fatalf ("sent packet type mismatch, got: %T, want: *v4wire.Neighbors" , p )
310- }
311- if len (neighbors .Nodes ) != len (want ) {
312- t .Errorf ("wrong number of results: got %d, want %d" , len (neighbors .Nodes ), len (want ))
306+ test .waitPacketOut (func (p * v4wire.Neighbors , to netip.AddrPort , hash []byte ) {
307+ if len (p .Nodes ) != len (want ) {
308+ t .Errorf ("wrong number of results: got %d, want %d" , len (p .Nodes ), len (want ))
313309 return
314310 }
315- for i , n := range neighbors .Nodes {
311+ for i , n := range p .Nodes {
316312 if n .ID .ID () != want [i ].ID () {
317313 t .Errorf ("result mismatch at %d:\n got: %v\n want: %v" , i , n , expected .entries [i ])
318314 }
319315 if ! live [n .ID .ID ()] {
320316 t .Errorf ("result includes dead node %v" , n .ID .ID ())
321317 }
322318 }
323- return
324- }
319+ }, reflect .TypeOf ((* v4wire .Ping )(nil )))
325320 }
326321 // Receive replies.
327322 want := expected .entries
0 commit comments