@@ -27,8 +27,11 @@ import (
2727 "net"
2828 "reflect"
2929 "strings"
30+ "syscall"
3031 "testing"
32+ "time"
3133
34+ "golang.org/x/sys/unix"
3235 core "google.golang.org/grpc/credentials/alts/internal"
3336 "google.golang.org/grpc/internal/grpctest"
3437)
@@ -105,6 +108,94 @@ func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (cli
105108 return clientConn , serverConn
106109}
107110
111+ // newTCPConnPair returns a pair of conns backed by TCP over loopback.
112+ func newTCPConnPair (rp string , clientProtected []byte , serverProtected []byte ) (* conn , * conn , error ) {
113+ const address = "localhost:50935"
114+
115+ // Start the server.
116+ serverChan := make (chan net.Conn )
117+ listenChan := make (chan struct {})
118+ go func () {
119+ listener , err := net .Listen ("tcp4" , address )
120+ if err != nil {
121+ panic (fmt .Sprintf ("failed to listen: %v" , err ))
122+ }
123+ defer listener .Close ()
124+ listenChan <- struct {}{}
125+ conn , err := listener .Accept ()
126+ if err != nil {
127+ panic (fmt .Sprintf ("failed to aceept: %v" , err ))
128+ }
129+ serverChan <- conn
130+ }()
131+
132+ // Ensure the server is listening before trying to connect.
133+ <- listenChan
134+ clientTCP , err := net .DialTimeout ("tcp4" , address , 5 * time .Second )
135+ if err != nil {
136+ return nil , nil , fmt .Errorf ("failed to Dial: %w" , err )
137+ }
138+
139+ // Get the server-side connection returned by Accept().
140+ var serverTCP net.Conn
141+ select {
142+ case serverTCP = <- serverChan :
143+ case <- time .After (5 * time .Second ):
144+ return nil , nil , fmt .Errorf ("timed out waiting for server conn" )
145+ }
146+
147+ // Make the connection behave a little bit like a real one by imposing
148+ // an MTU.
149+ clientTCP = & mtuConn {clientTCP , 1500 }
150+
151+ // 16 arbitrary bytes.
152+ key := []byte {
153+ 0x1f , 0x8b , 0x08 , 0x00 , 0x00 , 0x09 , 0x6e , 0x88 ,
154+ 0x02 , 0xff , 0xe2 , 0xd2 , 0x4c , 0xce , 0x4f , 0x49 ,
155+ }
156+
157+ client , err := NewConn (clientTCP , core .ClientSide , rp , key , clientProtected )
158+ if err != nil {
159+ panic (fmt .Sprintf ("Unexpected error creating test ALTS record connection: %v" , err ))
160+ }
161+ server , err := NewConn (serverTCP , core .ServerSide , rp , key , serverProtected )
162+ if err != nil {
163+ panic (fmt .Sprintf ("Unexpected error creating test ALTS record connection: %v" , err ))
164+ }
165+
166+ return client .(* conn ), server .(* conn ), nil
167+ }
168+
169+ // mtuConn imposes an MTU on writes. It simulates an important quality of real
170+ // network traffic that is lost when using loopback devices. On loopback, even
171+ // large messages (e.g. 512 KiB) when written often arrive at the receiver
172+ // instantaneously as a single payload. By explicitly splitting such writes into
173+ // smaller, MTU-sized paylaods we give the receiver a chance to respond to
174+ // smaller message sizes.
175+ type mtuConn struct {
176+ net.Conn
177+ mtu int
178+ }
179+
180+ // Write implements net.Conn.
181+ func (rc * mtuConn ) Write (buf []byte ) (int , error ) {
182+ var written int
183+ for len (buf ) > 0 {
184+ n , err := rc .Conn .Write (buf [:min (rc .mtu , len (buf ))])
185+ written += n
186+ if err != nil {
187+ return written , err
188+ }
189+ buf = buf [n :]
190+ }
191+ return written , nil
192+ }
193+
194+ // SyscallConn implements syscall.Conn.
195+ func (rc * mtuConn ) SycallConn () (syscall.RawConn , error ) {
196+ return rc .Conn .(syscall.Conn ).SyscallConn ()
197+ }
198+
108199func testPingPong (t * testing.T , rp string ) {
109200 clientConn , serverConn := newConnPair (rp , nil , nil )
110201 clientMsg := []byte ("Client Message" )
@@ -231,6 +322,115 @@ func BenchmarkLargeMessage(b *testing.B) {
231322 }
232323}
233324
325+ // BenchmarkTCP is a simple throughput test that sends payloads over a local TCP
326+ // connection.
327+ func BenchmarkTCP (b * testing.B ) {
328+ tcs := []struct {
329+ name string
330+ size int
331+ }{
332+ {"1 KiB" , 1024 },
333+ {"4 KiB" , 4 * 1024 },
334+ {"64 KiB" , 64 * 1024 },
335+ {"512 KiB" , 512 * 1024 },
336+ {"1 MiB" , 1024 * 1024 },
337+ {"4 MiB" , 4 * 1024 * 1024 },
338+ }
339+ for _ , tc := range tcs {
340+ b .Run ("size=" + tc .name , func (b * testing.B ) {
341+ benchmarkTCP (b , tc .size )
342+ })
343+ }
344+ }
345+
346+ // sum makes unwanted compiler optimizations in benchmarkTCP's loop less likely.
347+ var sum int
348+
349+ func benchmarkTCP (b * testing.B , size int ) {
350+ // Initialize the connection.
351+ client , server , err := newTCPConnPair (rekeyRecordProtocol , nil , nil )
352+ if err != nil {
353+ b .Fatalf ("failed to create TCP conn pair: %v" , err )
354+ }
355+ defer client .Close ()
356+ defer server .Close ()
357+
358+ rcvBuf := make ([]byte , size )
359+ sndBuf := make ([]byte , size )
360+ done := make (chan struct {})
361+ errChan := make (chan error )
362+
363+ // Launch a writer goroutine.
364+ go func () {
365+ for {
366+ select {
367+ case <- done :
368+ return
369+ default :
370+ }
371+ n , err := client .Write (sndBuf )
372+ if n != size || err != nil {
373+ errChan <- fmt .Errorf ("Write() = %v, %v; want %v, <nil>" , n , err , size )
374+ return
375+ }
376+ // Act a bit like a real workload that can't just fill
377+ // every buffer immediately.
378+ time .Sleep (10 * time .Millisecond )
379+ }
380+ }()
381+
382+ // Get the initial rusage so we can measure CPU time.
383+ var startUsage unix.Rusage
384+ if err := unix .Getrusage (unix .RUSAGE_SELF , & startUsage ); err != nil {
385+ b .Fatalf ("failed to get initial rusage: %v" , err )
386+ }
387+
388+ // Read as much as possible.
389+ var rcvd uint64
390+ for b .Loop () {
391+ n , err := io .ReadFull (server , rcvBuf )
392+ rcvd += uint64 (n )
393+ if n != size || err != nil {
394+ b .Fatalf ("Read() = %v, %v; want %v, <nil>" , n , err , size )
395+ }
396+ // Act a bit like a real workload and utilize received bytes.
397+ for _ , b := range rcvBuf [:n ] {
398+ sum += int (b )
399+ }
400+ }
401+
402+ // Turn off the writer.
403+ done <- struct {}{}
404+
405+ // Get the ending rusage.
406+ var endUsage unix.Rusage
407+ if err := unix .Getrusage (unix .RUSAGE_SELF , & endUsage ); err != nil {
408+ b .Fatalf ("failed to get final rusage: %v" , err )
409+ }
410+
411+ // Error check the writer goroutine.
412+ select {
413+ case err := <- errChan :
414+ b .Fatal (err )
415+ default :
416+ }
417+
418+ // Emit extra metrics.
419+ utime := timevalDiffUsec (& startUsage .Utime , & endUsage .Utime )
420+ stime := timevalDiffUsec (& startUsage .Stime , & endUsage .Stime )
421+ b .ReportMetric (float64 (utime )/ float64 (b .N ), "usr-usec/op" )
422+ b .ReportMetric (float64 (stime )/ float64 (b .N ), "sys-usec/op" )
423+ b .ReportMetric (float64 (stime + utime )/ float64 (b .N ), "cpu-usec/op" )
424+ b .ReportMetric (float64 (rcvd * 8 / (1024 * 1024 ))/ float64 (b .Elapsed ().Seconds ()), "Mbps" )
425+ }
426+
427+ // timevalDiffUsec returns the difference in microseconds between start and end.
428+ func timevalDiffUsec (start , end * unix.Timeval ) int64 {
429+ // Note: the int64 type conversion is needed because unix.Timeval uses
430+ // 32 bit values on some architectures.
431+ return int64 (1_000_000 * (end .Sec - start .Sec ) + end .Usec - start .Usec )
432+ }
433+
234434func testIncorrectMsgType (t * testing.T , rp string ) {
235435 // framedMsg is an empty ciphertext with correct framing but wrong
236436 // message type.
0 commit comments