diff --git a/credentials/jwt/doc.go b/credentials/jwt/doc.go new file mode 100644 index 000000000000..f74d3446afb4 --- /dev/null +++ b/credentials/jwt/doc.go @@ -0,0 +1,50 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwt implements JWT token file-based call credentials. +// +// This package provides support for A97 JWT Call Credentials, allowing gRPC +// clients to authenticate using JWT tokens read from files. While originally +// designed for xDS environments, these credentials are general-purpose. +// +// The credentials can be used directly in gRPC clients or configured via xDS. +// +// # Token Requirements +// +// JWT tokens must: +// - Be valid, well-formed JWT tokens with header, payload, and signature +// - Include an "exp" (expiration) claim +// - Be readable from the specified file path +// +// # Considerations +// +// - Tokens are cached until expiration to avoid excessive file I/O +// - Transport security is required (RequireTransportSecurity returns true) +// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or +// UNAUTHENTICATED errors +// - These errors are cached and retried with exponential backoff. +// +// This implementation is originally intended for use in service mesh +// environments like Istio where JWT tokens are provisioned and rotated by the +// infrastructure. +// +// # Experimental +// +// Notice: All APIs in this package are experimental and may be removed in a +// later release. +package jwt diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go new file mode 100644 index 000000000000..c42dbccd114d --- /dev/null +++ b/credentials/jwt/jwt_token_file.go @@ -0,0 +1,281 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwt implements gRPC credentials using JWT tokens from files. +package jwt + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + "sync" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/status" +) + +// jwtClaims represents the JWT claims structure for extracting expiration time. +type jwtClaims struct { + Exp int64 `json:"exp"` +} + +// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads +// tokens from a file. +// This implementation follows the A97 JWT Call Credentials specification. +type jwtTokenFileCallCreds struct { + tokenFilePath string + backoffStrategy backoff.Strategy // Strategy when error occurs + + // Cached token data + mu sync.RWMutex + cachedToken string + cachedExpiration time.Time // Slightly less than actual expiration time + cachedError error // Error from last failed attempt + retryAttempt int // Current retry attempt number + nextRetryTime time.Time // When next retry is allowed + + // Pre-emptive refresh mutex + refreshMu sync.Mutex +} + +// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens +// from the specified file path. +func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) { + if tokenFilePath == "" { + return nil, fmt.Errorf("tokenFilePath cannot be empty") + } + + return &jwtTokenFileCallCreds{ + tokenFilePath: tokenFilePath, + backoffStrategy: backoff.DefaultExponential, + }, nil +} + +// GetRequestMetadata gets the current request metadata, refreshing tokens if +// required. This implementation follows the PerRPCCredentials interface. The +// tokens will get automatically refreshed if they are about to expire or if +// they haven't been loaded successfully yet. +// If it's not possible to extract a token from the file, UNAVAILABLE is +// returned. +// If the token is extracted but invalid, then UNAUTHENTICATED is returned. +// If errors are encoutered, a backoff is applied before retrying. +func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { + return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err) + } + + // This may be delayed if the token needs to be refreshed from file. + token, err := c.getToken() + if err != nil { + return nil, err + } + + return map[string]string{ + "authorization": "Bearer " + token, + }, nil +} + +// RequireTransportSecurity indicates whether the credentials requires +// transport security. +func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool { + return true +} + +// getToken returns a valid JWT token, reading from file if necessary. +// Implements pre-emptive refresh and caches errors with backoff. +func (c *jwtTokenFileCallCreds) getToken() (string, error) { + c.mu.RLock() + + if c.isTokenValidLocked() { + token := c.cachedToken + shouldRefresh := c.needsPreemptiveRefreshLocked() + c.mu.RUnlock() + + if shouldRefresh { + c.triggerPreemptiveRefresh() + } + return token, nil + } + + // If still within backoff period, return cached error to avoid repeated + // file reads. + if c.cachedError != nil && time.Now().Before(c.nextRetryTime) { + err := c.cachedError + c.mu.RUnlock() + return "", err + } + + c.mu.RUnlock() + // Token is expired or missing or the retry backoff period has expired. + // So we should refresh synchronously. + // NOTE: refreshTokenSync itself acquires the write lock. + return c.refreshTokenSync(false) +} + +// isTokenValidLocked checks if the cached token is still valid. +// Caller must hold c.mu.RLock(). +func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool { + if c.cachedToken == "" { + return false + } + return c.cachedExpiration.After(time.Now()) +} + +// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be +// triggered. +// Returns true if the cached token is valid but expires within 1 minute. +// We only trigger pre-emptive refresh for valid tokens - if the token is +// invalid or expired, the next RPC will handle synchronous refresh instead. +// Caller must hold c.mu.RLock(). +func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool { + return c.isTokenValidLocked() && time.Until(c.cachedExpiration) < time.Minute +} + +// triggerPreemptiveRefresh starts a background refresh if needed. +// Multiple concurrent calls are safe - only one refresh will run at a time. +// The refresh runs in a separate goroutine and does not block the caller. +func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() { + go func() { + c.refreshMu.Lock() + defer c.refreshMu.Unlock() + + // Re-check if refresh is still needed under mutex. + c.mu.RLock() + stillNeeded := c.needsPreemptiveRefreshLocked() + c.mu.RUnlock() + + if !stillNeeded { + return // Another goroutine already refreshed or token expired. + } + + // Force refresh to read new token even if current one is still valid. + _, _ = c.refreshTokenSync(true) + }() +} + +// refreshTokenSync reads a new token from the file and updates the cache. If +// forceRefresh is true, bypasses the validity check of the currently +// cached token and always reads from file. +// This is used for pre-emptive refresh to ensure new tokens are loaded even +// when the cached token is still valid. If forceRefresh is false, skips +// file read when cached token is still valid, optimizing concurrent synchronous +// refresh calls where one RPC may have already updated the cache while another +// was waiting on the lock. +func (c *jwtTokenFileCallCreds) refreshTokenSync(forceRefresh bool) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + + // Double-check under write lock but skip if preemptive refresh is + // requested. + if !forceRefresh && c.isTokenValidLocked() { + return c.cachedToken, nil + } + + tokenBytes, err := os.ReadFile(c.tokenFilePath) + if err != nil { + err = status.Errorf(codes.Unavailable, "failed to read token file %q: %v", c.tokenFilePath, err) + c.setErrorWithBackoffLocked(err) + return "", err + } + + token := strings.TrimSpace(string(tokenBytes)) + if token == "" { + err := status.Errorf(codes.Unavailable, "token file %q is empty", c.tokenFilePath) + c.setErrorWithBackoffLocked(err) + return "", err + } + + // Parse JWT to extract expiration. + exp, err := c.extractExpiration(token) + if err != nil { + err = status.Errorf(codes.Unauthenticated, "failed to parse JWT from token file %q: %v", c.tokenFilePath, err) + c.setErrorWithBackoffLocked(err) + return "", err + } + + // Success - clear any cached error and backoff state, update token cache. + c.clearErrorAndBackoffLocked() + c.cachedToken = token + // Per RFC A97: consider token invalid if it expires within the next 30 + // seconds to accommodate for clock skew and server processing time. + c.cachedExpiration = exp.Add(-30 * time.Second) + + return token, nil +} + +// extractExpiration parses the JWT token to extract the expiration time. +func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + + payload := parts[1] + // Add padding if necessary for base64 decoding. + if m := len(payload) % 4; m != 0 { + payload += strings.Repeat("=", 4-m) + } + + payloadBytes, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err) + } + + var claims jwtClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err) + } + + if claims.Exp == 0 { + return time.Time{}, fmt.Errorf("JWT token has no expiration claim") + } + + expTime := time.Unix(claims.Exp, 0) + + // Check if token is already expired. + if expTime.Before(time.Now()) { + return time.Time{}, fmt.Errorf("JWT token is expired") + } + + return expTime, nil +} + +// setErrorWithBackoffLocked caches an error and calculates the next retry time +// using exponential backoff. +// Caller must hold c.mu write lock. +func (c *jwtTokenFileCallCreds) setErrorWithBackoffLocked(err error) { + c.cachedError = err + c.retryAttempt++ + backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1) + c.nextRetryTime = time.Now().Add(backoffDelay) +} + +// clearErrorAndBackoffLocked clears the cached error and resets backoff state. +// Caller must hold c.mu write lock. +func (c *jwtTokenFileCallCreds) clearErrorAndBackoffLocked() { + c.cachedError = nil + c.retryAttempt = 0 + c.nextRetryTime = time.Time{} +} diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go new file mode 100644 index 000000000000..afad2b152602 --- /dev/null +++ b/credentials/jwt/jwt_token_file_test.go @@ -0,0 +1,733 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwt + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/status" +) + +const defaultTestTimeout = 5 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestNewTokenFileCallCredentials(t *testing.T) { + tests := []struct { + name string + tokenFilePath string + wantErr string + }{ + { + name: "some filepath", + tokenFilePath: "/path/to/token", + wantErr: "", + }, + { + name: "empty filepath", + tokenFilePath: "", + wantErr: "tokenFilePath cannot be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds, err := NewTokenFileCallCredentials(tt.tokenFilePath) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("NewTokenFileCallCredentials() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err) + } + if creds == nil { + t.Fatal("NewTokenFileCallCredentials() returned nil credentials") + } + }) + } +} + +func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) { + creds, err := NewTokenFileCallCredentials("/path/to/token") + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + if !creds.RequireTransportSecurity() { + t.Error("RequireTransportSecurity() = false, want true") + } +} + +func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { + tempDir, err := os.MkdirTemp("", "jwt_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + now := time.Now().Truncate(time.Second) + tests := []struct { + name string + tokenContent string + authInfo credentials.AuthInfo + wantErr bool + wantErrContains string + wantMetadata map[string]string + }{ + { + name: "valid token without expiration", + tokenContent: createTestJWT(t, "", time.Time{}), + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + wantErr: true, + wantErrContains: "JWT token has no expiration claim", + }, + { + name: "valid token with future expiration", + tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)), + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + wantErr: false, + wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, "https://example.com", now.Add(time.Hour))}, + }, + { + name: "insufficient security level", + tokenContent: createTestJWT(t, "", time.Time{}), + authInfo: &testAuthInfo{secLevel: credentials.NoSecurity}, + wantErr: true, + wantErrContains: "unable to transfer JWT token file PerRPCCredentials", + }, + { + name: "expired token", + tokenContent: createTestJWT(t, "", now.Add(-time.Hour)), + authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + wantErr: true, + wantErrContains: "JWT token is expired", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenFile := writeTempFile(t, "token", tt.tokenContent) + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: tt.authInfo, + }) + + metadata, err := creds.GetRequestMetadata(ctx) + if tt.wantErr { + if err == nil { + t.Fatalf("GetRequestMetadata() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains) + } + return + } + + if err != nil { + t.Fatalf("GetRequestMetadata() unexpected error: %v", err) + } + + if len(metadata) != len(tt.wantMetadata) { + t.Fatalf("GetRequestMetadata() returned %d metadata entries, want %d", len(metadata), len(tt.wantMetadata)) + } + + for k, v := range tt.wantMetadata { + if metadata[k] != v { + t.Errorf("GetRequestMetadata() metadata[%q] = %q, want %q", k, metadata[k], v) + } + } + }) + } +} + +func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { + token := createTestJWT(t, "", time.Now().Add(time.Hour)) + tokenFile := writeTempFile(t, "token", token) + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // First call should read from file. + metadata1, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("First GetRequestMetadata() failed: %v", err) + } + + // Update the file with a different token. + newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { + t.Fatalf("Failed to update token file: %v", err) + } + + // Second call should return cached token (not the updated one). + metadata2, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("Second GetRequestMetadata() failed: %v", err) + } + + if metadata1["authorization"] != metadata2["authorization"] { + t.Error("Expected cached token to be returned, but got different token") + } +} + +func (s) TestTokenFileCallCreds_FileErrors(t *testing.T) { + tests := []struct { + name string + setupFile func(string) error + wantErrContains string + }{ + { + name: "nonexistent file", + setupFile: func(_ string) error { + return nil // Don't create the file + }, + wantErrContains: "failed to read token file", + }, + { + name: "empty file", + setupFile: func(path string) error { + return os.WriteFile(path, []byte(""), 0600) + }, + wantErrContains: "token file", + }, + { + name: "file with whitespace only", + setupFile: func(path string) error { + return os.WriteFile(path, []byte(" \n\t "), 0600) + }, + wantErrContains: "token file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "jwt_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + tokenFile := filepath.Join(tempDir, "token") + if err := tt.setupFile(tokenFile); err != nil { + t.Fatalf("Failed to setup test file: %v", err) + } + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + _, err = creds.GetRequestMetadata(ctx) + if err == nil { + t.Fatal("GetRequestMetadata() expected error, got nil") + } + + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains) + } + }) + } +} + +// testAuthInfo implements credentials.AuthInfo for testing. +type testAuthInfo struct { + secLevel credentials.SecurityLevel +} + +func (t *testAuthInfo) AuthType() string { + return "test" +} + +func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} +} + +// createTestJWT creates a test JWT token with the specified audience and +// expiration. +func createTestJWT(t *testing.T, audience string, expiration time.Time) string { + t.Helper() + + header := map[string]any{ + "typ": "JWT", + "alg": "HS256", + } + + claims := map[string]any{} + if audience != "" { + claims["aud"] = audience + } + if !expiration.IsZero() { + claims["exp"] = expiration.Unix() + } + + headerBytes, err := json.Marshal(header) + if err != nil { + t.Fatalf("Failed to marshal header: %v", err) + } + + claimsBytes, err := json.Marshal(claims) + if err != nil { + t.Fatalf("Failed to marshal claims: %v", err) + } + + headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes) + + // Remove padding for URL-safe base64 + headerB64 = strings.TrimRight(headerB64, "=") + claimsB64 = strings.TrimRight(claimsB64, "=") + + // For testing, we'll use a fake signature + signature := base64.URLEncoding.EncodeToString([]byte("fake_signature")) + signature = strings.TrimRight(signature, "=") + + return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature) +} + +// Tests that cached token expiration is set to 30 seconds before actual token +// expiration. +func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { + // Create token that expires in 2 hours. + tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) + token := createTestJWT(t, "", tokenExp) + tokenFile := writeTempFile(t, "token", token) + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // Get token to trigger caching. + _, err = creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata() failed: %v", err) + } + + // Verify cached expiration is 30 seconds before actual token expiration. + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + cachedExp := impl.cachedExpiration + impl.mu.RUnlock() + + expectedExp := tokenExp.Add(-30 * time.Second) + if !cachedExp.Equal(expectedExp) { + t.Errorf("cache expiration = %v, want %v", cachedExp, expectedExp) + } +} + +// Tests that pre-emptive refresh is triggered within 1 minute of expiration. +func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { + // Create token that expires in 80 seconds (=> cache expires in ~50s). + // This ensures pre-emptive refresh triggers since 50s < the 1 minute check. + tokenExp := time.Now().Add(80 * time.Second) + expiringToken := createTestJWT(t, "", tokenExp) + tokenFile := writeTempFile(t, "token", expiringToken) + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // Get token - should trigger pre-emptive refresh. + metadata1, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata() failed: %v", err) + } + + // Verify token was cached and check if refresh should be triggered. + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + cacheExp := impl.cachedExpiration + tokenCached := impl.cachedToken != "" + shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked() + impl.mu.RUnlock() + + if !tokenCached { + t.Error("token should be cached after successful GetRequestMetadata") + } + + if !shouldTriggerRefresh { + timeUntilExp := time.Until(cacheExp) + t.Errorf("cache expires in %v, should be < 1 minute to trigger pre-emptive refresh", timeUntilExp) + } + + // Create new token file with different expiration while refresh is + // happening. + newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) + if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil { + t.Fatalf("Failed to write updated token file: %v", err) + } + + // Get token again - should trigger a refresh given that the first one was + // cached but expiring soon. + // However, the function should have returned right away with the current + // cached token. + metadata2, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("Second GetRequestMetadata() failed: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // Now should get the new token. + metadata3, err := creds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("Second GetRequestMetadata() failed: %v", err) + } + + // If pre-emptive refresh worked, we should get the new token. + expectedAuth1 := "Bearer " + expiringToken + expectedAuth2 := "Bearer " + expiringToken + expectedAuth3 := "Bearer " + newToken + + actualAuth1 := metadata1["authorization"] + actualAuth2 := metadata2["authorization"] + actualAuth3 := metadata3["authorization"] + + if actualAuth1 != expectedAuth1 { + t.Errorf("First call should return original token: got %q, want %q", actualAuth1, expectedAuth1) + } + + if actualAuth2 != expectedAuth2 { + t.Errorf("Second call should return the original token: got %q, want %q", actualAuth2, expectedAuth2) + } + if actualAuth3 != expectedAuth3 { + t.Errorf("Third call should return the new token: got %q, want %q", actualAuth3, expectedAuth3) + } +} + +// Tests that backoff behavior handles file read errors correctly. +func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { + // This test has the following expectations: + // First call to GetRequestMetadata() fails with UNAVAILABLE due to a + // missing file. + // Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff. + // Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry. + // Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff + // even though file exists. + // Fifth call to GetRequestMetadata() succeeds after reading the file and + // backoff has expired. + tempDir := t.TempDir() + nonExistentFile := filepath.Join(tempDir, "nonexistent") + + creds, err := NewTokenFileCallCredentials(nonExistentFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // First call should fail with UNAVAILABLE. + _, err1 := creds.GetRequestMetadata(ctx) + if err1 == nil { + t.Fatal("Expected error from nonexistent file") + } + if status.Code(err1) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err1)) + } + + // Verify error is cached internally. + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + cachedErr := impl.cachedError + retryAttempt := impl.retryAttempt + nextRetryTime := impl.nextRetryTime + impl.mu.RUnlock() + + if cachedErr == nil { + t.Error("error should be cached internally after failed file read") + } + if retryAttempt != 1 { + t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt) + } + if nextRetryTime.IsZero() || nextRetryTime.Before(time.Now()) { + t.Error("Next retry time should be set to future time") + } + + // Second call should still return cached error. + _, err2 := creds.GetRequestMetadata(ctx) + if err2 == nil { + t.Fatal("Expected cached error") + } + if status.Code(err2) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err2)) + } + if err1.Error() != err2.Error() { + t.Errorf("cached error = %q, want %q", err2.Error(), err1.Error()) + } + + impl.mu.RLock() + retryAttempt2 := impl.retryAttempt + nextRetryTime2 := impl.nextRetryTime + impl.mu.RUnlock() + + if !nextRetryTime2.Equal(nextRetryTime) { + t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, nextRetryTime) + } + if retryAttempt2 != 1 { + t.Error("retry attempt should not change due to backoff") + } + + // Fast-forward the backoff retry time to allow next retry attempt. + impl.mu.Lock() + impl.nextRetryTime = time.Now().Add(-1 * time.Minute) + impl.mu.Unlock() + + // Third call should retry but still fail with UNAVAILABLE. + _, err3 := creds.GetRequestMetadata(ctx) + if err3 == nil { + t.Fatal("Expected cached error") + } + if status.Code(err3) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err3)) + } + if err3.Error() != err1.Error() { + t.Errorf("cached error = %q, want %q", err3.Error(), err1.Error()) + } + + impl.mu.RLock() + retryAttempt3 := impl.retryAttempt + nextRetryTime3 := impl.nextRetryTime + impl.mu.RUnlock() + + if !nextRetryTime3.After(nextRetryTime2) { + t.Error("nextRetryTime should not change due to backoff") + } + if retryAttempt3 != 2 { + t.Error("retry attempt should not change due to backoff") + } + + // Create valid token file. + validToken := createTestJWT(t, "", time.Now().Add(time.Hour)) + if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil { + t.Fatalf("Failed to create valid token file: %v", err) + } + + // Fourth call should still fail even though the file now exists. + _, err4 := creds.GetRequestMetadata(ctx) + if err4 == nil { + t.Fatal("Expected cached error") + } + if status.Code(err4) != codes.Unavailable { + t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err4)) + } + if err4.Error() != err3.Error() { + t.Errorf("cached error = %q, want %q", err4.Error(), err3.Error()) + } + + impl.mu.RLock() + retryAttempt4 := impl.retryAttempt + nextRetryTime4 := impl.nextRetryTime + impl.mu.RUnlock() + + if !nextRetryTime4.Equal(nextRetryTime3) { + t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, nextRetryTime3) + } + if retryAttempt4 != retryAttempt3 { + t.Error("retry attempt should not change due to backoff") + } + + // Fast-forward the backoff retry time to allow next retry attempt. + impl.mu.Lock() + impl.nextRetryTime = time.Now().Add(-1 * time.Minute) + impl.mu.Unlock() + // Fifth call should succeed since the file now exists + // and the backoff has expired. + _, err5 := creds.GetRequestMetadata(ctx) + if err5 != nil { + t.Errorf("after creating valid token file, GetRequestMetadata() should eventually succeed, but got: %v", err5) + t.Error("backoff should expire and trigger new attempt on next RPC") + } else { + // If successful, verify error cache and backoff state were cleared. + impl.mu.RLock() + clearedErr := impl.cachedError + retryAttempt := impl.retryAttempt + nextRetryTime := impl.nextRetryTime + impl.mu.RUnlock() + + if clearedErr != nil { + t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr) + } + if retryAttempt != 0 { + t.Errorf("after successful retry, retry attempt should be reset, got: %d", retryAttempt) + } + if !nextRetryTime.IsZero() { + t.Error("after successful retry, next retry time should be cleared") + } + } +} + +// Tests that invalid JWT tokens are handled with UNAUTHENTICATED status. +func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { + // Write invalid JWT (missing exp field). + invalidJWT := createTestJWT(t, "", time.Time{}) + tokenFile := writeTempFile(t, "token", invalidJWT) + + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + _, err = creds.GetRequestMetadata(ctx) + if err == nil { + t.Fatal("Expected UNAUTHENTICATED from invalid JWT") + } + if status.Code(err) != codes.Unauthenticated { + t.Errorf("GetRequestMetadata() = %v, want UNAUTHENTICATED for invalid JWT", status.Code(err)) + } +} + +// Tests that RPCs are queued during file operations and all receive the same +// result. +func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { + tempDir := t.TempDir() + tokenFile := filepath.Join(tempDir, "token") + + // Start with no token file to force file read during first RPC. + creds, err := NewTokenFileCallCredentials(tokenFile) + if err != nil { + t.Fatalf("NewTokenFileCallCredentials() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + // Launch multiple concurrent RPCs before creating the token file. + const numConcurrentRPCs = 5 + results := make(chan error, numConcurrentRPCs) + + for range numConcurrentRPCs { + go func() { + _, err := creds.GetRequestMetadata(ctx) + results <- err + }() + } + + // Collect all results - they should all be the same error (UNAVAILABLE). + var errors []error + for range numConcurrentRPCs { + err := <-results + errors = append(errors, err) + } + + // All RPCs should fail with the same error (file not found). + for i, err := range errors { + if err == nil { + t.Errorf("RPC %d should have failed with UNAVAILABLE", i) + continue + } + if status.Code(err) != codes.Unavailable { + t.Errorf("RPC %d = %v, want UNAVAILABLE", i, status.Code(err)) + } + if i > 0 && err.Error() != errors[0].Error() { + t.Errorf("RPC %d error should match first RPC error for proper queueing", i) + } + } + + // Verify error was cached after concurrent RPCs. + impl := creds.(*jwtTokenFileCallCreds) + impl.mu.RLock() + finalCachedErr := impl.cachedError + impl.mu.RUnlock() + + if finalCachedErr == nil { + t.Error("error should be cached after failed concurrent RPCs") + } + if finalCachedErr.Error() != errors[0].Error() { + t.Error("cached error should match the errors returned to RPCs") + } +} + +func writeTempFile(t *testing.T, name, content string) string { + t.Helper() + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, name) + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + return filePath +} diff --git a/internal/envconfig/xds.go b/internal/envconfig/xds.go index e87551552ad7..6420558c0b7a 100644 --- a/internal/envconfig/xds.go +++ b/internal/envconfig/xds.go @@ -68,4 +68,9 @@ var ( // trust. For more details, see: // https://github.com/grpc/proposal/blob/master/A87-mtls-spiffe-support.md XDSSPIFFEEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false) + + // XDSBootstrapCallCredsEnabled controls if JWT call credentials can be used + // in xDS bootstrap configuration. For more details, see: + // https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md + XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false) ) diff --git a/internal/xds/bootstrap/bootstrap.go b/internal/xds/bootstrap/bootstrap.go index f409e4bd77b2..46dbf6bc98bc 100644 --- a/internal/xds/bootstrap/bootstrap.go +++ b/internal/xds/bootstrap/bootstrap.go @@ -31,6 +31,7 @@ import ( "strings" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -64,11 +65,26 @@ type ChannelCreds struct { Config json.RawMessage `json:"config,omitempty"` } +// CallCreds contains the call credentials configuration for individual RPCs. +// This type implements RFC A97 call credentials structure. +type CallCreds struct { + // Type contains a unique name identifying the call credentials type. + // Currently only "jwt_token_file" is supported. + Type string `json:"type,omitempty"` + // Config contains the JSON configuration associated with the call credentials. + Config json.RawMessage `json:"config,omitempty"` +} + // Equal reports whether cc and other are considered equal. func (cc ChannelCreds) Equal(other ChannelCreds) bool { return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) } +// Equal reports whether cc and other are considered equal. +func (cc CallCreds) Equal(other CallCreds) bool { + return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) +} + // String returns a string representation of the credentials. It contains the // type and the config (if non-nil) separated by a "-". func (cc ChannelCreds) String() string { @@ -172,13 +188,15 @@ type ServerConfig struct { serverURI string channelCreds []ChannelCreds serverFeatures []string + callCreds []CallCreds // As part of unmarshalling the JSON config into this struct, we ensure that // the credentials config is valid by building an instance of the specified // credentials and store it here for easy access. - selectedCreds ChannelCreds - credsDialOption grpc.DialOption - extraDialOptions []grpc.DialOption + selectedCreds ChannelCreds + credsDialOption grpc.DialOption + extraDialOptions []grpc.DialOption + selectedCallCreds []credentials.PerRPCCredentials // Built call credentials cleanups []func() } @@ -200,6 +218,17 @@ func (sc *ServerConfig) ServerFeatures() []string { return sc.serverFeatures } +// CallCreds returns the call credentials configuration for this server. +func (sc *ServerConfig) CallCreds() []CallCreds { + return sc.callCreds +} + +// SelectedCallCreds returns the built call credentials that are ready to use. +// These are the credentials that were successfully built from the call_creds configuration. +func (sc *ServerConfig) SelectedCallCreds() []credentials.PerRPCCredentials { + return sc.selectedCallCreds +} + // ServerFeaturesIgnoreResourceDeletion returns true if this server supports a // feature where the xDS client can ignore resource deletions from this server, // as described in gRFC A53. @@ -233,6 +262,28 @@ func (sc *ServerConfig) DialOptions() []grpc.DialOption { return dopts } +// DialOptionsWithCallCredsForTransport returns dial options including call credentials +// only if they are compatible with the specified transport credentials type. +// Call credentials that require transport security will be skipped for insecure transports. +func (sc *ServerConfig) DialOptionsWithCallCredsForTransport(transportCredsType string, transportCreds credentials.TransportCredentials) []grpc.DialOption { + dopts := sc.DialOptions() + + // Check if transport is insecure + isInsecureTransport := transportCredsType == "insecure" || + (transportCreds != nil && transportCreds.Info().SecurityProtocol == "insecure") + + // Add call credentials only if compatible with transport security + for _, callCred := range sc.selectedCallCreds { + // Skip call credentials that require transport security on insecure transports + if isInsecureTransport && callCred.RequireTransportSecurity() { + continue + } + dopts = append(dopts, grpc.WithPerRPCCredentials(callCred)) + } + + return dopts +} + // Cleanups returns a collection of functions to be called when the xDS client // for this server is closed. Allows cleaning up resources created specifically // for this server. @@ -251,6 +302,8 @@ func (sc *ServerConfig) Equal(other *ServerConfig) bool { return false case !slices.EqualFunc(sc.channelCreds, other.channelCreds, func(a, b ChannelCreds) bool { return a.Equal(b) }): return false + case !slices.EqualFunc(sc.callCreds, other.callCreds, func(a, b CallCreds) bool { return a.Equal(b) }): + return false case !slices.Equal(sc.serverFeatures, other.serverFeatures): return false case !sc.selectedCreds.Equal(other.selectedCreds): @@ -273,6 +326,7 @@ type serverConfigJSON struct { ServerURI string `json:"server_uri,omitempty"` ChannelCreds []ChannelCreds `json:"channel_creds,omitempty"` ServerFeatures []string `json:"server_features,omitempty"` + CallCreds []CallCreds `json:"call_creds,omitempty"` } // MarshalJSON returns marshaled JSON bytes corresponding to this server config. @@ -281,6 +335,7 @@ func (sc *ServerConfig) MarshalJSON() ([]byte, error) { ServerURI: sc.serverURI, ChannelCreds: sc.channelCreds, ServerFeatures: sc.serverFeatures, + CallCreds: sc.callCreds, } return json.Marshal(server) } @@ -301,6 +356,7 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.serverURI = server.ServerURI sc.channelCreds = server.ChannelCreds sc.serverFeatures = server.ServerFeatures + sc.callCreds = server.CallCreds for _, cc := range server.ChannelCreds { // We stop at the first credential type that we support. @@ -320,6 +376,27 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.cleanups = append(sc.cleanups, cancel) break } + + // Process call credentials - unlike channel creds, we use ALL supported types + // Call credentials are optional per RFC A97 + for _, callCred := range server.CallCreds { + c := bootstrap.GetCredentials(callCred.Type) + if c == nil { + // Skip unsupported call credential types (don't fail bootstrap) + continue + } + bundle, cancel, err := c.Build(callCred.Config) + if err != nil { + // Call credential validation failed - this should fail bootstrap + return fmt.Errorf("failed to build call credentials from bootstrap for %q: %v", callCred.Type, err) + } + // Extract the PerRPCCredentials from the bundle. Sanity check for nil just in case + if callCredentials := bundle.PerRPCCredentials(); callCredentials != nil { + sc.selectedCallCreds = append(sc.selectedCallCreds, callCredentials) + } + sc.cleanups = append(sc.cleanups, cancel) + } + if sc.serverURI == "" { return fmt.Errorf("xds: `server_uri` field in server config cannot be empty: %s", string(data)) } @@ -341,6 +418,9 @@ type ServerConfigTestingOptions struct { ChannelCreds []ChannelCreds // ServerFeatures represents the list of features supported by this server. ServerFeatures []string + // CallCreds contains a list of call credentials to use for individual RPCs + // to this server. Optional. + CallCreds []CallCreds } // ServerConfigForTesting creates a new ServerConfig from the passed in options, @@ -356,6 +436,7 @@ func ServerConfigForTesting(opts ServerConfigTestingOptions) (*ServerConfig, err ServerURI: opts.URI, ChannelCreds: cc, ServerFeatures: opts.ServerFeatures, + CallCreds: opts.CallCreds, } scJSON, err := json.Marshal(scInternal) if err != nil { diff --git a/internal/xds/bootstrap/bootstrap_test.go b/internal/xds/bootstrap/bootstrap_test.go index d057197804d6..93e90144fd28 100644 --- a/internal/xds/bootstrap/bootstrap_test.go +++ b/internal/xds/bootstrap/bootstrap_test.go @@ -19,15 +19,21 @@ package bootstrap import ( + "context" "encoding/json" "errors" "fmt" + "net" "os" + "strings" "testing" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/jwt" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -196,6 +202,74 @@ var ( "server_features" : ["ignore_resource_deletion", "xds_v3"] }] }`, + // example data seeded from + // https://github.com/istio/istio/blob/master/pkg/istio-agent/testdata/grpc-bootstrap.json + "istioStyleWithJWTCallCreds": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "insecure" } + ], + "call_creds": [ + { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } + ], + "server_features" : ["xds_v3"] + }] + }`, + "istioStyleWithoutCallCreds": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "insecure" } + ], + "server_features" : ["xds_v3"] + }] + }`, + "istioStyleWithTLSAndJWT": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "tls", "config": {} } + ], + "call_creds": [ + { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } + ], + "server_features" : ["xds_v3"] + }] + }`, } metadata = &structpb.Struct{ Fields: map[string]*structpb.Value{ @@ -276,6 +350,82 @@ var ( node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } + + istioNodeMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "GENERATOR": { + Kind: &structpb.Value_StringValue{StringValue: "grpc"}, + }, + "INSTANCE_IPS": { + Kind: &structpb.Value_StringValue{StringValue: "127.0.0.1"}, + }, + "ISTIO_VERSION": { + Kind: &structpb.Value_StringValue{StringValue: "1.26.2"}, + }, + "WORKLOAD_IDENTITY_SOCKET_FILE": { + Kind: &structpb.Value_StringValue{StringValue: "socket"}, + }, + }, + } + jwtCallCreds, _ = jwt.NewTokenFileCallCredentials("/var/run/secrets/tokens/istio-token") + selectedJWTCallCreds = []credentials.PerRPCCredentials{jwtCallCreds} + configWithIstioJWTCallCreds = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: []CallCreds{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, + serverFeatures: []string{"xds_v3"}, + selectedCreds: ChannelCreds{Type: "insecure"}, + selectedCallCreds: selectedJWTCallCreds, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } + + configWithIstioStyleNoCallCreds = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + serverFeatures: []string{"xds_v3"}, + selectedCreds: ChannelCreds{Type: "insecure"}, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } + + configWithIstioStyleWithTLSAndJWT = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "tls", Config: json.RawMessage("{}")}}, + callCreds: []CallCreds{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, + serverFeatures: []string{"xds_v3"}, + selectedCreds: ChannelCreds{Type: "tls", Config: json.RawMessage("{}")}, + selectedCallCreds: selectedJWTCallCreds, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } ) func fileReadFromFileMap(bootstrapFileMap map[string]string, name string) ([]byte, error) { @@ -425,6 +575,35 @@ func (s) TestGetConfiguration_Success(t *testing.T) { {"goodBootstrap", configWithGoogleDefaultCredsAndV3}, {"multipleXDSServers", configWithMultipleServers}, {"serverSupportsIgnoreResourceDeletion", configWithGoogleDefaultCredsAndIgnoreResourceDeletion}, + {"istioStyleWithoutCallCreds", configWithIstioStyleNoCallCreds}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testGetConfigurationWithFileNameEnv(t, test.name, false, test.wantConfig) + testGetConfigurationWithFileContentEnv(t, test.name, false, test.wantConfig) + }) + } +} + +// Tests Istio-style bootstrap configurations with JWT call credentials +func (s) TestGetConfiguration_IstioStyleWithCallCreds(t *testing.T) { + // Enable JWT call credentials feature + original := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = original + }() + + cancel := setupBootstrapOverride(v3BootstrapFileMap) + defer cancel() + + tests := []struct { + name string + wantConfig *Config + }{ + {"istioStyleWithJWTCallCreds", configWithIstioJWTCallCreds}, + {"istioStyleWithTLSAndJWT", configWithIstioStyleWithTLSAndJWT}, } for _, test := range tests { @@ -1018,12 +1197,203 @@ func (s) TestDefaultBundles(t *testing.T) { } } -type s struct { - grpctest.Tester +func (s) TestCallCreds_Equal(t *testing.T) { + tests := []struct { + name string + cc1 CallCreds + cc2 CallCreds + expect bool + }{ + { + name: "identical configs", + cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + expect: true, + }, + { + name: "different types", + cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCreds{Type: "other_type", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + expect: false, + }, + { + name: "different configs", + cc1: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/different/path"}`)}, + expect: false, + }, + { + name: "nil vs non-nil configs", + cc1: CallCreds{Type: "jwt_token_file", Config: nil}, + cc2: CallCreds{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + expect: false, + }, + { + name: "both nil configs", + cc1: CallCreds{Type: "jwt_token_file", Config: nil}, + cc2: CallCreds{Type: "jwt_token_file", Config: nil}, + expect: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.cc1.Equal(test.cc2) + if result != test.expect { + t.Errorf("CallCreds.Equal() = %v, want %v", result, test.expect) + } + }) + } } -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) +func (s) TestServerConfig_UnmarshalJSON_WithCallCreds(t *testing.T) { + original := envconfig.XDSBootstrapCallCredsEnabled + defer func() { envconfig.XDSBootstrapCallCredsEnabled = original }() + envconfig.XDSBootstrapCallCredsEnabled = true // Enable call creds in bootstrap + tests := []struct { + name string + json string + wantCallCreds []CallCreds + wantErr bool + errContains string + }{ + { + name: "valid call_creds with jwt_token_file", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/path/to/token.jwt"} + } + ] + }`, + wantCallCreds: []CallCreds{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file": "/path/to/token.jwt"}`), + }}, + }, + { + name: "multiple call_creds types", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + {"type": "jwt_token_file", "config": {"jwt_token_file": "/token1.jwt"}}, + {"type": "unsupported_type", "config": {}} + ] + }`, + wantCallCreds: []CallCreds{ + {Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/token1.jwt"}`)}, + {Type: "unsupported_type", Config: json.RawMessage(`{}`)}, + }, + }, + { + name: "empty call_creds array", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [] + }`, + wantCallCreds: []CallCreds{}, + }, + { + name: "missing call_creds field", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}] + }`, + wantCallCreds: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var sc ServerConfig + err := sc.UnmarshalJSON([]byte(test.json)) + + if test.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + if test.errContains != "" && !strings.Contains(err.Error(), test.errContains) { + t.Errorf("Error %v should contain %q", err, test.errContains) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if diff := cmp.Diff(test.wantCallCreds, sc.CallCreds()); diff != "" { + t.Errorf("CallCreds mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func (s) TestServerConfig_Equal_WithCallCreds(t *testing.T) { + callCreds := []CallCreds{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file": "/test/token.jwt"}`), + }} + sc1 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: callCreds, + serverFeatures: []string{"feature1"}, + } + sc2 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: callCreds, + serverFeatures: []string{"feature1"}, + } + sc3 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: []CallCreds{{Type: "different"}}, + serverFeatures: []string{"feature1"}, + } + + if !sc1.Equal(sc2) { + t.Error("Equal ServerConfigs with same call creds should be equal") + } + if sc1.Equal(sc3) { + t.Error("ServerConfigs with different call creds should not be equal") + } +} + +func (s) TestServerConfig_MarshalJSON_WithCallCreds(t *testing.T) { + original := envconfig.XDSBootstrapCallCredsEnabled + defer func() { envconfig.XDSBootstrapCallCredsEnabled = original }() + envconfig.XDSBootstrapCallCredsEnabled = true // Enable call creds in bootstrap + sc := &ServerConfig{ + serverURI: "test-server:443", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCreds: []CallCreds{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file":"/test/token.jwt"}`), + }}, + serverFeatures: []string{"test_feature"}, + } + + data, err := sc.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + // confirm Marshal/Unmarshal symmetry + var unmarshaled ServerConfig + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if diff := cmp.Diff(sc.CallCreds(), unmarshaled.CallCreds()); diff != "" { + t.Errorf("Marshal/Unmarshal call credentials produces differences:\n%s", diff) + } } func newStructProtoFromMap(t *testing.T, input map[string]any) *structpb.Struct { @@ -1269,3 +1639,231 @@ func (s) TestGetConfiguration_FallbackDisabled(t *testing.T) { testGetConfigurationWithFileContentEnv(t, "multipleXDSServers", false, wantConfig) }) } + +func (s) TestBootstrap_SelectedCredsAndCallCreds(t *testing.T) { + // Enable JWT call credentials + original := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = original + }() + + tokenFile := "/token.jwt" + tests := []struct { + name string + bootstrapConfig string + expectCallCreds int + expectTransportType string + }{ + { + name: "JWT call creds with TLS channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }`, + expectCallCreds: 1, + expectTransportType: "tls", + }, + { + name: "JWT call creds with multiple channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}, {"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + }, + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }`, + expectCallCreds: 2, + expectTransportType: "tls", // the first channel creds is selected + }, + { + name: "JWT call creds with insecure channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }`, + expectCallCreds: 1, + expectTransportType: "insecure", + }, + { + name: "No call creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}] + }`, + expectCallCreds: 0, + expectTransportType: "insecure", + }, + { + name: "No call creds multiple channel creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}, {"type": "tls", "config": {}}] + }`, + expectCallCreds: 0, + expectTransportType: "insecure", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var sc ServerConfig + err := sc.UnmarshalJSON([]byte(test.bootstrapConfig)) + if err != nil { + t.Fatalf("Failed to unmarshal bootstrap config: %v", err) + } + + // Verify call credentials processing + callCreds := sc.CallCreds() + selectedCallCreds := sc.SelectedCallCreds() + + if len(callCreds) != test.expectCallCreds { + t.Errorf("Call creds count = %d, want %d", len(callCreds), test.expectCallCreds) + } + if len(selectedCallCreds) != test.expectCallCreds { + t.Errorf("Selected call creds count = %d, want %d", len(selectedCallCreds), test.expectCallCreds) + } + + // Verify transport credentials are properly selected + if sc.SelectedCreds().Type != test.expectTransportType { + t.Errorf("Selected transport creds type = %q, want %q", + sc.SelectedCreds().Type, test.expectTransportType) + } + }) + } +} + +func (s) TestDialOptionsWithCallCredsForTransport(t *testing.T) { + // Create test JWT credentials that require transport security + testJWTCreds := &testPerRPCCreds{requireSecurity: true} + testInsecureCreds := &testPerRPCCreds{requireSecurity: false} + + sc := &ServerConfig{ + selectedCallCreds: []credentials.PerRPCCredentials{ + testJWTCreds, + testInsecureCreds, + }, + extraDialOptions: []grpc.DialOption{ + grpc.WithUserAgent("test-agent"), // Test extra option + }, + } + + tests := []struct { + name string + transportType string + transportCreds credentials.TransportCredentials + expectJWTCreds bool + expectOtherCreds bool + }{ + { + name: "insecure transport by type", + transportType: "insecure", + transportCreds: nil, + expectJWTCreds: false, // JWT requires security + expectOtherCreds: true, // Non-security creds allowed + }, + { + name: "insecure transport by protocol", + transportType: "custom", + transportCreds: insecure.NewCredentials(), + expectJWTCreds: false, // JWT requires security + expectOtherCreds: true, // Non-security creds allowed + }, + { + name: "secure transport", + transportType: "tls", + transportCreds: &testTransportCreds{securityProtocol: "tls"}, + expectJWTCreds: true, // JWT allowed on secure transport + expectOtherCreds: true, // All creds allowed + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := sc.DialOptionsWithCallCredsForTransport(test.transportType, test.transportCreds) + + // Count dial options (should include extra options + applicable call creds) + expectedCount := 2 // extraDialOptions + always include non-security creds + if test.expectJWTCreds { + expectedCount++ + } + + if len(opts) != expectedCount { + t.Errorf("DialOptions count = %d, want %d", len(opts), expectedCount) + } + }) + } +} + +type testPerRPCCreds struct { + requireSecurity bool +} + +func (c *testPerRPCCreds) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + return map[string]string{"test": "metadata"}, nil +} + +func (c *testPerRPCCreds) RequireTransportSecurity() bool { + return c.requireSecurity +} + +type testTransportCreds struct { + securityProtocol string +} + +func (c *testTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &testAuthInfo{}, nil +} + +func (c *testTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &testAuthInfo{}, nil +} + +func (c *testTransportCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: c.securityProtocol} +} + +func (c *testTransportCreds) Clone() credentials.TransportCredentials { + return &testTransportCreds{securityProtocol: c.securityProtocol} +} + +func (c *testTransportCreds) OverrideServerName(string) error { + return nil +} + +type testAuthInfo struct{} + +func (a *testAuthInfo) AuthType() string { + return "test" +} + +func (a *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{} +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} diff --git a/internal/xds/bootstrap/jwtcreds/bundle.go b/internal/xds/bootstrap/jwtcreds/bundle.go new file mode 100644 index 000000000000..2b2b2103e908 --- /dev/null +++ b/internal/xds/bootstrap/jwtcreds/bundle.go @@ -0,0 +1,81 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwtcreds implements JWT Call Credentials in xDS Bootstrap File. +// See gRFC A97: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md +package jwtcreds + +import ( + "encoding/json" + "fmt" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/jwt" +) + +// bundle is an implementation of credentials.Bundle which implements JWT +// Call Credentials in xDS Bootstrap File per RFC A97. +// This bundle only provides call credentials, not transport credentials. +type bundle struct { + transportCreds credentials.TransportCredentials // Always nil for JWT call creds + callCreds credentials.PerRPCCredentials +} + +// NewBundle returns a credentials.Bundle which implements JWT Call Credentials +// in xDS Bootstrap File per RFC A97. This implementation focuses on call credentials +// only and expects the config to match RFC A97 structure. +// See gRFC A97: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md +func NewBundle(configJSON json.RawMessage) (credentials.Bundle, func(), error) { + var cfg struct { + JWTTokenFile string `json:"jwt_token_file"` + } + + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal JWT call credentials config: %v", err) + } + + if cfg.JWTTokenFile == "" { + return nil, nil, fmt.Errorf("jwt_token_file is required in JWT call credentials config") + } + + // Create JWT call credentials + callCreds, err := jwt.NewTokenFileCallCredentials(cfg.JWTTokenFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to create JWT call credentials: %v", err) + } + + bundle := &bundle{ + transportCreds: nil, // JWT call creds don't provide transport security + callCreds: callCreds, + } + + return bundle, func() {}, nil +} + +func (b *bundle) TransportCredentials() credentials.TransportCredentials { + // Transport credentials should be configured separately via channel_creds + return nil +} + +func (b *bundle) PerRPCCredentials() credentials.PerRPCCredentials { + return b.callCreds +} + +func (b *bundle) NewWithMode(_ string) (credentials.Bundle, error) { + return nil, fmt.Errorf("JWT call credentials bundle does not support mode switching") +} diff --git a/internal/xds/bootstrap/jwtcreds/bundle_test.go b/internal/xds/bootstrap/jwtcreds/bundle_test.go new file mode 100644 index 000000000000..74f49a710246 --- /dev/null +++ b/internal/xds/bootstrap/jwtcreds/bundle_test.go @@ -0,0 +1,214 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwtcreds + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" +) + +func TestNewBundle(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + + tests := []struct { + name string + config string + wantErr bool + wantErrContains string + }{ + { + name: "valid RFC A97 config with jwt_token_file", + config: `{ + "jwt_token_file": "` + tokenFile + `" + }`, + wantErr: false, + }, + { + name: "empty config", + config: `""`, + wantErr: true, + wantErrContains: "unmarshal", + }, + { + name: "empty config", + config: `{}`, + wantErr: true, + wantErrContains: "jwt_token_file is required", + }, + { + name: "empty path", + config: `{ + "jwt_token_file": "" + }`, + wantErr: true, + wantErrContains: "jwt_token_file is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bundle, cleanup, err := NewBundle(json.RawMessage(tt.config)) + + if tt.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("Error %v should contain %q", err, tt.wantErrContains) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if bundle == nil { + t.Fatal("Expected non-nil bundle") + } + + if cleanup == nil { + t.Error("Expected non-nil cleanup function") + } else { + defer cleanup() + } + + // JWT bundle only deals with PerRPCCredentials, not TransportCredentials + if bundle.TransportCredentials() != nil { + t.Error("Expected nil transport credentials for JWT call creds bundle") + } + + if bundle.PerRPCCredentials() == nil { + t.Error("Expected non-nil per-RPC credentials for valid JWT config") + } + + // Test that call credentials work + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + + metadata, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata failed: %v", err) + } + + if len(metadata) == 0 { + t.Error("Expected metadata to be returned") + } + + authHeader, ok := metadata["authorization"] + if !ok { + t.Error("Expected authorization header in metadata") + } + + if !strings.HasPrefix(authHeader, "Bearer ") { + t.Errorf("Authorization header should start with 'Bearer ', got %q", authHeader) + } + }) + } +} + +func TestBundle_NewWithMode(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + config := `{"jwt_token_file": "` + tokenFile + `"}` + bundle, cleanup, err := NewBundle(json.RawMessage(config)) + if err != nil { + t.Fatalf("NewBundle failed: %v", err) + } + defer cleanup() + + _, err = bundle.NewWithMode("test_mode") + if err == nil { + t.Error("Expected error from NewWithMode, got nil") + } + if !strings.Contains(err.Error(), "does not support mode switching") { + t.Errorf("Error should mention mode switching, got: %v", err) + } +} + +func TestBundle_Cleanup(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + config := `{"jwt_token_file": "` + tokenFile + `"}` + _, cleanup, err := NewBundle(json.RawMessage(config)) + if err != nil { + t.Fatalf("NewBundle failed: %v", err) + } + + if cleanup == nil { + t.Fatal("Expected non-nil cleanup function") + } + + // Cleanup should not panic + cleanup() + + // Multiple cleanup calls should be safe + cleanup() +} + +// testAuthInfo implements credentials.AuthInfo for testing +type testAuthInfo struct { + secLevel credentials.SecurityLevel +} + +func (t *testAuthInfo) AuthType() string { + return "test" +} + +func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} +} + +// createTestJWT creates a test JWT token for testing +func createTestJWT(t *testing.T) string { + t.Helper() + + // Create a valid JWT with proper base64 encoding for testing + // Header: {"typ":"JWT","alg":"HS256"} + header := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" + + // Claims: {"aud":"https://example.com","exp":future_timestamp} + claims := "eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tIiwiZXhwIjoyMDAwMDAwMDAwfQ" + + // Fake signature for testing + signature := "fake_signature_for_testing" + + return header + "." + claims + "." + signature +} + +func writeTempFile(t *testing.T, content string) string { + t.Helper() + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, "tempfile") + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + return filePath +} diff --git a/xds/bootstrap/bootstrap.go b/xds/bootstrap/bootstrap.go index ef55ff0c02db..b1a5e831b2a6 100644 --- a/xds/bootstrap/bootstrap.go +++ b/xds/bootstrap/bootstrap.go @@ -29,6 +29,7 @@ import ( "encoding/json" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/envconfig" ) // registry is a map from credential type name to Credential builder. @@ -58,6 +59,9 @@ func RegisterCredentials(c Credentials) { // GetCredentials returns the credentials associated with a given name. // If no credentials are registered with the name, nil will be returned. func GetCredentials(name string) Credentials { + if name == "jwt_token_file" && !envconfig.XDSBootstrapCallCredsEnabled { + return nil + } if c, ok := registry[name]; ok { return c } diff --git a/xds/bootstrap/bootstrap_test.go b/xds/bootstrap/bootstrap_test.go index d1f7a1b64ee5..935976975513 100644 --- a/xds/bootstrap/bootstrap_test.go +++ b/xds/bootstrap/bootstrap_test.go @@ -22,6 +22,7 @@ import ( "testing" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/envconfig" ) const testCredsBuilderName = "test_creds" @@ -64,12 +65,14 @@ func TestRegisterNew(t *testing.T) { func TestCredsBuilders(t *testing.T) { tests := []struct { - typename string - builder Credentials + typename string + builder Credentials + minimumRequiredConfig json.RawMessage }{ - {"google_default", &googleDefaultCredsBuilder{}}, - {"insecure", &insecureCredsBuilder{}}, - {"tls", &tlsCredsBuilder{}}, + {"google_default", &googleDefaultCredsBuilder{}, nil}, + {"insecure", &insecureCredsBuilder{}, nil}, + {"tls", &tlsCredsBuilder{}, nil}, + {"jwt_token_file", &jwtCallCredsBuilder{}, json.RawMessage(`{"jwt_token_file":"/path/to/token.jwt"}`)}, } for _, test := range tests { @@ -78,10 +81,13 @@ func TestCredsBuilders(t *testing.T) { t.Errorf("%T.Name = %v, want %v", test.builder, got, want) } - _, stop, err := test.builder.Build(nil) + bundle, stop, err := test.builder.Build(test.minimumRequiredConfig) if err != nil { t.Fatalf("%T.Build failed: %v", test.builder, err) } + if bundle == nil { + t.Errorf("%T.Build returned nil bundle, expected non-nil", test.builder) + } stop() }) } @@ -100,3 +106,27 @@ func TestTlsCredsBuilder(t *testing.T) { stop() } } + +func TestJwtCallCredentials_BuildDisabledIfFeatureNotEnabled(t *testing.T) { + builder := GetCredentials("jwt_call_creds") + if builder != nil { + t.Fatal("Expected nil Credentials for jwt_call_creds when the feature is disabled.") + } + + // Enable JWT call credentials + original := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = original + }() + + // Test that GetCredentials returns the JWT builder + builder = GetCredentials("jwt_token_file") + if builder == nil { + t.Fatal("GetCredentials(\"jwt_token_file\") returned nil") + } + + if got, want := builder.Name(), "jwt_token_file"; got != want { + t.Errorf("Retrieved builder name = %q, want %q", got, want) + } +} diff --git a/xds/bootstrap/credentials.go b/xds/bootstrap/credentials.go index 578e1278970d..38018972f383 100644 --- a/xds/bootstrap/credentials.go +++ b/xds/bootstrap/credentials.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/google" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/xds/bootstrap/jwtcreds" "google.golang.org/grpc/internal/xds/bootstrap/tlscreds" ) @@ -31,6 +32,7 @@ func init() { RegisterCredentials(&insecureCredsBuilder{}) RegisterCredentials(&googleDefaultCredsBuilder{}) RegisterCredentials(&tlsCredsBuilder{}) + RegisterCredentials(&jwtCallCredsBuilder{}) } // insecureCredsBuilder implements the `Credentials` interface defined in @@ -68,3 +70,15 @@ func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func (d *googleDefaultCredsBuilder) Name() string { return "google_default" } + +// jwtCallCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates JWT call credentials. +type jwtCallCredsBuilder struct{} + +func (j *jwtCallCredsBuilder) Build(configJSON json.RawMessage) (credentials.Bundle, func(), error) { + return jwtcreds.NewBundle(configJSON) +} + +func (j *jwtCallCredsBuilder) Name() string { + return "jwt_token_file" +} diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 967182740719..80bf8d0e8183 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -229,7 +229,9 @@ func populateGRPCTransportConfigsFromServerConfig(sc *bootstrap.ServerConfig, gr grpcTransportConfigs[cc.Type] = grpctransport.Config{ Credentials: bundle, GRPCNewClient: func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - opts = append(opts, sc.DialOptions()...) + // Only add call credentials that are compatible with this transport type + // Call credentials requiring transport security are skipped for insecure transports + opts = append(opts, sc.DialOptionsWithCallCredsForTransport(cc.Type, bundle.TransportCredentials())...) return grpc.NewClient(target, opts...) }, } diff --git a/xds/internal/xdsclient/clientimpl_test.go b/xds/internal/xdsclient/clientimpl_test.go index fbfc24a074ec..c7884e8ebff6 100644 --- a/xds/internal/xdsclient/clientimpl_test.go +++ b/xds/internal/xdsclient/clientimpl_test.go @@ -19,8 +19,10 @@ package xdsclient import ( + "context" "encoding/json" "fmt" + "net" "reflect" "sync" "testing" @@ -28,7 +30,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/internal/xds/bootstrap" "google.golang.org/grpc/xds/internal/clients" @@ -259,3 +263,90 @@ func (s) TestBuildXDSClientConfig_Success(t *testing.T) { }) } } + +func TestServerConfigCallCredsIntegration(t *testing.T) { + // Enable JWT call credentials + originalJWTEnabled := envconfig.XDSBootstrapCallCredsEnabled + envconfig.XDSBootstrapCallCredsEnabled = true + defer func() { + envconfig.XDSBootstrapCallCredsEnabled = originalJWTEnabled + }() + + tokenFile := "/token.jwt" + // Test server config with both channel and call credentials + serverConfigJSON := `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "` + tokenFile + `"} + } + ] + }` + + var sc bootstrap.ServerConfig + if err := sc.UnmarshalJSON([]byte(serverConfigJSON)); err != nil { + t.Fatalf("Failed to unmarshal server config: %v", err) + } + + // Verify call credentials are processed + callCreds := sc.CallCreds() + if len(callCreds) != 1 { + t.Errorf("Expected 1 call credential, got %d", len(callCreds)) + } + + selectedCallCreds := sc.SelectedCallCreds() + if len(selectedCallCreds) != 1 { + t.Errorf("Expected 1 selected call credential, got %d", len(selectedCallCreds)) + } + + // Test dial options for secure transport (should include JWT) + secureOpts := sc.DialOptionsWithCallCredsForTransport("tls", &mockTransportCreds{protocol: "tls"}) + if len(secureOpts) != 1 { + t.Errorf("Expected dial options for secure transport. Got: %#v", secureOpts) + } + + // Test dial options for insecure transport (should exclude JWT) + insecureOpts := sc.DialOptionsWithCallCredsForTransport("insecure", &mockTransportCreds{protocol: "insecure"}) + + // JWT should be filtered out for insecure transport + if len(insecureOpts) >= len(secureOpts) { + t.Error("Expected fewer dial options for insecure transport (JWT should be filtered)") + } +} + +// Mock transport credentials for testing +type mockTransportCreds struct { + protocol string +} + +func (m *mockTransportCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &mockAuthInfo{}, nil +} + +func (m *mockTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, &mockAuthInfo{}, nil +} + +func (m *mockTransportCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: m.protocol} +} + +func (m *mockTransportCreds) Clone() credentials.TransportCredentials { + return &mockTransportCreds{protocol: m.protocol} +} + +func (m *mockTransportCreds) OverrideServerName(string) error { + return nil +} + +type mockAuthInfo struct{} + +func (m *mockAuthInfo) AuthType() string { + return "mock" +} + +func (m *mockAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{} +}