diff --git a/credentials/credentials_ext_test.go b/credentials/credentials_ext_test.go index fc0de6c0c306..908c3827fdd4 100644 --- a/credentials/credentials_ext_test.go +++ b/credentials/credentials_ext_test.go @@ -80,6 +80,7 @@ func loadTLSCreds(t *testing.T) (grpc.ServerOption, grpc.DialOption) { // used. func (s) TestCorrectAuthorityWithCreds(t *testing.T) { const authority = "auth.test.example.com" + const authorityWithPort = "auth.test.example.com:8010" tests := []struct { name string @@ -109,6 +110,13 @@ func (s) TestCorrectAuthorityWithCreds(t *testing.T) { }, expectedAuth: authority, }, + { + name: "TLSAuthorityWithPort", + creds: func(t *testing.T) (grpc.ServerOption, grpc.DialOption) { + return loadTLSCreds(t) + }, + expectedAuth: authorityWithPort, + }, } for _, tt := range tests { @@ -155,35 +163,50 @@ func (s) TestIncorrectAuthorityWithTLS(t *testing.T) { t.Fatalf("Failed to create credentials %v", err) } - serverCalled := make(chan struct{}) - ss := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - close(serverCalled) - return nil, nil + tests := []struct { + name string + authority string + }{ + { + name: "IncorrectAuthority", + authority: "auth.example.com", + }, + { + name: "IncorrectAuthorityWithPort", + authority: "auth.example.com:8443", }, } - if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - - cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) - if err != nil { - t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) - } - defer cc.Close() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { - const authority = "auth.example.com" - if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.Unavailable { - t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable) - } - select { - case <-serverCalled: - t.Fatalf("Server handler should not have been called") - case <-time.After(defaultTestShortTimeout): + serverCalled := make(chan struct{}) + ss := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + close(serverCalled) + return nil, nil + }, + } + if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) + if err != nil { + t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.authority)); status.Code(err) != codes.Unavailable { + t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable) + } + select { + case <-serverCalled: + t.Fatalf("Server handler should not have been called") + case <-time.After(defaultTestShortTimeout): + } + }) } } diff --git a/credentials/tls.go b/credentials/tls.go index 8277be7d6f85..0bcd16dbbf1d 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -56,9 +56,13 @@ func (t TLSInfo) AuthType() string { // non-nil error if the validation fails. func (t TLSInfo) ValidateAuthority(authority string) error { var errs []error + host, _, err := net.SplitHostPort(authority) + if err != nil { + host = authority + } for _, cert := range t.State.PeerCertificates { var err error - if err = cert.VerifyHostname(authority); err == nil { + if err = cert.VerifyHostname(host); err == nil { return nil } errs = append(errs, err)