From ffa397e36d8bb01421b14b2ffde534c6eff02920 Mon Sep 17 00:00:00 2001 From: yao <63141491+yaonyan@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:33:54 +0800 Subject: [PATCH 1/5] feat(ai): implement experimental_prepareStep for dynamic step configuration - Add PrepareStepFunction and PrepareStepResult types for step-by-step customization - Enable dynamic model, toolChoice, and activeTools modification per step - Support system prompt and messages override during multi-step execution - Implement experimental_prepareStep parameter in generateText and streamText - Add comprehensive test coverage for experimental_prepareStep functionality - Fix stepNumber calculation bug in stream-text.ts (use currentStep instead of recordedSteps.length) --- .../__snapshots__/stream-text.test.ts.snap | 507 ++++++++++++++++++ .../ai/core/generate-text/generate-text.ts | 32 +- .../ai/core/generate-text/prepare-step.ts | 39 ++ .../ai/core/generate-text/stream-text.test.ts | 243 +++++++++ packages/ai/core/generate-text/stream-text.ts | 50 +- 5 files changed, 847 insertions(+), 24 deletions(-) create mode 100644 packages/ai/core/generate-text/prepare-step.ts diff --git a/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap b/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap index 9cb9d8ce1baa..426167333992 100644 --- a/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap +++ b/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap @@ -931,6 +931,513 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > value p ] `; +exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with experimental_prepareStep > callbacks > onFinish should send correct information 1`] = ` +{ + "experimental_providerMetadata": undefined, + "files": [], + "finishReason": "stop", + "logprobs": undefined, + "providerMetadata": undefined, + "reasoning": undefined, + "reasoningDetails": [], + "request": {}, + "response": { + "headers": undefined, + "id": "id-1", + "messages": [ + { + "content": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "id": "msg-0", + "role": "assistant", + }, + { + "content": [ + { + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "id": "msg-3", + "role": "tool", + }, + { + "content": [ + { + "text": "Hello, world!", + "type": "text", + }, + ], + "id": "msg-2", + "role": "assistant", + }, + ], + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:10.000Z, + }, + "sources": [], + "steps": [ + { + "experimental_providerMetadata": undefined, + "files": [], + "finishReason": "tool-calls", + "isContinued": false, + "logprobs": undefined, + "providerMetadata": undefined, + "reasoning": undefined, + "reasoningDetails": [], + "request": {}, + "response": { + "headers": undefined, + "id": "id-0", + "messages": [ + { + "content": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "id": "msg-0", + "role": "assistant", + }, + { + "content": [ + { + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "id": "msg-3", + "role": "tool", + }, + ], + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:00.000Z, + }, + "sources": [], + "stepType": "initial", + "text": "", + "toolCalls": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "toolResults": [ + { + "args": { + "value": "value", + }, + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "usage": { + "completionTokens": 5, + "promptTokens": 10, + "totalTokens": 15, + }, + "warnings": undefined, + }, + { + "experimental_providerMetadata": undefined, + "files": [], + "finishReason": "stop", + "isContinued": false, + "logprobs": undefined, + "providerMetadata": undefined, + "reasoning": undefined, + "reasoningDetails": [], + "request": {}, + "response": { + "headers": undefined, + "id": "id-1", + "messages": [ + { + "content": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "id": "msg-0", + "role": "assistant", + }, + { + "content": [ + { + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "id": "msg-3", + "role": "tool", + }, + { + "content": [ + { + "text": "Hello, world!", + "type": "text", + }, + ], + "id": "msg-2", + "role": "assistant", + }, + ], + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:10.000Z, + }, + "sources": [], + "stepType": "tool-result", + "text": "Hello, world!", + "toolCalls": [], + "toolResults": [], + "usage": { + "completionTokens": 10, + "promptTokens": 5, + "totalTokens": 15, + }, + "warnings": undefined, + }, + ], + "text": "Hello, world!", + "toolCalls": [], + "toolResults": [], + "usage": { + "completionTokens": 15, + "promptTokens": 15, + "totalTokens": 30, + }, + "warnings": undefined, +} +`; + +exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with experimental_prepareStep > callbacks > onStepFinish should send correct information 1`] = `[]`; + +exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with experimental_prepareStep > should contain assistant response message and tool message from all steps 1`] = ` +[ + { + "messageId": "msg-0", + "request": {}, + "type": "step-start", + "warnings": [], + }, + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + { + "args": { + "value": "value", + }, + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + { + "experimental_providerMetadata": undefined, + "finishReason": "tool-calls", + "isContinued": false, + "logprobs": undefined, + "messageId": "msg-0", + "providerMetadata": undefined, + "request": {}, + "response": { + "headers": undefined, + "id": "id-0", + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:00.000Z, + }, + "type": "step-finish", + "usage": { + "completionTokens": 5, + "promptTokens": 10, + "totalTokens": 15, + }, + "warnings": undefined, + }, + { + "messageId": "msg-2", + "request": {}, + "type": "step-start", + "warnings": [], + }, + { + "textDelta": "Hello", + "type": "text-delta", + }, + { + "textDelta": ", ", + "type": "text-delta", + }, + { + "textDelta": "world!", + "type": "text-delta", + }, + { + "experimental_providerMetadata": undefined, + "finishReason": "stop", + "isContinued": false, + "logprobs": undefined, + "messageId": "msg-2", + "providerMetadata": undefined, + "request": {}, + "response": { + "headers": undefined, + "id": "id-1", + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:10.000Z, + }, + "type": "step-finish", + "usage": { + "completionTokens": 10, + "promptTokens": 5, + "totalTokens": 15, + }, + "warnings": undefined, + }, + { + "experimental_providerMetadata": undefined, + "finishReason": "stop", + "logprobs": undefined, + "providerMetadata": undefined, + "response": { + "headers": undefined, + "id": "id-1", + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:10.000Z, + }, + "type": "finish", + "usage": { + "completionTokens": 15, + "promptTokens": 15, + "totalTokens": 30, + }, + }, +] +`; + +exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with experimental_prepareStep > value promises > result.response.messages should contain response messages from all steps 1`] = ` +[ + { + "content": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "id": "msg-0", + "role": "assistant", + }, + { + "content": [ + { + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "id": "msg-3", + "role": "tool", + }, + { + "content": [ + { + "text": "Hello, world!", + "type": "text", + }, + ], + "id": "msg-2", + "role": "assistant", + }, +] +`; + +exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with experimental_prepareStep > value promises > result.steps should contain all steps 1`] = ` +[ + { + "experimental_providerMetadata": undefined, + "files": [], + "finishReason": "tool-calls", + "isContinued": false, + "logprobs": undefined, + "providerMetadata": undefined, + "reasoning": undefined, + "reasoningDetails": [], + "request": {}, + "response": { + "headers": undefined, + "id": "id-0", + "messages": [ + { + "content": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "id": "msg-0", + "role": "assistant", + }, + { + "content": [ + { + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "id": "msg-3", + "role": "tool", + }, + ], + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:00.000Z, + }, + "sources": [], + "stepType": "initial", + "text": "", + "toolCalls": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "toolResults": [ + { + "args": { + "value": "value", + }, + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "usage": { + "completionTokens": 5, + "promptTokens": 10, + "totalTokens": 15, + }, + "warnings": undefined, + }, + { + "experimental_providerMetadata": undefined, + "files": [], + "finishReason": "stop", + "isContinued": false, + "logprobs": undefined, + "providerMetadata": undefined, + "reasoning": undefined, + "reasoningDetails": [], + "request": {}, + "response": { + "headers": undefined, + "id": "id-1", + "messages": [ + { + "content": [ + { + "args": { + "value": "value", + }, + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-call", + }, + ], + "id": "msg-0", + "role": "assistant", + }, + { + "content": [ + { + "result": "result1", + "toolCallId": "call-1", + "toolName": "tool1", + "type": "tool-result", + }, + ], + "id": "msg-3", + "role": "tool", + }, + { + "content": [ + { + "text": "Hello, world!", + "type": "text", + }, + ], + "id": "msg-2", + "role": "assistant", + }, + ], + "modelId": "mock-model-id", + "timestamp": 1970-01-01T00:00:10.000Z, + }, + "sources": [], + "stepType": "tool-result", + "text": "Hello, world!", + "toolCalls": [], + "toolResults": [], + "usage": { + "completionTokens": 10, + "promptTokens": 5, + "totalTokens": 15, + }, + "warnings": undefined, + }, +] +`; + exports[`streamText > options.maxSteps > 4 steps: initial, continue, continue, continue > callbacks > onFinish should send correct information 1`] = ` { "experimental_providerMetadata": undefined, diff --git a/packages/ai/core/generate-text/generate-text.ts b/packages/ai/core/generate-text/generate-text.ts index 9c8bf6ced92e..13c720d0775d 100644 --- a/packages/ai/core/generate-text/generate-text.ts +++ b/packages/ai/core/generate-text/generate-text.ts @@ -31,6 +31,7 @@ import { GenerateTextResult } from './generate-text-result'; import { DefaultGeneratedFile, GeneratedFile } from './generated-file'; import { Output } from './output'; import { parseToolCall } from './parse-tool-call'; +import { PrepareStepFunction } from './prepare-step'; import { asReasoningText, ReasoningDetail } from './reasoning-detail'; import { ResponseMessage, StepResult } from './step-result'; import { toResponseMessages } from './to-response-messages'; @@ -213,19 +214,7 @@ Optional function that you can use to provide different settings for a step. @returns An object that contains the settings for the step. If you return undefined (or for undefined settings), the settings from the outer level will be used. */ - experimental_prepareStep?: (options: { - steps: Array>; - stepNumber: number; - maxSteps: number; - model: LanguageModel; - }) => PromiseLike< - | { - model?: LanguageModel; - toolChoice?: ToolChoice; - experimental_activeTools?: Array; - } - | undefined - >; + experimental_prepareStep?: PrepareStepFunction; /** A function that attempts to repair a tool call that failed to parse. @@ -329,11 +318,22 @@ A function that attempts to repair a tool call that failed to parse. ...responseMessages, ]; + const promptMessages = await convertToLanguageModelPrompt({ + prompt: { + type: promptFormat, + system: initialPrompt.system, + messages: stepInputMessages, + }, + modelSupportsImageUrls: model.supportsImageUrls, + modelSupportsUrl: model.supportsUrl?.bind(model), // support 'this' context + }); + const prepareStepResult = await prepareStep?.({ model, steps, maxSteps, stepNumber: stepCount, + messages: promptMessages, }); const stepToolChoice = prepareStepResult?.toolChoice ?? toolChoice; @@ -341,7 +341,7 @@ A function that attempts to repair a tool call that failed to parse. prepareStepResult?.experimental_activeTools ?? activeTools; const stepModel = prepareStepResult?.model ?? model; - const promptMessages = await convertToLanguageModelPrompt({ + const promptMessagesForStep = await convertToLanguageModelPrompt({ prompt: { type: promptFormat, system: initialPrompt.system, @@ -377,7 +377,7 @@ A function that attempts to repair a tool call that failed to parse. // prompt: 'ai.prompt.format': { input: () => promptFormat }, 'ai.prompt.messages': { - input: () => stringifyForTelemetry(promptMessages), + input: () => stringifyForTelemetry(promptMessagesForStep), }, 'ai.prompt.tools': { // convert the language model level tools: @@ -409,7 +409,7 @@ A function that attempts to repair a tool call that failed to parse. ...callSettings, inputFormat: promptFormat, responseFormat: output?.responseFormat({ model }), - prompt: promptMessages, + prompt: promptMessagesForStep, providerMetadata: providerOptions, abortSignal, headers, diff --git a/packages/ai/core/generate-text/prepare-step.ts b/packages/ai/core/generate-text/prepare-step.ts new file mode 100644 index 000000000000..483b871f6ba3 --- /dev/null +++ b/packages/ai/core/generate-text/prepare-step.ts @@ -0,0 +1,39 @@ +import { LanguageModelV1Message } from '@ai-sdk/provider'; +import { LanguageModel, ToolChoice } from '../types'; +import { Tool } from '../tool'; +import { StepResult } from './step-result'; + +/** +Function that you can use to provide different settings for a step. + +@param options - The options for the step. +@param options.steps - The steps that have been executed so far. +@param options.stepNumber - The number of the step that is being executed. +@param options.maxSteps - The maximum number of steps. +@param options.model - The model that is being used. +@param options.messages - The messages that will be sent to the model. + +@returns An object that contains the settings for the step. +If you return undefined (or for undefined settings), the settings from the outer level will be used. + */ +export type PrepareStepFunction< + TOOLS extends Record = Record, +> = (options: { + steps: Array>>; + stepNumber: number; + maxSteps: number; + model: LanguageModel; + messages: Array; +}) => PromiseLike> | PrepareStepResult; + +export type PrepareStepResult< + TOOLS extends Record = Record, +> = + | { + model?: LanguageModel; + toolChoice?: ToolChoice>; + experimental_activeTools?: Array>; + system?: string; + messages?: Array; + } + | undefined; diff --git a/packages/ai/core/generate-text/stream-text.test.ts b/packages/ai/core/generate-text/stream-text.test.ts index fa9a99c52905..367118016c01 100644 --- a/packages/ai/core/generate-text/stream-text.test.ts +++ b/packages/ai/core/generate-text/stream-text.test.ts @@ -2913,6 +2913,249 @@ describe('streamText', () => { expect(await result.sources).toMatchSnapshot(); }); }); + + describe('2 steps: initial, tool-result with experimental_prepareStep', () => { + beforeEach(async () => { + result = undefined as any; + onFinishResult = undefined as any; + onStepFinishResults = []; + + let responseCount = 0; + + const trueModel = new MockLanguageModelV1({ + doStream: async ({ prompt, mode }) => { + switch (responseCount++) { + case 0: { + expect(mode).toStrictEqual({ + type: 'regular', + tools: [ + { + type: 'function', + name: 'tool1', + description: undefined, + parameters: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: false, + properties: { value: { type: 'string' } }, + required: ['value'], + type: 'object', + }, + }, + ], + toolChoice: { type: 'tool', toolName: 'tool1' }, + }); + + expect(prompt).toStrictEqual([ + { + role: 'user', + content: [{ type: 'text', text: 'test-input' }], + providerMetadata: undefined, + }, + ]); + + return { + stream: convertArrayToReadableStream([ + { + type: 'response-metadata', + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + { + type: 'tool-call', + toolCallType: 'function', + toolCallId: 'call-1', + toolName: 'tool1', + args: `{ "value": "value" }`, + }, + { + type: 'finish', + finishReason: 'tool-calls', + logprobs: undefined, + usage: { completionTokens: 5, promptTokens: 10 }, + }, + ]), + rawCall: { rawPrompt: 'prompt', rawSettings: {} }, + }; + } + + case 1: { + expect(mode).toStrictEqual({ + type: 'regular', + toolChoice: { type: 'auto' }, + tools: [], + }); + + expect(prompt).toStrictEqual([ + { + role: 'user', + content: [{ type: 'text', text: 'test-input' }], + providerMetadata: undefined, + }, + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'tool1', + args: { value: 'value' }, + providerMetadata: undefined, + }, + ], + providerMetadata: undefined, + }, + { + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: 'call-1', + toolName: 'tool1', + result: 'result1', + content: undefined, + isError: undefined, + providerMetadata: undefined, + }, + ], + providerMetadata: undefined, + }, + ]); + + return { + stream: convertArrayToReadableStream([ + { + type: 'response-metadata', + id: 'id-1', + modelId: 'mock-model-id', + timestamp: new Date(10000), + }, + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: 'world!' }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 5 }, + }, + ]), + rawCall: { rawPrompt: 'prompt', rawSettings: {} }, + }; + } + + default: + throw new Error(`Unexpected response count: ${responseCount}`); + } + }, + }); + + const originalModel = new MockLanguageModelV1({ + doStream: async () => { + throw new Error('Should not be called'); + }, + }); + + result = streamText({ + model: originalModel, + tools: { + tool1: tool({ + parameters: z.object({ value: z.string() }), + execute: async (args, options) => { + expect(args).toStrictEqual({ value: 'value' }); + expect(options.messages).toStrictEqual([ + { role: 'user', content: 'test-input' }, + ]); + return 'result1'; + }, + }), + }, + prompt: 'test-input', + maxSteps: 3, + onFinish: async event => { + onFinishResult = event as any; + }, + onStepFinish: async event => { + onStepFinishResults.push(event); + }, + experimental_prepareStep: async ({ model, stepNumber, steps }) => { + expect(model).toStrictEqual(originalModel); + + if (stepNumber === 0) { + expect(steps).toStrictEqual([]); + return { + model: trueModel, + toolChoice: { + type: 'tool', + toolName: 'tool1' as const, + }, + }; + } + + if (stepNumber === 1) { + expect(steps.length).toStrictEqual(0); // step 0 not yet recorded in streaming + return { + model: trueModel, + toolChoice: 'auto', + experimental_activeTools: [], + }; + } + }, + experimental_generateMessageId: mockId({ prefix: 'msg' }), + }); + }); + + it('should contain assistant response message and tool message from all steps', async () => { + // We need to consume the stream to trigger tool execution and second step + expect( + await convertAsyncIterableToArray(result.fullStream), + ).toMatchSnapshot(); + }); + + describe('callbacks', () => { + it('onFinish should send correct information', async () => { + result.consumeStream(); + expect(await result.finishReason).toStrictEqual('stop'); + expect(onFinishResult).toMatchSnapshot(); + }); + + it('onStepFinish should send correct information', async () => { + result.consumeStream(); + expect(onStepFinishResults).toMatchSnapshot(); + }); + }); + + describe('value promises', () => { + it('result.usage should contain total token usage', async () => { + result.consumeStream(); + assert.deepStrictEqual(await result.usage, { + completionTokens: 15, + promptTokens: 15, + totalTokens: 30, + }); + }); + + it('result.finishReason should contain finish reason from final step', async () => { + result.consumeStream(); + expect(await result.finishReason).toStrictEqual('stop'); + }); + + it('result.text should contain text from final step', async () => { + result.consumeStream(); + assert.deepStrictEqual(await result.text, 'Hello, world!'); + }); + + it('result.steps should contain all steps', async () => { + result.consumeStream(); + expect(await result.steps).toMatchSnapshot(); + }); + + it('result.response.messages should contain response messages from all steps', async () => { + result.consumeStream(); + expect((await result.response).messages).toMatchSnapshot(); + }); + }); + }); }); describe('options.headers', () => { diff --git a/packages/ai/core/generate-text/stream-text.ts b/packages/ai/core/generate-text/stream-text.ts index c478c5be3633..1aac8bdcaa07 100644 --- a/packages/ai/core/generate-text/stream-text.ts +++ b/packages/ai/core/generate-text/stream-text.ts @@ -49,6 +49,7 @@ import { splitOnLastWhitespace } from '../util/split-on-last-whitespace'; import { writeToServerResponse } from '../util/write-to-server-response'; import { GeneratedFile } from './generated-file'; import { Output } from './output'; +import { PrepareStepFunction } from './prepare-step'; import { asReasoningText, ReasoningDetail } from './reasoning-detail'; import { runToolsTransformation, @@ -214,6 +215,7 @@ export function streamText< experimental_toolCallStreaming = false, toolCallStreaming = experimental_toolCallStreaming, experimental_activeTools: activeTools, + experimental_prepareStep: prepareStep, experimental_repairToolCall: repairToolCall, experimental_transform: transform, onChunk, @@ -297,6 +299,20 @@ A function that attempts to repair a tool call that failed to parse. */ experimental_repairToolCall?: ToolCallRepairFunction; + /** +Optional function that you can use to provide different settings for a step. + +@param options - The options for the step. +@param options.steps - The steps that have been executed so far. +@param options.stepNumber - The number of the step that is being executed. +@param options.maxSteps - The maximum number of steps. +@param options.model - The model that is being used. + +@returns An object that contains the settings for the step. +If you return undefined (or for undefined settings), the settings from the outer level will be used. + */ + experimental_prepareStep?: PrepareStepFunction; + /** Enable streaming of tool call deltas as they are generated. Disabled by default. */ @@ -370,6 +386,7 @@ Internal. For test use only. May change without notice. toolCallStreaming, transforms: asArray(transform), activeTools, + prepareStep, repairToolCall, maxSteps, output, @@ -544,6 +561,7 @@ class DefaultStreamTextResult toolCallStreaming, transforms, activeTools, + prepareStep, repairToolCall, maxSteps, output, @@ -572,6 +590,7 @@ class DefaultStreamTextResult toolCallStreaming: boolean; transforms: Array>; activeTools: Array | undefined; + prepareStep: PrepareStepFunction | undefined; repairToolCall: ToolCallRepairFunction | undefined; maxSteps: number; output: Output | undefined; @@ -973,19 +992,34 @@ class DefaultStreamTextResult ...responseMessages, ]; + const prepareStepResult = await prepareStep?.({ + model, + steps: recordedSteps, + stepNumber: currentStep, + maxSteps, + messages: stepInputMessages as any, // TODO: Fix type compatibility + }); + const promptMessages = await convertToLanguageModelPrompt({ prompt: { type: promptFormat, - system: initialPrompt.system, - messages: stepInputMessages, + system: prepareStepResult?.system ?? initialPrompt.system, + messages: prepareStepResult?.messages ?? stepInputMessages, }, modelSupportsImageUrls: model.supportsImageUrls, modelSupportsUrl: model.supportsUrl?.bind(model), // support 'this' context }); + const stepModel = prepareStepResult?.model ?? model; + const mode = { type: 'regular' as const, - ...prepareToolsAndToolChoice({ tools, toolChoice, activeTools }), + ...prepareToolsAndToolChoice({ + tools, + toolChoice: prepareStepResult?.toolChoice ?? toolChoice, + activeTools: + prepareStepResult?.experimental_activeTools ?? activeTools, + }), }; const { @@ -1021,8 +1055,8 @@ class DefaultStreamTextResult }, // standardized gen-ai llm span attributes: - 'gen_ai.system': model.provider, - 'gen_ai.request.model': model.modelId, + 'gen_ai.system': stepModel.provider, + 'gen_ai.request.model': stepModel.modelId, 'gen_ai.request.frequency_penalty': settings.frequencyPenalty, 'gen_ai.request.max_tokens': settings.maxTokens, 'gen_ai.request.presence_penalty': settings.presencePenalty, @@ -1037,11 +1071,11 @@ class DefaultStreamTextResult fn: async doStreamSpan => ({ startTimestampMs: now(), // get before the call doStreamSpan, - result: await model.doStream({ + result: await stepModel.doStream({ mode, ...prepareCallSettings(settings), inputFormat: promptFormat, - responseFormat: output?.responseFormat({ model }), + responseFormat: output?.responseFormat({ model: stepModel }), prompt: promptMessages, providerMetadata: providerOptions, abortSignal, @@ -1087,7 +1121,7 @@ class DefaultStreamTextResult let stepResponse: { id: string; timestamp: Date; modelId: string } = { id: generateId(), timestamp: currentDate(), - modelId: model.modelId, + modelId: stepModel.modelId, }; // chunk buffer when using continue: From dcc3e9ba6c185e6b43cea2ecc608867d3d090827 Mon Sep 17 00:00:00 2001 From: yao <63141491+yaonyan@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:42:35 +0800 Subject: [PATCH 2/5] feat(ai): add changeset for experimental_prepareStep implementation in v4 --- .changeset/short-feet-peel.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/short-feet-peel.md diff --git a/.changeset/short-feet-peel.md b/.changeset/short-feet-peel.md new file mode 100644 index 000000000000..eebcb2c00cfb --- /dev/null +++ b/.changeset/short-feet-peel.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat(ai): implement experimental_prepareStep for streamText in v4 \ No newline at end of file From 97eb733a2a7a85a685ab3635b0914b3498e91a27 Mon Sep 17 00:00:00 2001 From: yao <63141491+yaonyan@users.noreply.github.com> Date: Fri, 18 Jul 2025 17:32:55 +0800 Subject: [PATCH 3/5] fix(ai): update type for step input messages to LanguageModelV1Message in stream-text.ts --- packages/ai/core/generate-text/stream-text.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/ai/core/generate-text/stream-text.ts b/packages/ai/core/generate-text/stream-text.ts index 1aac8bdcaa07..6f38e0f08407 100644 --- a/packages/ai/core/generate-text/stream-text.ts +++ b/packages/ai/core/generate-text/stream-text.ts @@ -1,4 +1,4 @@ -import { AISDKError, LanguageModelV1Source } from '@ai-sdk/provider'; +import { AISDKError, LanguageModelV1Message, LanguageModelV1Source } from '@ai-sdk/provider'; import { createIdGenerator, IDGenerator } from '@ai-sdk/provider-utils'; import { DataStreamString, formatDataStreamPart } from '@ai-sdk/ui-utils'; import { Span } from '@opentelemetry/api'; @@ -990,14 +990,14 @@ class DefaultStreamTextResult const stepInputMessages = [ ...initialPrompt.messages, ...responseMessages, - ]; + ] as LanguageModelV1Message[]; const prepareStepResult = await prepareStep?.({ model, steps: recordedSteps, stepNumber: currentStep, maxSteps, - messages: stepInputMessages as any, // TODO: Fix type compatibility + messages: stepInputMessages }); const promptMessages = await convertToLanguageModelPrompt({ From 61d3754f90be36517f6e7daa243f73bf3fc9fc38 Mon Sep 17 00:00:00 2001 From: yao <63141491+yaonyan@users.noreply.github.com> Date: Fri, 18 Jul 2025 20:11:43 +0800 Subject: [PATCH 4/5] refactor(ai): remove maxSteps parameter from prepareStep function and update related calls --- packages/ai/core/generate-text/generate-text.ts | 1 - packages/ai/core/generate-text/prepare-step.ts | 1 - .../ai/core/generate-text/stream-text.test.ts | 17 +++++++++++++++-- packages/ai/core/generate-text/stream-text.ts | 3 +-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/packages/ai/core/generate-text/generate-text.ts b/packages/ai/core/generate-text/generate-text.ts index 13c720d0775d..56eef95153aa 100644 --- a/packages/ai/core/generate-text/generate-text.ts +++ b/packages/ai/core/generate-text/generate-text.ts @@ -331,7 +331,6 @@ A function that attempts to repair a tool call that failed to parse. const prepareStepResult = await prepareStep?.({ model, steps, - maxSteps, stepNumber: stepCount, messages: promptMessages, }); diff --git a/packages/ai/core/generate-text/prepare-step.ts b/packages/ai/core/generate-text/prepare-step.ts index 483b871f6ba3..3409de79c8f9 100644 --- a/packages/ai/core/generate-text/prepare-step.ts +++ b/packages/ai/core/generate-text/prepare-step.ts @@ -21,7 +21,6 @@ export type PrepareStepFunction< > = (options: { steps: Array>>; stepNumber: number; - maxSteps: number; model: LanguageModel; messages: Array; }) => PromiseLike> | PrepareStepResult; diff --git a/packages/ai/core/generate-text/stream-text.test.ts b/packages/ai/core/generate-text/stream-text.test.ts index 367118016c01..65778fae6749 100644 --- a/packages/ai/core/generate-text/stream-text.test.ts +++ b/packages/ai/core/generate-text/stream-text.test.ts @@ -2982,8 +2982,21 @@ describe('streamText', () => { case 1: { expect(mode).toStrictEqual({ type: 'regular', - toolChoice: { type: 'auto' }, - tools: [], + toolChoice: { type: 'tool', toolName: 'tool1' }, + tools: [ + { + type: 'function', + name: 'tool1', + description: undefined, + parameters: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: false, + properties: { value: { type: 'string' } }, + required: ['value'], + type: 'object', + }, + }, + ], }); expect(prompt).toStrictEqual([ diff --git a/packages/ai/core/generate-text/stream-text.ts b/packages/ai/core/generate-text/stream-text.ts index 6f38e0f08407..776dd346fa21 100644 --- a/packages/ai/core/generate-text/stream-text.ts +++ b/packages/ai/core/generate-text/stream-text.ts @@ -995,8 +995,7 @@ class DefaultStreamTextResult const prepareStepResult = await prepareStep?.({ model, steps: recordedSteps, - stepNumber: currentStep, - maxSteps, + stepNumber: recordedSteps.length, messages: stepInputMessages }); From 5099e2dea7f9223252b22d2235e0b266e1de4d6e Mon Sep 17 00:00:00 2001 From: yao <63141491+yaonyan@users.noreply.github.com> Date: Fri, 18 Jul 2025 21:53:32 +0800 Subject: [PATCH 5/5] feat(ai): update step handling in stream processing to support dynamic maxSteps configuration --- .../__snapshots__/stream-text.test.ts.snap | 54 +++++++++---------- .../ai/core/generate-text/generate-text.ts | 1 + .../ai/core/generate-text/prepare-step.ts | 1 + .../ai/core/generate-text/stream-text.test.ts | 19 ++----- packages/ai/core/generate-text/stream-text.ts | 11 ++++ 5 files changed, 43 insertions(+), 43 deletions(-) diff --git a/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap b/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap index 426167333992..0c6d744223e4 100644 --- a/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap +++ b/packages/ai/core/generate-text/__snapshots__/stream-text.test.ts.snap @@ -127,7 +127,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -137,7 +137,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], @@ -194,7 +194,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, ], @@ -275,7 +275,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -285,7 +285,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], @@ -367,7 +367,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, ], @@ -448,7 +448,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -458,7 +458,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > callbac "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], @@ -534,7 +534,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > should "warnings": undefined, }, { - "messageId": "msg-2", + "messageId": "msg-3", "request": {}, "type": "step-start", "warnings": [], @@ -552,7 +552,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > should "finishReason": "stop", "isContinued": false, "logprobs": undefined, - "messageId": "msg-2", + "messageId": "msg-3", "providerMetadata": undefined, "request": {}, "response": { @@ -752,7 +752,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > value p "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -762,7 +762,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > value p "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ] @@ -818,7 +818,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > value p "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, ], @@ -899,7 +899,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > value p "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -909,7 +909,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result > value p "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], @@ -968,7 +968,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -978,7 +978,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], @@ -1024,7 +1024,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, ], @@ -1099,7 +1099,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -1109,7 +1109,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], @@ -1191,7 +1191,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "warnings": undefined, }, { - "messageId": "msg-2", + "messageId": "msg-3", "request": {}, "type": "step-start", "warnings": [], @@ -1213,7 +1213,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "finishReason": "stop", "isContinued": false, "logprobs": undefined, - "messageId": "msg-2", + "messageId": "msg-3", "providerMetadata": undefined, "request": {}, "response": { @@ -1276,7 +1276,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -1286,7 +1286,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ] @@ -1331,7 +1331,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, ], @@ -1406,7 +1406,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "tool-result", }, ], - "id": "msg-3", + "id": "msg-1", "role": "tool", }, { @@ -1416,7 +1416,7 @@ exports[`streamText > options.maxSteps > 2 steps: initial, tool-result with expe "type": "text", }, ], - "id": "msg-2", + "id": "msg-3", "role": "assistant", }, ], diff --git a/packages/ai/core/generate-text/generate-text.ts b/packages/ai/core/generate-text/generate-text.ts index 56eef95153aa..13c720d0775d 100644 --- a/packages/ai/core/generate-text/generate-text.ts +++ b/packages/ai/core/generate-text/generate-text.ts @@ -331,6 +331,7 @@ A function that attempts to repair a tool call that failed to parse. const prepareStepResult = await prepareStep?.({ model, steps, + maxSteps, stepNumber: stepCount, messages: promptMessages, }); diff --git a/packages/ai/core/generate-text/prepare-step.ts b/packages/ai/core/generate-text/prepare-step.ts index 3409de79c8f9..483b871f6ba3 100644 --- a/packages/ai/core/generate-text/prepare-step.ts +++ b/packages/ai/core/generate-text/prepare-step.ts @@ -21,6 +21,7 @@ export type PrepareStepFunction< > = (options: { steps: Array>>; stepNumber: number; + maxSteps: number; model: LanguageModel; messages: Array; }) => PromiseLike> | PrepareStepResult; diff --git a/packages/ai/core/generate-text/stream-text.test.ts b/packages/ai/core/generate-text/stream-text.test.ts index 65778fae6749..eea729aef8a3 100644 --- a/packages/ai/core/generate-text/stream-text.test.ts +++ b/packages/ai/core/generate-text/stream-text.test.ts @@ -2982,21 +2982,8 @@ describe('streamText', () => { case 1: { expect(mode).toStrictEqual({ type: 'regular', - toolChoice: { type: 'tool', toolName: 'tool1' }, - tools: [ - { - type: 'function', - name: 'tool1', - description: undefined, - parameters: { - $schema: 'http://json-schema.org/draft-07/schema#', - additionalProperties: false, - properties: { value: { type: 'string' } }, - required: ['value'], - type: 'object', - }, - }, - ], + toolChoice: { type: 'auto' }, + tools: [], }); expect(prompt).toStrictEqual([ @@ -3106,7 +3093,7 @@ describe('streamText', () => { } if (stepNumber === 1) { - expect(steps.length).toStrictEqual(0); // step 0 not yet recorded in streaming + expect(steps.length).toStrictEqual(1); // step 0 is now properly recorded due to race condition fix return { model: trueModel, toolChoice: 'auto', diff --git a/packages/ai/core/generate-text/stream-text.ts b/packages/ai/core/generate-text/stream-text.ts index 776dd346fa21..38296f0a42a6 100644 --- a/packages/ai/core/generate-text/stream-text.ts +++ b/packages/ai/core/generate-text/stream-text.ts @@ -647,6 +647,7 @@ class DefaultStreamTextResult let stepType: 'initial' | 'continue' | 'tool-result' = 'initial'; const recordedSteps: StepResult[] = []; let rootSpan!: Span; + let stepFinish!: DelayedPromise; const eventProcessor = new TransformStream< EnrichedStreamPart, @@ -798,6 +799,9 @@ class DefaultStreamTextResult recordedResponse.messages.push(...stepMessages); recordedContinuationText = ''; } + + // Signal that the step is fully processed + stepFinish.resolve(); } if (part.type === 'finish') { @@ -983,6 +987,9 @@ class DefaultStreamTextResult hasLeadingWhitespace: boolean; messageId: string; }) { + // Create a new promise for this step + stepFinish = new DelayedPromise(); + // after the 1st step, we need to switch to messages format: const promptFormat = responseMessages.length === 0 ? initialPrompt.type : 'messages'; @@ -996,6 +1003,7 @@ class DefaultStreamTextResult model, steps: recordedSteps, stepNumber: recordedSteps.length, + maxSteps, messages: stepInputMessages }); @@ -1453,6 +1461,9 @@ class DefaultStreamTextResult self.closeStream(); // close the stitchable stream } else { + // wait for the step to be fully processed by the event processor + await stepFinish.value; + // append to messages for the next step: if (stepType === 'continue') { // continue step: update the last assistant message