Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import {
discoverOAuthProtectedResourceMetadata,
extractResourceMetadataUrl,
auth,
type OAuthClientProvider
type OAuthClientProvider,
selectClientAuthMethod
} from './auth.js';
import { ServerError } from '../server/auth/errors.js';
import { AuthorizationServerMetadata } from '../shared/auth.js';
Expand Down Expand Up @@ -881,6 +882,25 @@ describe('OAuth Authorization', () => {
});
});

describe('selectClientAuthMethod', () => {
it('selects the correct client authentication method from client information', () => {
const clientInfo = {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
token_endpoint_auth_method: 'client_secret_basic'
};
const supportedMethods = ['client_secret_post', 'client_secret_basic', 'none'];
const authMethod = selectClientAuthMethod(clientInfo, supportedMethods);
expect(authMethod).toBe('client_secret_basic');
});
it('selects the correct client authentication method from supported methods', () => {
const clientInfo = { client_id: 'test-client-id' };
const supportedMethods = ['client_secret_post', 'client_secret_basic', 'none'];
const authMethod = selectClientAuthMethod(clientInfo, supportedMethods);
expect(authMethod).toBe('none');
});
});

describe('startAuthorization', () => {
const validMetadata = {
issuer: 'https://auth.example.com',
Expand Down
27 changes: 21 additions & 6 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { LATEST_PROTOCOL_VERSION } from '../types.js';
import {
OAuthClientMetadata,
OAuthClientInformation,
OAuthClientInformationMixed,
OAuthTokens,
OAuthMetadata,
OAuthClientInformationFull,
Expand Down Expand Up @@ -56,7 +57,7 @@ export interface OAuthClientProvider {
* server, or returns `undefined` if the client is not registered with the
* server.
*/
clientInformation(): OAuthClientInformation | undefined | Promise<OAuthClientInformation | undefined>;
clientInformation(): OAuthClientInformationMixed | undefined | Promise<OAuthClientInformationMixed | undefined>;

/**
* If implemented, this permits the OAuth client to dynamically register with
Expand All @@ -66,7 +67,7 @@ export interface OAuthClientProvider {
* This method is not required to be implemented if client information is
* statically known (e.g., pre-registered).
*/
saveClientInformation?(clientInformation: OAuthClientInformationFull): void | Promise<void>;
saveClientInformation?(clientInformation: OAuthClientInformationMixed): void | Promise<void>;

/**
* Loads any existing OAuth tokens for the current session, or returns
Expand Down Expand Up @@ -149,6 +150,10 @@ export class UnauthorizedError extends Error {

type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none';

function isClientAuthMethod(method: string): method is ClientAuthMethod {
return ['client_secret_basic', 'client_secret_post', 'none'].includes(method);
}

const AUTHORIZATION_CODE_RESPONSE_TYPE = 'code';
const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256';

Expand All @@ -164,14 +169,24 @@ const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256';
* @param supportedMethods - Authentication methods supported by the authorization server
* @returns The selected authentication method
*/
function selectClientAuthMethod(clientInformation: OAuthClientInformation, supportedMethods: string[]): ClientAuthMethod {
export function selectClientAuthMethod(clientInformation: OAuthClientInformationMixed, supportedMethods: string[]): ClientAuthMethod {
const hasClientSecret = clientInformation.client_secret !== undefined;

// If server doesn't specify supported methods, use RFC 6749 defaults
if (supportedMethods.length === 0) {
return hasClientSecret ? 'client_secret_post' : 'none';
}

// Prefer the method returned by the server during client registration if valid and supported
if (
'token_endpoint_auth_method' in clientInformation &&
clientInformation.token_endpoint_auth_method &&
isClientAuthMethod(clientInformation.token_endpoint_auth_method) &&
supportedMethods.includes(clientInformation.token_endpoint_auth_method)
) {
return clientInformation.token_endpoint_auth_method;
}

// Try methods in priority order (most secure first)
if (hasClientSecret && supportedMethods.includes('client_secret_basic')) {
return 'client_secret_basic';
Expand Down Expand Up @@ -793,7 +808,7 @@ export async function startAuthorization(
resource
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformation;
clientInformation: OAuthClientInformationMixed;
redirectUrl: string | URL;
scope?: string;
state?: string;
Expand Down Expand Up @@ -876,7 +891,7 @@ export async function exchangeAuthorization(
fetchFn
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformation;
clientInformation: OAuthClientInformationMixed;
authorizationCode: string;
codeVerifier: string;
redirectUri: string | URL;
Expand Down Expand Up @@ -955,7 +970,7 @@ export async function refreshAuthorization(
fetchFn
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformation;
clientInformation: OAuthClientInformationMixed;
refreshToken: string;
resource?: URL;
addClientAuthentication?: OAuthClientProvider['addClientAuthentication'];
Expand Down
8 changes: 4 additions & 4 deletions src/examples/client/simpleOAuthClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { URL } from 'node:url';
import { exec } from 'node:child_process';
import { Client } from '../../client/index.js';
import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js';
import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js';
import { OAuthClientInformationMixed, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js';
import { CallToolRequest, ListToolsRequest, CallToolResultSchema, ListToolsResultSchema } from '../../types.js';
import { OAuthClientProvider, UnauthorizedError } from '../../client/auth.js';

Expand All @@ -20,7 +20,7 @@ const CALLBACK_URL = `http://localhost:${CALLBACK_PORT}/callback`;
* In production, you should persist tokens securely
*/
class InMemoryOAuthClientProvider implements OAuthClientProvider {
private _clientInformation?: OAuthClientInformationFull;
private _clientInformation?: OAuthClientInformationMixed;
private _tokens?: OAuthTokens;
private _codeVerifier?: string;

Expand All @@ -46,11 +46,11 @@ class InMemoryOAuthClientProvider implements OAuthClientProvider {
return this._clientMetadata;
}

clientInformation(): OAuthClientInformation | undefined {
clientInformation(): OAuthClientInformationMixed | undefined {
return this._clientInformation;
}

saveClientInformation(clientInformation: OAuthClientInformationFull): void {
saveClientInformation(clientInformation: OAuthClientInformationMixed): void {
this._clientInformation = clientInformation;
}

Expand Down
1 change: 1 addition & 0 deletions src/shared/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ export type OAuthErrorResponse = z.infer<typeof OAuthErrorResponseSchema>;
export type OAuthClientMetadata = z.infer<typeof OAuthClientMetadataSchema>;
export type OAuthClientInformation = z.infer<typeof OAuthClientInformationSchema>;
export type OAuthClientInformationFull = z.infer<typeof OAuthClientInformationFullSchema>;
export type OAuthClientInformationMixed = OAuthClientInformation | OAuthClientInformationFull;
export type OAuthClientRegistrationError = z.infer<typeof OAuthClientRegistrationErrorSchema>;
export type OAuthTokenRevocationRequest = z.infer<typeof OAuthTokenRevocationRequestSchema>;
export type OAuthProtectedResourceMetadata = z.infer<typeof OAuthProtectedResourceMetadataSchema>;
Expand Down