diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index d6c205cb0a..4d710d7037 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -2,15 +2,16 @@ // Licensed under the MIT License. using System; +using System.Collections.Concurrent; using System.IO; -using System.Threading.Tasks; +using System.Security.Cryptography.X509Certificates; using System.Threading; -using Microsoft.Identity.Client.Internal; +using System.Threading.Tasks; using Microsoft.Identity.Client.ApiConfig.Parameters; -using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity.V2; -using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.PlatformsCommon.Shared; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -30,6 +31,9 @@ internal class ManagedIdentityClient internal static void ResetSourceForTest() { s_sourceName = ManagedIdentitySource.None; + + // Clear cert caches so each test starts fresh + ImdsV2ManagedIdentitySource.ResetCertCacheForTest(); } internal async Task SendTokenRequestForManagedIdentityAsync( diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateCacheEntry.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateCacheEntry.cs new file mode 100644 index 0000000000..417ad21211 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateCacheEntry.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// In-memory entry owned by the cache. Disposing the entry disposes the certificate it owns. + /// + internal sealed class CertificateCacheEntry : IDisposable + { + private int _disposed; + + /// + /// Represents the minimum remaining lifetime for an operation or resource. + /// + public static readonly TimeSpan MinRemainingLifetime = TimeSpan.FromHours(24); + + /// + /// certificate+endpoint+clientId cache entry. + /// + /// + /// + /// + /// + /// + public CertificateCacheEntry(X509Certificate2 certificate, DateTimeOffset notAfterUtc, string endpoint, string clientId) + { + Certificate = certificate ?? throw new ArgumentNullException(nameof(certificate)); + NotAfterUtc = notAfterUtc; + Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); + ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + } + + /// + /// certificate owned by this entry. + /// + public X509Certificate2 Certificate { get; } + /// + /// notAfterUtc of the certificate. + /// + public DateTimeOffset NotAfterUtc { get; } + /// + /// endpoint associated with this certificate. + /// + public string Endpoint { get; } + /// + /// clientId associated with this certificate. + /// + public string ClientId { get; } + + /// Whether this entry has been disposed. + public bool IsDisposed => Volatile.Read(ref _disposed) != 0; + + /// + /// is expired at the specified time. + /// + /// + /// + public bool IsExpiredUtc(DateTimeOffset nowUtc) => nowUtc >= (NotAfterUtc - MinRemainingLifetime); + + /// + /// dispose the entry and its certificate. + /// + public void Dispose() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; // already disposed + } + + // Idempotent due to the atomic guard + Certificate.Dispose(); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateCacheValue.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateCacheValue.cs new file mode 100644 index 0000000000..bf795ff56e --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateCacheValue.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography.X509Certificates; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// Immutable snapshot of a cached certificate and its associated metadata. + /// + internal readonly struct CertificateCacheValue + { + public CertificateCacheValue(X509Certificate2 certificate, string endpoint, string clientId) + { + if (certificate == null) throw new ArgumentNullException(nameof(certificate)); + if (endpoint == null) throw new ArgumentNullException(nameof(endpoint)); + if (clientId == null) throw new ArgumentNullException(nameof(clientId)); + + Certificate = certificate; + Endpoint = endpoint; + ClientId = clientId; + } + + /// The certificate (clone owned by the caller). + public X509Certificate2 Certificate { get; } + + /// The base endpoint to use with this certificate. + public string Endpoint { get; } + + /// The canonical client id to be posted to the mTLS token endpoint. + public string ClientId { get; } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICertificateCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICertificateCache.cs new file mode 100644 index 0000000000..1341e6e46a --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICertificateCache.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// Process-local cache for an mTLS certificate and its endpoint. + /// Expiration is based solely on certificate.NotAfter. + /// + internal interface ICertificateCache + { + /// + /// Try to get a cached certificate+endpoint+clientId for the specified cacheKey. + /// Returns true and non-null outputs if found and not expired. + /// + bool TryGet( + string cacheKey, + out CertificateCacheValue value, + ILoggerAdapter logger = null); + + /// + /// Insert or replace the cached certificate+endpoint+clientId for cacheKey. + /// + void Set( + string cacheKey, + in CertificateCacheValue value, + ILoggerAdapter logger = null); + + /// Remove an entry if present. + bool Remove(string cacheKey, ILoggerAdapter logger = null); + + /// Clear all entries. + void Clear(ILoggerAdapter logger = null); + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index cd937f44a2..aa98c98121 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -2,10 +2,12 @@ // Licensed under the MIT License. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Http; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; @@ -21,6 +23,13 @@ namespace Microsoft.Identity.Client.ManagedIdentity.V2 { internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { + // Central, process-local cache for mTLS binding (cert + endpoint + canonical client_id). + internal static readonly ICertificateCache s_mtlsCertificateCache = new InMemoryCertificateCache(); + + // Per-key async de-duplication so concurrent callers don’t double-mint. + internal static readonly ConcurrentDictionary s_perKeyGates = + new ConcurrentDictionary(StringComparer.Ordinal); + // used in unit tests public const string ImdsV2ApiVersion = "2.0"; public const string CsrMetadataPath = "/metadata/identity/getplatformmetadata"; @@ -211,17 +220,15 @@ private async Task ExecuteCertificateRequestAsync( "[ImdsV2] mTLS Proof-of-Possession requires a KeyGuard-backed key. Enable KeyGuard or use a KeyGuard-supported environment."); } - // TODO: : Normalize and validate attestation endpoint Code needs to be removed - // once IMDS team start returning full URI - Uri normalizedEndpoint = NormalizeAttestationEndpoint(attestationEndpoint, _requestContext.Logger); - // Ask helper for JWT only for KeyGuard keys string attestationJwt = string.Empty; + var attestationUri = new Uri(attestationEndpoint); + if (managedIdentityKeyInfo.Type == ManagedIdentityKeyType.KeyGuard) { attestationJwt = await GetAttestationJwtAsync( clientId, - normalizedEndpoint, + attestationUri, managedIdentityKeyInfo, _requestContext.UserCancellationToken).ConfigureAwait(false); } @@ -256,12 +263,14 @@ private async Task ExecuteCertificateRequestAsync( } catch (Exception ex) { + int? statusCode = response != null ? (int?)response.StatusCode : null; + throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, - $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed.", + "[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed.", ex, ManagedIdentitySource.ImdsV2, - (int)response.StatusCode); + statusCode); } if (response.StatusCode != HttpStatusCode.OK) @@ -284,46 +293,80 @@ protected override async Task CreateRequestAsync(string { var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false); - IManagedIdentityKeyProvider keyProvider = _requestContext.ServiceBundle.PlatformProxy.ManagedIdentityKeyProvider; + string certCacheKey = _requestContext.ServiceBundle.Config.ClientId; - ManagedIdentityKeyInfo keyInfo = await keyProvider - .GetOrCreateKeyAsync( - _requestContext.Logger, - _requestContext.UserCancellationToken) + var certEndpointAndClientId = await GetOrCreateMtlsBindingAsync( + cacheKey: certCacheKey, + async () => + { + IManagedIdentityKeyProvider keyProvider = _requestContext.ServiceBundle.PlatformProxy.ManagedIdentityKeyProvider; + + ManagedIdentityKeyInfo keyInfo = await keyProvider + .GetOrCreateKeyAsync(_requestContext.Logger, _requestContext.UserCancellationToken) + .ConfigureAwait(false); + + var csrAndKey = _requestContext.ServiceBundle.Config.CsrFactory.Generate( + keyInfo.Key, + csrMetadata.ClientId, + csrMetadata.TenantId, + csrMetadata.CuId); + + string csr = csrAndKey.csrPem; + var privateKey = csrAndKey.privateKey; + + var certificateRequestResponse = await ExecuteCertificateRequestAsync( + csrMetadata.ClientId, + csrMetadata.AttestationEndpoint, + csr, + keyInfo).ConfigureAwait(false); + + X509Certificate2 mtlsCertificate = CommonCryptographyManager.AttachPrivateKeyToCert( + certificateRequestResponse.Certificate, + privateKey); + + // Base endpoint = "{mtlsAuthEndpoint}/{tenantId}" + string endpointBase = + (certificateRequestResponse.MtlsAuthenticationEndpoint).TrimEnd('/') + + "/" + + (certificateRequestResponse.TenantId).Trim('/'); + + // Canonical GUID to use as client_id in the token call + string clientIdGuid = certificateRequestResponse.ClientId; + + return Tuple.Create(mtlsCertificate, endpointBase, clientIdGuid); + }, + _requestContext.UserCancellationToken, + _requestContext.Logger) .ConfigureAwait(false); - var (csr, privateKey) = _requestContext.ServiceBundle.Config.CsrFactory.Generate(keyInfo.Key, csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - - var certificateRequestResponse = await ExecuteCertificateRequestAsync( - csrMetadata.ClientId, - csrMetadata.AttestationEndpoint, - csr, - keyInfo).ConfigureAwait(false); + X509Certificate2 bindingCertificate = certEndpointAndClientId.Item1; + string endpointBaseForToken = certEndpointAndClientId.Item2; + string clientIdForToken = certEndpointAndClientId.Item3; - // transform certificateRequestResponse.Certificate to x509 with private key - var mtlsCertificate = CommonCryptographyManager.AttachPrivateKeyToCert( - certificateRequestResponse.Certificate, - privateKey); + ManagedIdentityRequest request = new ManagedIdentityRequest( + HttpMethod.Post, + new Uri(endpointBaseForToken + AcquireEntraTokenPath)); - ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); + Dictionary idParams = MsalIdHelper.GetMsalIdParameters(_requestContext.Logger); - var idParams = MsalIdHelper.GetMsalIdParameters(_requestContext.Logger); - foreach (var idParam in idParams) + foreach (KeyValuePair idParam in idParams) { request.Headers[idParam.Key] = idParam.Value; } + request.Headers.Add(OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString()); request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue); request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true"); - request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); + var tokenType = _isMtlsPopRequested ? Constants.MtlsPoPTokenType : Constants.BearerTokenType; + + request.BodyParameters.Add("client_id", clientIdForToken); request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials); request.BodyParameters.Add("scope", resource.TrimEnd('/') + "/.default"); - request.BodyParameters.Add("token_type", "mtls_pop"); + request.BodyParameters.Add("token_type", tokenType); request.RequestType = RequestType.STS; - - request.MtlsCertificate = mtlsCertificate; + request.MtlsCertificate = bindingCertificate; return request; } @@ -397,56 +440,82 @@ private async Task GetAttestationJwtAsync( return response.AttestationToken; } - //To-do : Remove this method once IMDS team start returning full URI + // ...unchanged usings and class header... + /// - /// Temporarily normalize attestation endpoint values to a full https:// URI. - /// IMDS team will eventually return a full URI. + /// Read-through cache: try cache; if missing, run async factory once (per key), + /// store the result, and return it. Thread-safe for the given cacheKey. /// - /// - /// - /// - private static Uri NormalizeAttestationEndpoint(string rawEndpoint, ILoggerAdapter logger) + private static async Task> GetOrCreateMtlsBindingAsync( + string cacheKey, + Func>> factory, + CancellationToken cancellationToken, + ILoggerAdapter logger) { - if (string.IsNullOrWhiteSpace(rawEndpoint)) - { - return null; - } + if (string.IsNullOrWhiteSpace(cacheKey)) + throw new ArgumentException("cacheKey must be non-empty.", nameof(cacheKey)); + if (factory is null) + throw new ArgumentNullException(nameof(factory)); - // Trim whitespace - rawEndpoint = rawEndpoint.Trim(); + X509Certificate2 cachedCertificate; + string cachedEndpointBase; + string cachedClientId; - // If it already parses as an absolute URI with https, keep it. - if (Uri.TryCreate(rawEndpoint, UriKind.Absolute, out var absolute) && - (absolute.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase))) + // 1) Only lookup by cacheKey + if (s_mtlsCertificateCache.TryGet(cacheKey, out var cached, logger)) { - return absolute; + cachedCertificate = cached.Certificate; + cachedEndpointBase = cached.Endpoint; + cachedClientId = cached.ClientId; + + return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId); } - // If it has no scheme (common service behavior returning only host) - // prepend https:// and try again. - if (!rawEndpoint.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + // 2) Gate per cacheKey + var gate = s_perKeyGates.GetOrAdd(cacheKey, _ => new SemaphoreSlim(1, 1)); + await gate.WaitAsync(cancellationToken).ConfigureAwait(false); + + try { - var candidate = "https://" + rawEndpoint; - if (Uri.TryCreate(candidate, UriKind.Absolute, out var httpsUri)) + // Re-check after acquiring the gate + if (s_mtlsCertificateCache.TryGet(cacheKey, out cached, logger)) { - logger.Info(() => $"[Managed Identity] Normalized attestation endpoint '{rawEndpoint}' -> '{httpsUri.ToString()}'."); - return httpsUri; + cachedCertificate = cached.Certificate; + cachedEndpointBase = cached.Endpoint; + cachedClientId = cached.ClientId; + return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId); } + + // 3) Mint + cache under the provided cacheKey + var created = await factory().ConfigureAwait(false); + + s_mtlsCertificateCache.Set(cacheKey, + new CertificateCacheValue(created.Item1, created.Item2, created.Item3), + logger); + + return created; + } + finally + { + gate.Release(); } + } - // Final attempt: reject http (non‑TLS) or malformed - if (Uri.TryCreate(rawEndpoint, UriKind.Absolute, out var anyUri)) + internal static void ResetCertCacheForTest() + { + // Clear caches so each test starts fresh + if (s_mtlsCertificateCache != null) { - if (!anyUri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) - { - logger.Warning($"[Managed Identity] Attestation endpoint uses unsupported scheme '{anyUri.Scheme}'. HTTPS is required."); - return null; - } - return anyUri; + s_mtlsCertificateCache.Clear(); } - logger.Warning($"[Managed Identity] Failed to normalize attestation endpoint value '{rawEndpoint}'."); - return null; + foreach (var gate in s_perKeyGates.Values) + { + try + { gate.Dispose(); } + catch { } + } + s_perKeyGates.Clear(); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/InMemoryCertificateCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/InMemoryCertificateCache.cs new file mode 100644 index 0000000000..ab310fdaf2 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/InMemoryCertificateCache.cs @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// Certificate + endpoint + clientId cache stored in process memory. + /// + internal sealed class InMemoryCertificateCache : ICertificateCache, IDisposable + { + private readonly ConcurrentDictionary _entriesByCacheKey = + new ConcurrentDictionary(StringComparer.Ordinal); + + private int _disposed; + + /// + public bool TryGet( + string cacheKey, + out CertificateCacheValue value, + ILoggerAdapter logger = null) + { + ThrowIfDisposed(); + ValidateCacheKey(cacheKey); + + value = default; + + if (_entriesByCacheKey.TryGetValue(cacheKey, out var entry)) + { + if (TryEvictIfExpired(cacheKey, entry, logger)) + { + return false; + } + + // Return a clone so the caller can dispose independently. + var certClone = new X509Certificate2(entry.Certificate); + value = new CertificateCacheValue(certClone, entry.Endpoint, entry.ClientId); + + logger?.Verbose(() => "[CertCache] HIT (key='" + Mask(cacheKey) + "')."); + return true; + } + + logger?.Verbose(() => "[CertCache] MISS (key='" + Mask(cacheKey) + "')."); + return false; + } + + /// + public void Set( + string cacheKey, + in CertificateCacheValue value, + ILoggerAdapter logger = null) + { + ThrowIfDisposed(); + ValidateCacheKey(cacheKey); + + if (value.Certificate is null) + throw new ArgumentNullException(nameof(value.Certificate)); + if (string.IsNullOrWhiteSpace(value.Endpoint)) + throw new ArgumentException("Endpoint must be non-empty.", nameof(value.Endpoint)); + if (string.IsNullOrWhiteSpace(value.ClientId)) + throw new ArgumentException("ClientId must be non-empty.", nameof(value.ClientId)); + + var notAfterUtc = ToNotAfterUtc(value.Certificate); + var nowUtc = DateTimeOffset.UtcNow; + + // Enforce minimum remaining lifetime (e.g., 24h). + if (notAfterUtc <= nowUtc + CertificateCacheEntry.MinRemainingLifetime) + { + var remaining = notAfterUtc - nowUtc; + logger?.Verbose(() => + "[CertCache] Skipping certificate with insufficient remaining lifetime " + + $"({remaining.TotalHours:F2}h) (key='{Mask(cacheKey)}')."); + return; + } + + // Cache owns its copy; it will dispose upon eviction. + var cachedCopy = new X509Certificate2(value.Certificate); + var newEntry = new CertificateCacheEntry(cachedCopy, notAfterUtc, value.Endpoint, value.ClientId); + + _entriesByCacheKey.AddOrUpdate( + cacheKey, + _ => + { + logger?.Verbose(() => "[CertCache] SET (key='" + Mask(cacheKey) + "')."); + return newEntry; + }, + (_, old) => + { + if (!old.IsDisposed) + { + old.Dispose(); + } + logger?.Verbose(() => "[CertCache] REPLACE (key='" + Mask(cacheKey) + "')."); + return newEntry; + }); + } + + /// + public bool Remove(string cacheKey, ILoggerAdapter logger = null) + { + ThrowIfDisposed(); + ValidateCacheKey(cacheKey); + + if (_entriesByCacheKey.TryRemove(cacheKey, out var entry)) + { + if (!entry.IsDisposed) + { + entry.Dispose(); + } + logger?.Verbose(() => "[CertCache] REMOVE (key='" + Mask(cacheKey) + "')."); + return true; + } + return false; + } + + /// + public void Clear(ILoggerAdapter logger = null) + { + ThrowIfDisposed(); + + foreach (var kvp in _entriesByCacheKey) + { + if (_entriesByCacheKey.TryRemove(kvp.Key, out var entry)) + { + if (!entry.IsDisposed) + { + entry.Dispose(); + } + } + } + + logger?.Verbose(() => "[CertCache] CLEAR."); + } + + public void Dispose() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + return; + + // Dispose entries and empty the map + foreach (var kvp in _entriesByCacheKey) + { + if (_entriesByCacheKey.TryRemove(kvp.Key, out var entry)) + { + if (!entry.IsDisposed) + { + entry.Dispose(); + } + } + } + } + + // --- helpers --- + + private void ThrowIfDisposed() + { + if (Volatile.Read(ref _disposed) != 0) + { + throw new ObjectDisposedException(nameof(InMemoryCertificateCache)); + } + } + + private static void ValidateCacheKey(string cacheKey) + { + if (string.IsNullOrWhiteSpace(cacheKey)) + throw new ArgumentException("Cache key must be non-empty.", nameof(cacheKey)); + } + + private bool TryEvictIfExpired(string cacheKey, CertificateCacheEntry entry, ILoggerAdapter logger) + { + var nowUtc = DateTimeOffset.UtcNow; + if (!entry.IsExpiredUtc(nowUtc)) + { + return false; + } + + if (_entriesByCacheKey.TryRemove(cacheKey, out var removed)) + { + if (!removed.IsDisposed) + { + removed.Dispose(); + } + logger?.Verbose(() => "[CertCache] Evicted expired entry (key='" + Mask(cacheKey) + "')."); + } + + return true; + } + + private static DateTimeOffset ToNotAfterUtc(X509Certificate2 cert) + { + var notAfter = cert.NotAfter; + if (notAfter.Kind == DateTimeKind.Unspecified) + { + notAfter = DateTime.SpecifyKind(notAfter, DateTimeKind.Local); + } + return new DateTimeOffset(notAfter.ToUniversalTime()); + } + + /// + /// Used for logging cache keys without exposing full values. + /// + /// The sensitive string. + /// Masked representation. + private static string Mask(string s) + { + if (string.IsNullOrEmpty(s)) + return ""; + + // Do not reveal full value for short keys + if (s.Length <= 8) + return $"…({s.Length})"; + + var take = 8; + return "…" + s.Substring(s.Length - take, take) + "(" + s.Length + ")"; + } + } +} diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index b82578d179..a8d6cdaded 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -602,18 +602,22 @@ public static MockHttpMessageHandler MockCsrResponse( HttpStatusCode statusCode = HttpStatusCode.OK, string responseServerHeader = "IMDS/150.870.65.1854", UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, - string userAssignedId = null) + string userAssignedId = null, + string clientIdOverride = null, + string tenantIdOverride = null, + string attestationEndpointOverride = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); IList presentRequestHeaders = new List - { - OAuth2Header.XMsCorrelationId - }; + { + OAuth2Header.XMsCorrelationId + }; if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( + (ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); } expectedQueryParams.Add("cred-api-version", "2.0"); @@ -622,9 +626,9 @@ public static MockHttpMessageHandler MockCsrResponse( string content = "{" + "\"cuId\": { \"vmId\": \"fake_vmId\" }," + - "\"clientId\": \"" + TestConstants.ClientId + "\"," + - "\"tenantId\": \"" + TestConstants.TenantId + "\"," + - "\"attestationEndpoint\": \"fake_attestation_endpoint\"" + + "\"clientId\": \"" + (clientIdOverride ?? TestConstants.ClientId) + "\"," + + "\"tenantId\": \"" + (tenantIdOverride ?? TestConstants.TenantId) + "\"," + + "\"attestationEndpoint\": \"" + (attestationEndpointOverride ?? "https://fake_attestation_endpoint") + "\"" + "}"; var handler = new MockHttpMessageHandler() @@ -654,20 +658,24 @@ public static MockHttpMessageHandler MockCsrResponseFailure() } public static MockHttpMessageHandler MockCertificateRequestResponse( - UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, - string userAssignedId = null, - string certificate = TestConstants.ValidRawCertificate) + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificate = TestConstants.ValidRawCertificate, + string clientIdOverride = null, + string tenantIdOverride = null, + string mtlsEndpointOverride = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); IList presentRequestHeaders = new List - { - OAuth2Header.XMsCorrelationId - }; + { + OAuth2Header.XMsCorrelationId + }; if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) { - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( + (ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); } expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); @@ -675,11 +683,11 @@ public static MockHttpMessageHandler MockCertificateRequestResponse( string content = "{" + - "\"client_id\": \"" + TestConstants.ClientId + "\"," + - "\"tenant_id\": \"" + TestConstants.TenantId + "\"," + + "\"client_id\": \"" + (clientIdOverride ?? TestConstants.ClientId) + "\"," + + "\"tenant_id\": \"" + (tenantIdOverride ?? TestConstants.TenantId) + "\"," + "\"certificate\": \"" + certificate + "\"," + - "\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned", it doesn't matter for these tests - "\"mtls_authentication_endpoint\": \"" + TestConstants.MtlsAuthenticationEndpoint + "\"" + + "\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned" - not relevant in tests + "\"mtls_authentication_endpoint\": \"" + (mtlsEndpointOverride ?? TestConstants.MtlsAuthenticationEndpoint) + "\"" + "}"; var handler = new MockHttpMessageHandler() @@ -734,5 +742,77 @@ public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( return handler; } + + internal static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponseExpectClientId( + IdentityLoggerAdapter identityLoggerAdapter, + bool mTLSPop = false, + string expectedClientId = TestConstants.ClientId) + { + IDictionary expectedPostData = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary + { + { ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue } + }; + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter); + foreach (var idParam in idParams) + { + expectedRequestHeaders[idParam.Key] = idParam.Value; + } + + var tokenType = mTLSPop ? "mtls_pop" : "bearer"; + expectedPostData.Add("token_type", tokenType); + expectedPostData.Add("client_id", expectedClientId); // <— assert canonical GUID + + return new MockHttpMessageHandler() + { + ExpectedUrl = $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", + ExpectedMethod = HttpMethod.Post, + ExpectedPostData = expectedPostData, + ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(GetMsiSuccessfulResponse(imdsV2: true)), + } + }; + } + + internal static void AddMocksToGetEntraTokenUsingCachedCert( + MockHttpManager httpManager, + IdentityLoggerAdapter identityLoggerAdapter, + bool mTLSPop = false, + bool assertClientId = false, + string expectedClientId = TestConstants.ClientId, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null) + { + // cached‑cert refresh still calls /getplatformmetadata (SAMI or UAMI flavor) + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + httpManager.AddMockHandler( + MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + } + + // Token request (no /issuecredential added here) + if (assertClientId) + { + httpManager.AddMockHandler( + MockHelpers.MockImdsV2EntraTokenRequestResponseExpectClientId(identityLoggerAdapter, mTLSPop, expectedClientId)); + } + else + { + httpManager.AddMockHandler( + MockHelpers.MockImdsV2EntraTokenRequestResponse(identityLoggerAdapter)); + } + } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 873fc8a813..4c2d9404b5 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -1,14 +1,18 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using System; +using System.Collections.Generic; +using System.IO; using System.Net; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; +using System.Xml; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.KeyProviders; @@ -679,5 +683,784 @@ await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Res } } #endregion + + #region cached certificate tests + [TestMethod] + public async Task mTLSPop_ForceRefresh_UsesCachedCert_NoIssueCredential_PostsCanonicalClientId_AndSkipsAttestation() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + // Start clean across tests + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + // First acquire: full flow (CSR + issuecredential + token) + AddMocksToGetEntraToken(httpManager); + + int attestationCalls = 0; + Func> countingProvider = + (input, ct) => + { + Interlocked.Increment(ref attestationCalls); + return Task.FromResult(new AttestationTokenResponse { AttestationToken = "header.payload.sig" }); + }; + + var result1 = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(countingProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(ImdsV2Tests.MTLSPoP, result1.TokenType); + Assert.IsNotNull(result1.BindingCertificate); + Assert.AreEqual(TokenSource.IdentityProvider, result1.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(1, attestationCalls, "Attestation must be called exactly once on first mint."); + + // Second acquire: FORCE REFRESH to bypass token cache. + // Expect: 1x getplatformmetadata + token request. NO /issuecredential. Attestation NOT called again. + MockHelpers.AddMocksToGetEntraTokenUsingCachedCert( + httpManager, + _identityLoggerAdapter, + mTLSPop: true, + assertClientId: true, // assert canonical client_id is posted + expectedClientId: TestConstants.ClientId); + + var result2 = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithForceRefresh(true) // if your API is parameterless, use .WithForceRefresh() + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(countingProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(ImdsV2Tests.MTLSPoP, result2.TokenType); + Assert.IsNotNull(result2.BindingCertificate); + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(1, attestationCalls, "Attestation must NOT be invoked on refresh when cert is cached."); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, TestConstants.ClientId + "-2")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, TestConstants.MiResourceId + "-2")] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, TestConstants.ObjectId + "-2")] + public async Task mTLSPop_CachedCertIsPerIdentity_OnRefresh_Identity1UsesCache_Identity2Mints( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId1, + string userAssignedId2) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // Identity 1 – first acquire (mint) + var mi1 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId1, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId1); + + var result1 = await mi1.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + Assert.AreEqual(TokenSource.IdentityProvider, result1.AuthenticationResultMetadata.TokenSource); + + // Identity 1 – force refresh (should use cached cert → NO /issuecredential) + MockHelpers.AddMocksToGetEntraTokenUsingCachedCert( + httpManager, + _identityLoggerAdapter, + mTLSPop: true, + assertClientId: true, + expectedClientId: TestConstants.ClientId, + userAssignedIdentityId: userAssignedIdentityId, + userAssignedId: userAssignedId1 + ); + + var result1Refresh = await mi1.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithForceRefresh(true) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync() + .ConfigureAwait(false); + + Assert.AreEqual(TokenSource.IdentityProvider, result1Refresh.AuthenticationResultMetadata.TokenSource); + + // Identity 2 – new identity (should MINT again → requires /issuecredential) + var mi2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId2, addProbeMock: false, addSourceCheck: false, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId2); + + var result2 = await mi2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + } + } + #endregion + + #region Cert cache tests + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null, /*isUami*/ false)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, /*isUami*/ true)] // UAMI by client_id + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, /*isUami*/ true)] // UAMI by resource_id + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, /*isUami*/ true)] // UAMI by object_id + public async Task mTLSPopTokenHappyPath_LongLivedCert_IdentityMapping( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId, + bool isUami) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // Force KeyGuard so the PoP path is taken + var managedIdentityApp = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId, + userAssignedId, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard + ).ConfigureAwait(false); + + // --- First acquire: MINT (CSR + issuecredential + token) with a long-lived cert --- + // Use the known-good cert that matches TestCsrFactory's RSA and already has a far NotAfter (>= 20 years) + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId, + userAssignedId, + certificateRequestCertificate: TestConstants.ValidRawCertificate); + + var first = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(first); + Assert.AreEqual(MTLSPoP, first.TokenType, "Token type must be mtls_pop"); + Assert.IsNotNull(first.BindingCertificate, "Binding certificate should be present on mTLS PoP tokens"); + Assert.AreEqual(TokenSource.IdentityProvider, first.AuthenticationResultMetadata.TokenSource); + + Assert.IsTrue(first.BindingCertificate.NotAfter.ToUniversalTime() >= DateTime.UtcNow.AddYears(20).AddDays(-1), + $"Binding cert NotAfter {first.BindingCertificate.NotAfter:u} should be >= ~20 years from now."); + + var second = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(second); + Assert.AreEqual(MTLSPoP, second.TokenType); + Assert.IsNotNull(second.BindingCertificate, "Binding certificate should be present on cached mTLS PoP tokens"); + Assert.AreEqual(TokenSource.Cache, second.AuthenticationResultMetadata.TokenSource); + + // Optional: Same thumbprint between the two (same cached binding cert) + Assert.AreEqual(first.BindingCertificate.Thumbprint, second.BindingCertificate.Thumbprint, + "Cached mTLS flow should reuse the same binding certificate."); + + // Your existing CN assertion against the baked-in TestConstants.ValidRawCertificate + AssertCertCN(first.BindingCertificate, "Test"); + AssertCertCN(second.BindingCertificate, "Test"); + } + } + + /// + /// Create TWO long-lived (20y) raw DER (base64) certs with the CSR key: + /// - One for SAMI (CN=SAMI-20Y) + /// - One for UAMI (CN=UAMI-20Y) + /// Then run mint + cached flows and assert thumbprints. + /// + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null, /*aliasLabel*/ "SAMI")] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, /*aliasLabel*/ "UAMI-ClientId")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId,/*aliasLabel*/ "UAMI-ResourceId")] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, /*aliasLabel*/ "UAMI-ObjectId")] + public async Task mTLSPop_LongLivedCerts_SamiVsUami_DistinctAndCached( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId, + string aliasLabel) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // Create the two test certs (20-year) from the SAME RSA as CSR (XmlPrivateKey) + string rawCertSami = CreateRawCertFromXml("CN=SAMI-20Y", notAfterUtc: DateTimeOffset.UtcNow.AddYears(20)); + string rawCertUami = CreateRawCertFromXml("CN=UAMI-20Y", notAfterUtc: DateTimeOffset.UtcNow.AddYears(20)); + + // Build an MI app for the row's identity kind (force KeyGuard so mTLS path is used) + var mi = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId, + userAssignedId, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + // --- First acquire (MINT): return the identity-specific cert we want --- + // SAMI → use rawCertSami ; UAMI (any alias) → use rawCertUami + string selectedCert = (userAssignedIdentityId == UserAssignedIdentityId.None) ? rawCertSami : rawCertUami; + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId, + userAssignedId, + certificateRequestCertificate: selectedCert); + + var first = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(first); + Assert.AreEqual(MTLSPoP, first.TokenType, $"[{aliasLabel}] token type must be mtls_pop"); + Assert.IsNotNull(first.BindingCertificate, $"[{aliasLabel}] binding cert missing"); + Assert.AreEqual(TokenSource.IdentityProvider, first.AuthenticationResultMetadata.TokenSource, $"[{aliasLabel}] first acquire must mint from IDP"); + Assert.IsTrue(first.BindingCertificate.NotAfter.ToUniversalTime() >= DateTime.UtcNow.AddYears(20).AddDays(-1), + $"[{aliasLabel}] NotAfter {first.BindingCertificate.NotAfter:u} should be ~20y+"); + + var thumb1 = first.BindingCertificate.Thumbprint; + + // --- Second acquire: cached; cert should be the SAME (cached binding cert) --- + var second = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(second); + Assert.AreEqual(MTLSPoP, second.TokenType, $"[{aliasLabel}] cached token type"); + Assert.IsNotNull(second.BindingCertificate, $"[{aliasLabel}] cached binding cert missing"); + Assert.AreEqual(TokenSource.Cache, second.AuthenticationResultMetadata.TokenSource, $"[{aliasLabel}] second acquire should be from cache"); + Assert.AreEqual(thumb1, second.BindingCertificate.Thumbprint, $"[{aliasLabel}] cached must reuse same binding cert"); + + var expectedCn = (userAssignedIdentityId == UserAssignedIdentityId.None) ? "SAMI-20Y" : "UAMI-20Y"; + AssertCertCN(first.BindingCertificate, expectedCn); + AssertCertCN(second.BindingCertificate, expectedCn); + } + } + + /// + /// End-to-end: mint SAMI & UAMI in one test and prove their binding certs differ, + /// while each identity reuses its own binding cert from cache. + /// + [TestMethod] + public async Task mTLSPop_LongLivedCerts_SamiAndUami_ThumbprintsDiffer_AndEachCaches() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // Make two long-lived certs **from the CSR key** so AttachPrivateKey succeeds + string rawCertSami = CreateRawCertForCsrKey("CN=SAMI-20Y", DateTimeOffset.UtcNow.AddYears(20)); + string rawCertUami = CreateRawCertForCsrKey("CN=UAMI-20Y", DateTimeOffset.UtcNow.AddYears(20)); + + // ---------- SAMI ---------- + var sami = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.None, + userAssignedId: null, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard + ).ConfigureAwait(false); + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.None, + userAssignedId: null, + certificateRequestCertificate: rawCertSami); + + var s1 = await sami.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(s1.BindingCertificate); + AssertCertCN(s1.BindingCertificate, "SAMI-20Y"); + + var samiThumb = s1.BindingCertificate.Thumbprint; + + // cached + var s2 = await sami.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.Cache, s2.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(samiThumb, s2.BindingCertificate.Thumbprint, "SAMI must reuse cached binding cert"); + + // ---------- UAMI (client_id) ---------- + var uami = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ClientId, + userAssignedId: TestConstants.ClientId, + addProbeMock: false, + addSourceCheck: false, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard + ).ConfigureAwait(false); + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ClientId, + userAssignedId: TestConstants.ClientId, + certificateRequestCertificate: rawCertUami); + + var u1 = await uami.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(u1.BindingCertificate); + AssertCertCN(u1.BindingCertificate, "UAMI-20Y"); + + var uamiThumb = u1.BindingCertificate.Thumbprint; + + var u2 = await uami.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.Cache, u2.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(uamiThumb, u2.BindingCertificate.Thumbprint, "UAMI must reuse cached binding cert"); + + // Cross-identity certs must differ + Assert.AreNotEqual(samiThumb, uamiThumb, "SAMI and UAMI must use different binding certs"); + } + } + + /// + /// Subject mapping test that mirrors prod: CN=canonical client_id, DC=tenant id. + /// - SAMI → CN = Constants.ManagedIdentityDefaultClientId + /// - UAMI (client_id|object_id|resource_id) → CN = TestConstants.ClientId (canonical) + /// Both assert DC = TestConstants.TenantId and cert cache reuse. + /// + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null, /*label*/ "SAMI", /*isUami*/ false)] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, /*label*/ "UAMI-ClientId", /*isUami*/ true)] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, /*label*/ "UAMI-ObjectId", /*isUami*/ true)] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId,/*label*/"UAMI-ResourceId",/*isUami*/ true)] + public async Task mTLSPop_SubjectCnDc_MatchesMetadata_AndCaches( + UserAssignedIdentityId idKind, + string idValue, + string label, + bool isUami) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // Expected mapping (mirrors your live logs) + string expectedCn = isUami ? TestConstants.ClientId : Constants.ManagedIdentityDefaultClientId; + string expectedDc = TestConstants.TenantId; + + // Mint a 20-year cert with Subject "CN=, DC=" using the CSR key + string rawCert = CreateRawCertForCsrKeyWithCnDc(expectedCn, expectedDc, DateTimeOffset.UtcNow.AddYears(20)); + + var mi = await CreateManagedIdentityAsync(httpManager, idKind, idValue, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard) + .ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, idKind, idValue, certificateRequestCertificate: rawCert); + + var first = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(MTLSPoP, first.TokenType, $"[{label}]"); + AssertCertSubjectCnDc(first.BindingCertificate, expectedCn, expectedDc, label); + + var second = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.Cache, second.AuthenticationResultMetadata.TokenSource, $"[{label}] cache"); + Assert.AreEqual(first.BindingCertificate.Thumbprint, second.BindingCertificate.Thumbprint, $"[{label}] thumbprint must be stable"); + AssertCertSubjectCnDc(second.BindingCertificate, expectedCn, expectedDc, label); + } + } + + [TestMethod] + public async Task mTLSPoP_Uami_ClientIdThenObjectId_MintsThenCaches_SubjectCNIsClientId() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + string expectedCn = TestConstants.ClientId; + string expectedDc = TestConstants.TenantId; + string rawCert = CreateRawCertForCsrKeyWithCnDc(expectedCn, expectedDc, DateTimeOffset.UtcNow.AddYears(20)); + + // (1) client_id → MINT (CSR + issuecredential + token) + var miClientId = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ClientId, + userAssignedId: TestConstants.ClientId, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ClientId, + userAssignedId: TestConstants.ClientId, + certificateRequestCertificate: rawCert); + + var c1 = await miClientId.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(MTLSPoP, c1.TokenType); + Assert.AreEqual(TokenSource.IdentityProvider, c1.AuthenticationResultMetadata.TokenSource); + AssertCertSubjectCnDc(c1.BindingCertificate, expectedCn, expectedDc, "[client_id]"); + + // (2) object_id → MINT (new alias → its own cache key) + var miObjectId = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ObjectId, + userAssignedId: TestConstants.ObjectId, + addProbeMock: false, + addSourceCheck: false, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ObjectId, + userAssignedId: TestConstants.ObjectId, + certificateRequestCertificate: rawCert); + + var o1 = await miObjectId.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(MTLSPoP, o1.TokenType); + Assert.AreEqual(TokenSource.IdentityProvider, o1.AuthenticationResultMetadata.TokenSource); + AssertCertSubjectCnDc(o1.BindingCertificate, expectedCn, expectedDc, "[object_id first]"); + var objectIdThumb = o1.BindingCertificate.Thumbprint; + + // (3) object_id again → CACHED + var o2 = await miObjectId.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.Cache, o2.AuthenticationResultMetadata.TokenSource); + Assert.AreEqual(objectIdThumb, o2.BindingCertificate.Thumbprint); + AssertCertSubjectCnDc(o2.BindingCertificate, expectedCn, expectedDc, "[object_id second]"); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, "object_id")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, "resource_id")] + public async Task mTLSPoP_Uami_ClientIdThenAlias_MintsThenCaches_SubjectCNIsClientId( + UserAssignedIdentityId aliasKind, + string aliasValue, + string label) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + string expectedCn = TestConstants.ClientId; + string expectedDc = TestConstants.TenantId; + string rawCert = CreateRawCertForCsrKeyWithCnDc(expectedCn, expectedDc, DateTimeOffset.UtcNow.AddYears(20)); + + // (1) client_id → MINT (CSR + issuecredential + token) + var miClientId = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ClientId, + userAssignedId: TestConstants.ClientId, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId: UserAssignedIdentityId.ClientId, + userAssignedId: TestConstants.ClientId, + certificateRequestCertificate: rawCert); + + var c1 = await miClientId.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(MTLSPoP, c1.TokenType, "[client_id]"); + Assert.AreEqual(TokenSource.IdentityProvider, c1.AuthenticationResultMetadata.TokenSource, "[client_id] should mint"); + AssertCertSubjectCnDc(c1.BindingCertificate, expectedCn, expectedDc, "[client_id]"); + + // (2) alias (object_id/resource_id) → MINT (new alias → new cache key) + var miAlias = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId: aliasKind, + userAssignedId: aliasValue, + addProbeMock: false, + addSourceCheck: false, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken( + httpManager, + userAssignedIdentityId: aliasKind, + userAssignedId: aliasValue, + certificateRequestCertificate: rawCert); + + var a1 = await miAlias.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(MTLSPoP, a1.TokenType, $"[{label} first]"); + Assert.AreEqual(TokenSource.IdentityProvider, a1.AuthenticationResultMetadata.TokenSource, $"[{label} first] should mint"); + AssertCertSubjectCnDc(a1.BindingCertificate, expectedCn, expectedDc, $"[{label} first]"); + var aliasThumb = a1.BindingCertificate.Thumbprint; + + // (3) alias again → CACHED (no /issuecredential; no extra mocks needed) + var a2 = await miAlias.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.Cache, a2.AuthenticationResultMetadata.TokenSource, $"[{label} second] should be cached"); + Assert.AreEqual(aliasThumb, a2.BindingCertificate.Thumbprint, $"[{label}] cached binding cert must match"); + AssertCertSubjectCnDc(a2.BindingCertificate, expectedCn, expectedDc, $"[{label} second]"); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, "UAMI-ClientId")] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, "UAMI-ObjectId")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, "UAMI-ResourceId")] + [DataRow(UserAssignedIdentityId.None, null, "SAMI")] + public async Task mTLSPop_ShortLivedCert_LessThan24h_NotCached_ReMints( + UserAssignedIdentityId idKind, + string idValue, + string label) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // short-lived cert #1: < 24h => must NOT be cached + var rawShort1 = CreateRawCertForCsrKeyWithCnDc( + cn: (idKind == UserAssignedIdentityId.None ? Constants.ManagedIdentityDefaultClientId : TestConstants.ClientId), + dc: TestConstants.TenantId, + notAfterUtc: DateTimeOffset.UtcNow.AddHours(23)); + + // short-lived cert #2: also < 24h (ensures new thumbprint on re-mint) + var rawShort2 = CreateRawCertForCsrKeyWithCnDc( + cn: (idKind == UserAssignedIdentityId.None ? Constants.ManagedIdentityDefaultClientId : TestConstants.ClientId), + dc: TestConstants.TenantId, + notAfterUtc: DateTimeOffset.UtcNow.AddHours(23).AddMinutes(5)); + + var mi = await CreateManagedIdentityAsync(httpManager, idKind, idValue, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard) + .ConfigureAwait(false); + + // FIRST acquire -> MINT with short-lived cert #1 + AddMocksToGetEntraToken(httpManager, idKind, idValue, certificateRequestCertificate: rawShort1); + + var first = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.IdentityProvider, first.AuthenticationResultMetadata.TokenSource, $"[{label}] first must mint."); + + // SECOND acquire -> FORCE REFRESH to bypass AT cache; since cert #1 wasn't cached, we must mint again. + AddMocksToGetEntraToken(httpManager, idKind, idValue, certificateRequestCertificate: rawShort2); + + var second = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithForceRefresh(true) // <-- key change + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.IdentityProvider, second.AuthenticationResultMetadata.TokenSource, $"[{label}] second must mint (no cert cache for <24h)."); + Assert.AreNotEqual(first.BindingCertificate.Thumbprint, second.BindingCertificate.Thumbprint, $"[{label}] re-mint should produce a new binding cert."); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, "UAMI-ClientId")] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, "UAMI-ObjectId")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, "UAMI-ResourceId")] + [DataRow(UserAssignedIdentityId.None, null, "SAMI")] + public async Task mTLSPop_CertAtLeast24h_IsCached_ReusedOnSecondAcquire( + UserAssignedIdentityId idKind, + string idValue, + string label) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + ManagedIdentityClient.ResetSourceForTest(); + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // NotAfter >= 24h + 1min → should be cached and reused + var rawLong = CreateRawCertForCsrKeyWithCnDc( + cn: (idKind == UserAssignedIdentityId.None ? Constants.ManagedIdentityDefaultClientId : TestConstants.ClientId), + dc: TestConstants.TenantId, + notAfterUtc: DateTimeOffset.UtcNow.AddHours(24).AddMinutes(1)); + + var mi = await CreateManagedIdentityAsync(httpManager, idKind, idValue, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard) + .ConfigureAwait(false); + + // First acquire → MINT + AddMocksToGetEntraToken(httpManager, idKind, idValue, certificateRequestCertificate: rawLong); + + var first = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.IdentityProvider, first.AuthenticationResultMetadata.TokenSource, $"[{label}] first must mint long-lived cert."); + + // Second acquire → CACHED (no /issuecredential mocks needed) + var second = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.AreEqual(TokenSource.Cache, second.AuthenticationResultMetadata.TokenSource, $"[{label}] second should be cache."); + Assert.AreEqual(first.BindingCertificate.Thumbprint, second.BindingCertificate.Thumbprint, $"[{label}] cached cert must be reused."); + } + } + #endregion + + #region Cert cache test helpers + + // Build a base64 DER cert (public part only) whose public key == the CSR key used by tests + private static string CreateRawCertForCsrKey(string subjectCN, DateTimeOffset notAfter) + { + using var rsa = TestCsrFactory.CreateMockRsa(); // same key the CSR factory uses + return CreateRawCertFromKey(rsa, subjectCN, notAfter); + } + + // Build a base64 DER cert (public part only) with Subject "CN=, DC=" and CSR key + private static string CreateRawCertForCsrKeyWithCnDc(string cn, string dc, DateTimeOffset notAfterUtc) + { + using var rsa = TestCsrFactory.CreateMockRsa(); + var subject = $"CN={cn}, DC={dc}"; + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName(subject), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + var notBefore = DateTimeOffset.UtcNow.AddMinutes(-2); + using var cert = req.CreateSelfSigned(notBefore, notAfterUtc); + return Convert.ToBase64String(cert.Export(X509ContentType.Cert)); + } + + private static string CreateRawCertFromKey(RSA key, string subjectCN, DateTimeOffset notAfter) + { + var now = DateTimeOffset.UtcNow.AddMinutes(-2); + + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName(subjectCN), + key, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + using var cert = req.CreateSelfSigned(now, notAfter); + // Return public portion only; the product code attaches the private key + return Convert.ToBase64String(cert.Export(X509ContentType.Cert)); + } + + private static RSA RsaFromXml(string xml) + { + var rsa = RSA.Create(); + + var settings = new XmlReaderSettings + { + DtdProcessing = DtdProcessing.Prohibit, + XmlResolver = null + }; + + var doc = new XmlDocument { XmlResolver = null }; + using (var sr = new StringReader(xml)) + using (var xr = XmlReader.Create(sr, settings)) + { + doc.Load(xr); + } + + byte[] B64(string s) => Convert.FromBase64String(s); + + var p = new RSAParameters + { + Modulus = B64(doc.DocumentElement["Modulus"].InnerText), + Exponent = B64(doc.DocumentElement["Exponent"].InnerText), + P = B64(doc.DocumentElement["P"].InnerText), + Q = B64(doc.DocumentElement["Q"].InnerText), + DP = B64(doc.DocumentElement["DP"].InnerText), + DQ = B64(doc.DocumentElement["DQ"].InnerText), + InverseQ = B64(doc.DocumentElement["InverseQ"].InnerText), + D = B64(doc.DocumentElement["D"].InnerText), + }; + + rsa.ImportParameters(p); + return rsa; + } + + private static string CreateRawCertFromXml(string subjectCN, DateTimeOffset notAfterUtc) + { + using var rsa = RsaFromXml(TestConstants.XmlPrivateKey); // same RSA as CSR/keyguard in tests + + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName(subjectCN), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + var notBefore = DateTimeOffset.UtcNow.AddMinutes(-2); + using var cert = req.CreateSelfSigned(notBefore, notAfterUtc); + + // IMPORTANT: return **public part only** – product code attaches the private key + return Convert.ToBase64String(cert.Export(X509ContentType.Cert)); + } + + private static void AssertCertCN(X509Certificate2 cert, string expectedCn) + { + // SimpleName returns the CN without the "CN=" prefix + var cn = cert.GetNameInfo(X509NameType.SimpleName, forIssuer: false); + + // Defensive fallback in case SimpleName is empty on some runtimes + if (string.IsNullOrEmpty(cn) && !string.IsNullOrEmpty(cert.Subject)) + { + var subject = cert.Subject; // e.g. "CN=SAMI-20Y" + const string cnPrefix = "CN="; + var idx = subject.IndexOf(cnPrefix, StringComparison.OrdinalIgnoreCase); + if (idx >= 0) + { + var end = subject.IndexOf(',', idx); + cn = (end > idx ? subject.Substring(idx + cnPrefix.Length, end - (idx + cnPrefix.Length)) + : subject.Substring(idx + cnPrefix.Length)).Trim(); + } + } + + Assert.AreEqual(expectedCn, cn, $"Expected CN={expectedCn}, got Subject='{cert.Subject}'."); + } + + // Parse a specific RDN (e.g., "CN" or "DC") out of the subject + private static string GetRdn(X509Certificate2 cert, string rdn) + { + var dn = cert?.SubjectName?.Name ?? string.Empty; + foreach (var part in dn.Split(',')) + { + var kv = part.Trim().Split('='); + if (kv.Length == 2 && kv[0].Trim().Equals(rdn, StringComparison.OrdinalIgnoreCase)) + return kv[1].Trim(); + } + return null; + } + + private static void AssertCertSubjectCnDc(X509Certificate2 cert, string expectedCn, string expectedDc, string label) + { + Assert.IsNotNull(cert); + var cn = GetRdn(cert, "CN"); + var dc = GetRdn(cert, "DC"); + + Assert.AreEqual(expectedCn, cn, $"[{label}] CN mismatch. Subject='{cert.Subject}'"); + Assert.AreEqual(expectedDc, dc, $"[{label}] DC mismatch. Subject='{cert.Subject}'"); + } + + #endregion } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/InMemoryCertificateCacheTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/InMemoryCertificateCacheTests.cs new file mode 100644 index 0000000000..fe7befba61 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/InMemoryCertificateCacheTests.cs @@ -0,0 +1,330 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.ManagedIdentity.V2; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + [TestClass] + public class InMemoryCertificateCacheTests + { + private static X509Certificate2 CreateSelfSignedCert(TimeSpan lifetime, string subjectCn = "CN=CacheTest") + { + using var rsa = RSA.Create(2048); + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName(subjectCn), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + // Give NotBefore a small headroom to avoid clock skew flakes + var notBefore = DateTimeOffset.UtcNow.AddMinutes(-2); + var notAfter = notBefore.Add(lifetime); + return req.CreateSelfSigned(notBefore, notAfter); + } + + [TestMethod] + public void TryGet_EmptyCache_ReturnsFalse() + { + var cache = new InMemoryCertificateCache(); + var ok = cache.TryGet("key-1", out _); + Assert.IsFalse(ok); + } + + [TestMethod] + public void Set_Then_TryGet_Hit_And_ReturnsClone() + { + var cache = new InMemoryCertificateCache(); + using var cert = CreateSelfSignedCert(TimeSpan.FromDays(2)); + const string key = "key-hit-1"; + const string ep = "https://mtls.endpoint"; + const string cid = "11111111-1111-1111-1111-111111111111"; + + cache.Set(key, new CertificateCacheValue(cert, ep, cid)); + + var ok = cache.TryGet(key, out var value); + Assert.IsTrue(ok); + Assert.IsNotNull(value.Certificate); + try + { + Assert.AreEqual(ep, value.Endpoint); + Assert.AreEqual(cid, value.ClientId); + + // Verify clone: instance is different but same thumbprint + Assert.AreNotSame(cert, value.Certificate); + Assert.AreEqual(cert.Thumbprint, value.Certificate.Thumbprint, ignoreCase: true); + } + finally + { + // Caller owns the clone returned by TryGet + value.Certificate.Dispose(); + } + } + + [TestMethod] + public void Set_Skips_When_LessThan_MinLifetime() + { + var cache = new InMemoryCertificateCache(); + + // Certificate lifetime shorter than product threshold (24h) + using var shortCert = CreateSelfSignedCert(TimeSpan.FromHours(1)); + cache.Set("short-key", new CertificateCacheValue(shortCert, "https://mtls", "client-guid")); + + var ok = cache.TryGet("short-key", out _); + Assert.IsFalse(ok, "Cache should skip certs with remaining lifetime < 24h."); + } + + [TestMethod] + public void Set_SameKey_Replaces_Previous() + { + var cache = new InMemoryCertificateCache(); + using var certA = CreateSelfSignedCert(TimeSpan.FromDays(3), "CN=A"); + using var certB = CreateSelfSignedCert(TimeSpan.FromDays(3), "CN=B"); + + const string key = "replace-key"; + cache.Set(key, new CertificateCacheValue(certA, "https://ep", "cid")); + cache.Set(key, new CertificateCacheValue(certB, "https://ep", "cid")); + + var ok = cache.TryGet(key, out var v); + Assert.IsTrue(ok); + try + { + Assert.AreEqual(certB.Thumbprint, v.Certificate.Thumbprint, ignoreCase: true, + "Newest certificate should be returned after REPLACE."); + } + finally + { + v.Certificate.Dispose(); + } + } + + [TestMethod] + public void Remove_Removes_Entry() + { + var cache = new InMemoryCertificateCache(); + using var cert = CreateSelfSignedCert(TimeSpan.FromDays(2)); + + cache.Set("k1", new CertificateCacheValue(cert, "https://ep", "cid")); + var removed = cache.Remove("k1"); + Assert.IsTrue(removed); + + var ok = cache.TryGet("k1", out _); + Assert.IsFalse(ok); + } + + [TestMethod] + public void Clear_Removes_All() + { + var cache = new InMemoryCertificateCache(); + using var c1 = CreateSelfSignedCert(TimeSpan.FromDays(2)); + using var c2 = CreateSelfSignedCert(TimeSpan.FromDays(2)); + + cache.Set("k1", new CertificateCacheValue(c1, "https://ep1", "cid1")); + cache.Set("k2", new CertificateCacheValue(c2, "https://ep2", "cid2")); + + cache.Clear(); + + Assert.IsFalse(cache.TryGet("k1", out _)); + Assert.IsFalse(cache.TryGet("k2", out _)); + } + + [TestMethod] + public void Validate_Arguments() + { + var cache = new InMemoryCertificateCache(); + using var cert = CreateSelfSignedCert(TimeSpan.FromDays(2)); + + // TryGet + Assert.ThrowsException(() => cache.TryGet(" ", out _)); + Assert.ThrowsException(() => cache.TryGet(null, out _)); + + // Set + Assert.ThrowsException(() => cache.Set(" ", new CertificateCacheValue(cert, "ep", "cid"))); + Assert.ThrowsException(() => cache.Set("k", new CertificateCacheValue(null, "ep", "cid"))); // This will throw as expected in the test + Assert.ThrowsException(() => cache.Set("k", new CertificateCacheValue(cert, " ", "cid"))); + Assert.ThrowsException(() => cache.Set("k", new CertificateCacheValue(cert, "ep", " "))); + + // Remove + Assert.ThrowsException(() => cache.Remove("")); + Assert.ThrowsException(() => cache.Remove(null)); + } + + [TestMethod] + public void Dispose_Prevents_Use() + { + var cache = new InMemoryCertificateCache(); + cache.Dispose(); + + Assert.ThrowsException(() => cache.TryGet("k", out _)); + using var cert = CreateSelfSignedCert(TimeSpan.FromDays(2)); + Assert.ThrowsException(() => cache.Set("k", new CertificateCacheValue(cert, "ep", "cid"))); + Assert.ThrowsException(() => cache.Remove("k")); + Assert.ThrowsException(() => cache.Clear()); + } + + [TestMethod] + public void Set_Skips_When_Lifetime_Equals_MinRemainingLifetime() + { + using var cert = CreateSelfSignedCert(CertificateCacheEntry.MinRemainingLifetime); + var cache = new InMemoryCertificateCache(); + cache.Set("eq-threshold", new CertificateCacheValue(cert, "https://ep", "cid")); + Assert.IsFalse(cache.TryGet("eq-threshold", out _), "Exact threshold should be skipped (<= comparison)."); + } + + [TestMethod] + public void TryGet_Evicts_When_Remaining_Lifetime_Below_Threshold() + { + using var longCert = CreateSelfSignedCert(TimeSpan.FromHours(30)); + var cache = new InMemoryCertificateCache(); + cache.Set("will-expire", new CertificateCacheValue(longCert, "https://ep", "cid")); + + // Use reflection to inject an entry whose NotAfterUtc makes remaining lifetime < MinRemainingLifetime + var field = typeof(InMemoryCertificateCache).GetField("_entriesByCacheKey", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var dict = (ConcurrentDictionary)field.GetValue(cache); + + // Remaining lifetime 23h (< 24h) triggers expiration on access + var expiredEntry = new CertificateCacheEntry( + certificate: new X509Certificate2(longCert), + notAfterUtc: DateTimeOffset.UtcNow.AddHours(23), + endpoint: "https://ep", + clientId: "cid"); + + dict["will-expire"] = expiredEntry; + + var ok = cache.TryGet("will-expire", out _); + Assert.IsFalse(ok, "Entry should be evicted when remaining lifetime drops below threshold."); + } + } + + [TestClass] + public class CertificateCacheEntryTests + { + private static X509Certificate2 MakeCert(TimeSpan lifetime) + { + using var rsa = RSA.Create(2048); + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName("CN=EntryTest"), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + var notBefore = DateTimeOffset.UtcNow.AddMinutes(-2); + var notAfter = notBefore.Add(lifetime); + return req.CreateSelfSigned(notBefore, notAfter); + } + + [TestMethod] + public void IsExpiredUtc_Boundary() + { + // Choose a lifetime > MinRemainingLifetime (24h) so boundary is in the future. + using var cert = MakeCert(TimeSpan.FromHours(26)); + var notAfterUtc = cert.NotAfter.ToUniversalTime(); + + var entry = new CertificateCacheEntry( + certificate: new X509Certificate2(cert), + notAfterUtc: new DateTimeOffset(notAfterUtc), + endpoint: "https://ep", + clientId: "cid"); + + var boundary = new DateTimeOffset(notAfterUtc).Add(-CertificateCacheEntry.MinRemainingLifetime); + + // Just before boundary -> not expired + Assert.IsFalse(entry.IsExpiredUtc(boundary.AddSeconds(-1))); + + // At boundary -> expired + Assert.IsTrue(entry.IsExpiredUtc(boundary)); + + // After boundary -> expired + Assert.IsTrue(entry.IsExpiredUtc(boundary.AddMinutes(1))); + } + + [TestMethod] + public void Dispose_IsIdempotent_SetsFlag() + { + using var cert = MakeCert(TimeSpan.FromDays(2)); + var entry = new CertificateCacheEntry( + certificate: new X509Certificate2(cert), + notAfterUtc: DateTimeOffset.UtcNow.AddDays(2), + endpoint: "https://ep", + clientId: "cid"); + + Assert.IsFalse(entry.IsDisposed); + entry.Dispose(); + Assert.IsTrue(entry.IsDisposed); + + // No throw on second dispose + entry.Dispose(); + Assert.IsTrue(entry.IsDisposed); + } + + [TestMethod] + public void IsExpiredUtc_WellBeforeBoundary_NotExpired() + { + using var cert = MakeCert(TimeSpan.FromHours(48)); + var entry = new CertificateCacheEntry( + certificate: new X509Certificate2(cert), + notAfterUtc: cert.NotAfter.ToUniversalTime(), + endpoint: "https://ep", + clientId: "cid"); + + // Far before boundary + Assert.IsFalse(entry.IsExpiredUtc(DateTimeOffset.UtcNow)); + } + } + + [TestClass] + public class CertificateCacheValueTests + { + private static X509Certificate2 MakeCert() + { + using var rsa = RSA.Create(2048); + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName("CN=ValTest"), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + var notBefore = DateTimeOffset.UtcNow.AddMinutes(-1); + var notAfter = notBefore.AddDays(2); + return req.CreateSelfSigned(notBefore, notAfter); + } + + [TestMethod] + public void Ctor_Throws_On_Nulls() + { + // certificate null + Assert.ThrowsException(() => + new CertificateCacheValue(null, "ep", "cid")); + + using var cert = MakeCert(); + + // endpoint null + Assert.ThrowsException(() => + new CertificateCacheValue(cert, null, "cid")); + + // clientId null + Assert.ThrowsException(() => + new CertificateCacheValue(cert, "ep", null)); + } + + [TestMethod] + public void Properties_Are_Immutable_And_Preserved() + { + using var cert = MakeCert(); + + var value = new CertificateCacheValue(cert, "https://ep", "cid"); + Assert.AreEqual("https://ep", value.Endpoint); + Assert.AreEqual("cid", value.ClientId); + Assert.AreEqual(cert.Thumbprint, value.Certificate.Thumbprint, ignoreCase: true); + + // Caller should dispose when done + value.Certificate.Dispose(); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs index ccd522e1fe..723ff5ca9a 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; @@ -18,10 +17,29 @@ namespace Microsoft.Identity.Test.Common.Core.Mocks /// internal sealed class TestKeyGuardManagedIdentityKeyProvider : IManagedIdentityKeyProvider { - public Task GetOrCreateKeyAsync(ILoggerAdapter logger, CancellationToken cancellationToken) + // Keep a single ManagedIdentityKeyInfo per provider instance + private readonly ManagedIdentityKeyInfo _keyInfo; + + /// + /// Creates a provider with a fresh 2048-bit RSACng key. + /// + public TestKeyGuardManagedIdentityKeyProvider() + : this(new RSACng(2048)) + { } + + /// + /// Creates a provider that will always return the supplied RSACng key. + /// Useful when you want two identities with different, fixed keys. + /// + public TestKeyGuardManagedIdentityKeyProvider(RSACng fixedKey) { - var rsacng = new RSACng(2048); - return Task.FromResult(new ManagedIdentityKeyInfo(rsacng, ManagedIdentityKeyType.KeyGuard, "Test KeyGuard Provider")); + _keyInfo = new ManagedIdentityKeyInfo( + fixedKey, + ManagedIdentityKeyType.KeyGuard, + "Test KeyGuard Provider (fixed)"); } + + public Task GetOrCreateKeyAsync(ILoggerAdapter logger, CancellationToken cancellationToken) + => Task.FromResult(_keyInfo); } } diff --git a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs index f9f72091a9..68eade0c26 100644 --- a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs +++ b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs @@ -1,61 +1,414 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#nullable enable +using System; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.IdentityModel.Abstractions; using Microsoft.Identity.Client.MtlsPop; -IIdentityLogger identityLogger = new IdentityLogger(); +internal class Program +{ + // App state + private static IManagedIdentityApplication s_miApp = null!; + private static ManagedIdentityId s_currentMiId = ManagedIdentityId.SystemAssigned; + private static string s_identityLabel = "SAMI"; -IManagedIdentityApplication mi = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithLogging(identityLogger, true) - .Build(); + // Defaults + private static string s_resource = "https://graph.microsoft.com"; + private const string DefaultUamiClientId = "209b9435-3a7d-4967-8647-52c648d6f67f"; + private const string DefaultUamiObjectId = "981430e1-6890-498c-b882-7f7a0cf853fe"; + private const string DefaultUamiResourceId = + "/subscriptions/ff71c235-108e-4869-9779-5f275ce45c44/resourcegroups/nbhargava/providers/Microsoft.ManagedIdentity/userAssignedIdentities/nidhi_uai_centraluseuap"; -string? scope = "https://management.azure.com"; + // Session flags + private static bool s_forceRefresh = false; // persistent toggle + private static bool s_forceRefreshNext = false; // one-shot bypass + private static bool s_lastWasBound = false; + private static bool s_hasLast = false; -do -{ - Console.WriteLine($"Acquiring token with scope {scope}"); + // Toggle for full token printing (default OFF) + private static bool s_printFullToken = false; - try + private static async Task Main() { - var result = await mi.AcquireTokenForManagedIdentity(scope) - .WithMtlsProofOfPossession() - .ExecuteAsync().ConfigureAwait(false); + Console.Clear(); + WriteTitle("Managed Identity Token Tester"); + + // Resource first (keeps intuitive default) + s_resource = NormalizeResource(Ask($"Resource", defaultValue: s_resource)); + + var logger = new IdentityLogger(); + BuildMiApp(ManagedIdentityId.SystemAssigned, "SAMI", logger); + + PrintHelp(); + + while (true) + { + DrawStatusBar(); + WriteMenu(); + + Console.Write("\nChoice (A/P/T/B/F/I/R/H/Enter/X): "); + var key = Console.ReadKey(intercept: true); + + if (key.Key == ConsoleKey.Enter) + { + if (!s_hasLast) + { + Info("No previous acquisition yet. Press A or P first."); + continue; + } + await AcquireAndReportAsync(bound: s_lastWasBound, forceRefresh: s_forceRefresh).ConfigureAwait(false); + continue; + } + + var ch = char.ToUpperInvariant(key.KeyChar); + Console.WriteLine(ch == '\0' ? "" : ch.ToString()); + + try + { + switch (ch) + { + case 'A': // Bearer + await AcquireAndReportAsync(bound: false, forceRefresh: s_forceRefresh).ConfigureAwait(false); + s_lastWasBound = false; + s_hasLast = true; + break; + + case 'P': // PoP (mTLS-bound) + await AcquireAndReportAsync(bound: true, forceRefresh: s_forceRefresh).ConfigureAwait(false); + s_lastWasBound = true; + s_hasLast = true; + break; + + case 'T': // Toggle persistent Force Refresh + s_forceRefresh = !s_forceRefresh; + Info($"Force Refresh is now {(s_forceRefresh ? "ON" : "OFF")}."); + break; + + case 'B': // One-shot Force Refresh (next acquisition only) + s_forceRefreshNext = true; + Info("Next acquisition will bypass cache (one-shot)."); + break; + + case 'F': // Toggle full token printing + s_printFullToken = !s_printFullToken; + Info($"Full Token Print is now {(s_printFullToken ? "ON" : "OFF")}."); + if (s_printFullToken) + { + Error("WARNING: Full token printing exposes secrets. Do NOT use in shared terminals or logs."); + } + break; + + case 'I': // Identity + await SwitchIdentityAsync(logger).ConfigureAwait(false); + break; + + case 'R': // Resource + var newRes = Ask("New resource (blank=keep)", allowEmpty: true); + if (!string.IsNullOrWhiteSpace(newRes)) + { + s_resource = NormalizeResource(newRes); + Success($"Resource set to {s_resource}"); + } + break; + + case 'H': + PrintHelp(); + break; - Console.WriteLine("Success"); - Console.ReadLine(); + case 'X': + Console.WriteLine("\nGoodbye."); + return; + + default: + Info("Unknown choice. Press H for help."); + break; + } + } + catch (MsalServiceException ex) + { + Error("[MSAL Service Error]"); + Console.WriteLine($" Code : {ex.ErrorCode}"); + Console.WriteLine($" Message: {ex.Message}"); + } + catch (Exception ex) + { + Error("[Unexpected Error]"); + Console.WriteLine(ex.ToString()); + } + } } - catch (MsalServiceException e) + + // === UX === + + private static void WriteTitle(string text) { - Console.WriteLine(e.ErrorCode); - Console.WriteLine(e.Message); - Console.WriteLine(e.StackTrace); - Console.ReadLine(); + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine(text); + Console.ForegroundColor = old; + Console.WriteLine(); } - Console.WriteLine("Enter the scope to acquire token."); - scope = Console.ReadLine(); -} while (!string.IsNullOrEmpty(scope)); + private static void DrawStatusBar() + { + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.Cyan; + Console.WriteLine("\n================ STATUS ================"); + Console.WriteLine($"Resource : {s_resource}"); + Console.WriteLine($"Identity : {s_identityLabel}"); + Console.WriteLine($"Mode : {(s_lastWasBound ? "Bound (mTLS PoP)" : "Bearer")}, ForceRefresh={(s_forceRefresh ? "ON" : "OFF")}, NextBypass={(s_forceRefreshNext ? "ON" : "OFF")}"); + Console.WriteLine($"Secrets : FullTokenPrint={(s_printFullToken ? "ON" : "OFF")}"); + Console.WriteLine("=======================================\n"); + Console.ForegroundColor = old; + } -class IdentityLogger : IIdentityLogger -{ - public EventLogLevel MinLogLevel { get; } + private static void WriteMenu() + { + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.Green; + Console.WriteLine("A) Acquire Bearer token"); + Console.WriteLine("P) Acquire Bound token (mTLS PoP)"); + Console.WriteLine("T) Toggle Force Refresh (persistent)"); + Console.WriteLine("B) Bypass cache on NEXT acquisition (one-shot)"); + Console.WriteLine("F) Toggle Full Token Print"); + Console.WriteLine("I) Switch Identity (SAMI/UAMI presets)"); + Console.WriteLine("R) Change Resource"); + Console.WriteLine("H) Help"); + Console.WriteLine("Enter) Repeat last acquisition"); + Console.WriteLine("X) Exit"); + Console.ForegroundColor = old; + } - public IdentityLogger() + private static void PrintHelp() { - MinLogLevel = EventLogLevel.Verbose; + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.DarkGray; + Console.WriteLine("Help:"); + Console.WriteLine(" - A: Get a Bearer token for the current resource."); + Console.WriteLine(" - P: Get a Bound (mTLS PoP) token (only useful if the API accepts cert-bound tokens)."); + Console.WriteLine(" - T: Toggle Force Refresh to bypass cache on every acquisition until turned off."); + Console.WriteLine(" - B: One-shot bypass; only the NEXT acquisition will bypass cache, then it resets."); + Console.WriteLine(" - F: Toggle Full Token Print. WARNING: prints the entire token (secret)."); + Console.WriteLine(" - I: Switch identity. Options:"); + Console.WriteLine(" 1) SAMI"); + Console.WriteLine(" 2) UAMI (ClientId) [default: 209b9435-...-52c648d6f67f]"); + Console.WriteLine(" 3) UAMI (ResourceId) [default: …/nidhi_uai_centraluseuap]"); + Console.WriteLine(" 4) UAMI (ObjectId) [default: 981430e1-...-7f7a0cf853fe]"); + Console.WriteLine(" - R: Change the resource (accepts `/.default`, will normalize)."); + Console.WriteLine(" - Enter: Repeat the last acquisition (great for cache testing)."); + Console.ForegroundColor = old; + Console.WriteLine(); } - public bool IsEnabled(EventLogLevel eventLogLevel) + private static void Success(string msg) { - return eventLogLevel <= MinLogLevel; + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.Green; + Console.WriteLine(msg); + Console.ForegroundColor = old; + } + private static void Info(string msg) + { + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.DarkCyan; + Console.WriteLine(msg); + Console.ForegroundColor = old; + } + private static void Error(string msg) + { + var old = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine(msg); + Console.ForegroundColor = old; } - public void Log(LogEntry entry) + // === Identity + Acquire === + + private static async Task SwitchIdentityAsync(IIdentityLogger logger) { - //Log Message here: - Console.WriteLine(entry.Message); + Console.WriteLine(); + Console.WriteLine("Identities:"); + Console.WriteLine(" 1) SAMI"); + Console.WriteLine(" 2) UAMI (ClientId)"); + Console.WriteLine(" 3) UAMI (ResourceId)"); + Console.WriteLine(" 4) UAMI (ObjectId)"); + + var pick = Ask("Selection", defaultValue: "1"); + switch (pick) + { + case "1": + BuildMiApp(ManagedIdentityId.SystemAssigned, "SAMI", logger); + break; + + case "2": + { + var id = Ask("UAMI ClientId (GUID)", defaultValue: DefaultUamiClientId); + BuildMiApp(ManagedIdentityId.WithUserAssignedClientId(id), $"UAMI (ClientId={Short(id)})", logger); + } + break; + + case "3": + { + var rid = Ask("UAMI ResourceId", defaultValue: DefaultUamiResourceId); + BuildMiApp(ManagedIdentityId.WithUserAssignedResourceId(rid), "UAMI (ResourceId=…/nidhi_uai_centraluseuap)", logger); + } + break; + + case "4": + { + var oid = Ask("UAMI ObjectId (GUID)", defaultValue: DefaultUamiObjectId); + BuildMiApp(ManagedIdentityId.WithUserAssignedObjectId(oid), $"UAMI (ObjectId={Short(oid)})", logger); + } + break; + + default: + Info("Unknown selection. Keeping current identity."); + break; + } + await Task.CompletedTask.ConfigureAwait(false); + } + + private static void BuildMiApp(ManagedIdentityId miId, string label, IIdentityLogger logger) + { + s_currentMiId = miId; + s_identityLabel = label; + s_miApp = ManagedIdentityApplicationBuilder + .Create(miId) + .WithLogging(logger, enablePiiLogging: true) + .Build(); + Success($"Identity set to {s_identityLabel}"); } + + private static async Task AcquireAndReportAsync(bool bound, bool forceRefresh) + { + // Compute effective force refresh (global OR one-shot), then consume one-shot + bool effectiveForceRefresh = forceRefresh || s_forceRefreshNext; + s_forceRefreshNext = false; // consume one-shot if it was set + + Info($"\nAcquiring {(bound ? "BOUND (mTLS PoP)" : "BEARER")} token for {s_resource} ..."); + + var builder = s_miApp.AcquireTokenForManagedIdentity(s_resource); + if (bound) + builder = builder.WithMtlsProofOfPossession(); + if (effectiveForceRefresh) + builder = builder.WithForceRefresh(true); + + var result = await builder.ExecuteAsync().ConfigureAwait(false); + + Success("Success!"); + var source = result.AuthenticationResultMetadata?.TokenSource.ToString() ?? "Unknown"; + Console.WriteLine($" Token source : {source}"); + Console.WriteLine($" Expires On : {result.ExpiresOn.UtcDateTime:O} (UTC)"); + Console.WriteLine($" Token type : {(bound ? "Bound (mTLS PoP)" : "Bearer")}"); + + if (s_printFullToken) + { + Error(" NOTE: Printing full token (secret) per toggle!"); + Console.WriteLine($" Access Token : {result.AccessToken}"); + } + else + { + var preview = result.AccessToken?.Length > 32 ? result.AccessToken[..32] + "..." : result.AccessToken; + Console.WriteLine($" Access Token : {preview}"); + } + + if (bound && !string.IsNullOrEmpty(result.AccessToken)) + { + var cnf = TryGetCnfClaim(result.AccessToken); + if (cnf is not null) + { + Console.WriteLine(" PoP cnf :"); + if (cnf.Value.TryGetProperty("x5t#S256", out var x5t)) + Console.WriteLine($" x5t#S256 : {x5t.GetString()}"); + if (cnf.Value.TryGetProperty("kid", out var kid)) + Console.WriteLine($" kid : {kid.GetString()}"); + if (cnf.Value.TryGetProperty("xms_mirid", out var mirid)) + Console.WriteLine($" xms_mirid : {mirid.GetString()}"); + } + else + { + Console.WriteLine(" PoP cnf : (not present)"); + } + } + } + + // === Helpers === + + private static string Ask(string prompt, string? defaultValue = null, bool allowEmpty = false) + { + if (defaultValue is null) + Console.Write($"{prompt}: "); + else + Console.Write($"{prompt} [{defaultValue}]: "); + + while (true) + { + var s = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(s)) + { + if (defaultValue != null) + return defaultValue; + if (allowEmpty) + return string.Empty; + } + else + { + return s.Trim(); + } + Console.Write("Please enter a value: "); + } + } + + private static string NormalizeResource(string input) + { + var s = input.Trim(); + if (s.EndsWith("/.default", StringComparison.OrdinalIgnoreCase)) + s = s[..^"/.default".Length]; + if (s.EndsWith("/")) + s = s.TrimEnd('/'); + return s; + } + + private static string Short(string id) => id.Length <= 8 ? id : id[..8]; + + private static JsonElement? TryGetCnfClaim(string jwt) + { + var parts = jwt.Split('.'); + if (parts.Length < 2) + return null; + var payloadJson = Encoding.UTF8.GetString(Base64UrlDecode(parts[1])); + using var doc = JsonDocument.Parse(payloadJson); + if (!doc.RootElement.TryGetProperty("cnf", out var cnfEl)) + return null; + using var cnfDoc = JsonDocument.Parse(cnfEl.GetRawText()); + return cnfDoc.RootElement.Clone(); + } + + private static byte[] Base64UrlDecode(string s) + { + s = s.Replace('-', '+').Replace('_', '/'); + switch (s.Length % 4) + { + case 2: + s += "=="; + break; + case 3: + s += "="; + break; + } + return Convert.FromBase64String(s); + } +} + +class IdentityLogger : IIdentityLogger +{ + public EventLogLevel MinLogLevel { get; } = EventLogLevel.Verbose; + public bool IsEnabled(EventLogLevel eventLogLevel) => eventLogLevel <= MinLogLevel; + public void Log(LogEntry entry) => Console.WriteLine($"[MSAL] {entry.Message}"); }