diff --git a/packages/core/src/agent/agent.ts b/packages/core/src/agent/agent.ts index 4a8ee1a0d..8b5f73869 100644 --- a/packages/core/src/agent/agent.ts +++ b/packages/core/src/agent/agent.ts @@ -246,8 +246,8 @@ export class Agent< this.onTaskStartTip = this.opts.onTaskStartTip; - this.insight = new Insight(async (action: InsightAction) => { - return this.getUIContext(action); + this.insight = new Insight(async () => { + return this.getUIContext(); }); // Process cache configuration diff --git a/packages/core/src/agent/tasks.ts b/packages/core/src/agent/tasks.ts index 801347953..a5e5bb54e 100644 --- a/packages/core/src/agent/tasks.ts +++ b/packages/core/src/agent/tasks.ts @@ -5,22 +5,20 @@ import { plan, uiTarsPlanning, } from '@/ai-model'; -import { Executor } from '@/ai-model/action-executor'; import type { TMultimodalPrompt, TUserPrompt } from '@/ai-model/common'; import type { AbstractInterface } from '@/device'; +import { Executor } from '@/executor'; import type Insight from '@/insight'; import type { AIUsageInfo, DetailedLocateParam, DumpSubscriber, ElementCacheFeature, - ExecutionRecorderItem, ExecutionTaskActionApply, ExecutionTaskApply, ExecutionTaskHitBy, ExecutionTaskInsightLocateApply, ExecutionTaskInsightQueryApply, - ExecutionTaskPlanning, ExecutionTaskPlanningApply, ExecutionTaskProgressOptions, ExecutorContext, @@ -109,43 +107,6 @@ export class TaskExecutor { this.conversationHistory = new ConversationHistory(); } - private async recordScreenshot(timing: ExecutionRecorderItem['timing']) { - const base64 = await this.interface.screenshotBase64(); - const item: ExecutionRecorderItem = { - type: 'screenshot', - ts: Date.now(), - screenshot: base64, - timing, - }; - return item; - } - - private prependExecutorWithScreenshot( - taskApply: ExecutionTaskApply, - appendAfterExecution = false, - ): ExecutionTaskApply { - const taskWithScreenshot: ExecutionTaskApply = { - ...taskApply, - executor: async (param, context, ...args) => { - const recorder: ExecutionRecorderItem[] = []; - const { task } = context; - // set the recorder before executor in case of error - task.recorder = recorder; - const shot = await this.recordScreenshot(`before ${task.type}`); - recorder.push(shot); - - const result = await taskApply.executor(param, context, ...args); - - if (appendAfterExecution) { - const shot2 = await this.recordScreenshot('after Action'); - recorder.push(shot2); - } - return result; - }, - }; - return taskWithScreenshot; - } - public async convertPlanToExecutable( plans: PlanningAction[], modelConfig: IModelConfig, @@ -201,19 +162,9 @@ export class TaskExecutor { } }; this.insight.onceDumpUpdatedFn = dumpCollector; - const shotTime = Date.now(); - // Get context through contextRetrieverFn which handles frozen context - const uiContext = await this.insight.contextRetrieverFn('locate'); - task.uiContext = uiContext; - - const recordItem: ExecutionRecorderItem = { - type: 'screenshot', - ts: shotTime, - screenshot: uiContext.screenshotBase64, - timing: 'before Insight', - }; - task.recorder = [recordItem]; + const { uiContext } = taskContext; + assert(uiContext, 'uiContext is required for Insight task'); // try matching xpath const elementFromXpath = @@ -470,17 +421,17 @@ export class TaskExecutor { subType: planType, thought: plan.thought, param: plan.param, - executor: async (param, context) => { + executor: async (param, taskContext) => { debug( 'executing action', planType, param, - `context.element.center: ${context.element?.center}`, + `taskContext.element.center: ${taskContext.element?.center}`, ); // Get context for actionSpace operations to ensure size info is available - const uiContext = await this.insight.contextRetrieverFn('locate'); - context.task.uiContext = uiContext; + const uiContext = taskContext.uiContext; + assert(uiContext, 'uiContext is required for Action task'); requiredLocateFields.forEach((field) => { assert( @@ -523,7 +474,7 @@ export class TaskExecutor { debug('calling action', action.name); const actionFn = action.call.bind(this.interface); - await actionFn(param, context); + await actionFn(param, taskContext); debug('called action', action.name); try { @@ -554,35 +505,14 @@ export class TaskExecutor { } } - const wrappedTasks = tasks.map( - (task: ExecutionTaskApply, index: number) => { - if (task.type === 'Action') { - return this.prependExecutorWithScreenshot( - task, - index === tasks.length - 1, - ); - } - return task; - }, - ); - return { - tasks: wrappedTasks, + tasks, }; } private async setupPlanningContext(executorContext: ExecutorContext) { - const shotTime = Date.now(); - const uiContext = await this.insight.contextRetrieverFn('locate'); - const recordItem: ExecutionRecorderItem = { - type: 'screenshot', - ts: shotTime, - screenshot: uiContext.screenshotBase64, - timing: 'before Planning', - }; - - executorContext.task.recorder = [recordItem]; - (executorContext.task as ExecutionTaskPlanning).uiContext = uiContext; + const uiContext = executorContext.uiContext; + assert(uiContext, 'uiContext is required for Planning task'); return { uiContext, @@ -590,9 +520,13 @@ export class TaskExecutor { } async loadYamlFlowAsPlanning(userInstruction: string, yamlString: string) { - const taskExecutor = new Executor(taskTitleStr('Action', userInstruction), { - onTaskStart: this.onTaskStartCallback, - }); + const taskExecutor = new Executor( + taskTitleStr('Action', userInstruction), + () => Promise.resolve(this.insight.contextRetrieverFn()), + { + onTaskStart: this.onTaskStartCallback, + }, + ); const task: ExecutionTaskPlanningApply = { type: 'Planning', @@ -741,9 +675,13 @@ export class TaskExecutor { plans: PlanningAction[], modelConfig: IModelConfig, ): Promise { - const taskExecutor = new Executor(title, { - onTaskStart: this.onTaskStartCallback, - }); + const taskExecutor = new Executor( + title, + () => Promise.resolve(this.insight.contextRetrieverFn()), + { + onTaskStart: this.onTaskStartCallback, + }, + ); const { tasks } = await this.convertPlanToExecutable(plans, modelConfig); await taskExecutor.append(tasks); const result = await taskExecutor.flush(); @@ -781,9 +719,13 @@ export class TaskExecutor { > { this.conversationHistory.reset(); - const taskExecutor = new Executor(taskTitleStr('Action', userPrompt), { - onTaskStart: this.onTaskStartCallback, - }); + const taskExecutor = new Executor( + taskTitleStr('Action', userPrompt), + () => Promise.resolve(this.insight.contextRetrieverFn()), + { + onTaskStart: this.onTaskStartCallback, + }, + ); let replanCount = 0; const yamlFlow: MidsceneYamlFlowItem[] = []; @@ -891,17 +833,8 @@ export class TaskExecutor { this.insight.onceDumpUpdatedFn = dumpCollector; // Get context for query operations - const shotTime = Date.now(); - const uiContext = await this.insight.contextRetrieverFn('extract'); - task.uiContext = uiContext; - - const recordItem: ExecutionRecorderItem = { - type: 'screenshot', - ts: shotTime, - screenshot: uiContext.screenshotBase64, - timing: 'before Extract', - }; - task.recorder = [recordItem]; + const uiContext = taskContext.uiContext; + assert(uiContext, 'uiContext is required for Query task'); const ifTypeRestricted = type !== 'Query'; let demandInput = demand; @@ -965,6 +898,7 @@ export class TaskExecutor { type, typeof demand === 'string' ? demand : JSON.stringify(demand), ), + () => Promise.resolve(this.insight.contextRetrieverFn()), { onTaskStart: this.onTaskStartCallback, }, @@ -978,7 +912,7 @@ export class TaskExecutor { multimodalPrompt, ); - await taskExecutor.append(this.prependExecutorWithScreenshot(queryTask)); + await taskExecutor.append(queryTask); const result = await taskExecutor.flush(); if (!result) { @@ -1012,7 +946,7 @@ export class TaskExecutor { [errorPlan], modelConfig, ); - await taskExecutor.append(this.prependExecutorWithScreenshot(tasks[0])); + await taskExecutor.append(tasks[0]); await taskExecutor.flush(); return { @@ -1035,7 +969,7 @@ export class TaskExecutor { modelConfig, ); - return this.prependExecutorWithScreenshot(sleepTasks[0]); + return sleepTasks[0]; } async waitFor( @@ -1046,9 +980,13 @@ export class TaskExecutor { const { textPrompt, multimodalPrompt } = parsePrompt(assertion); const description = `waitFor: ${textPrompt}`; - const taskExecutor = new Executor(taskTitleStr('WaitFor', description), { - onTaskStart: this.onTaskStartCallback, - }); + const taskExecutor = new Executor( + taskTitleStr('WaitFor', description), + () => Promise.resolve(this.insight.contextRetrieverFn()), + { + onTaskStart: this.onTaskStartCallback, + }, + ); const { timeoutMs, checkIntervalMs } = opt; assert(assertion, 'No assertion for waitFor'); @@ -1075,7 +1013,7 @@ export class TaskExecutor { multimodalPrompt, ); - await taskExecutor.append(this.prependExecutorWithScreenshot(queryTask)); + await taskExecutor.append(queryTask); const result = (await taskExecutor.flush()) as | { output: boolean; diff --git a/packages/core/src/ai-model/action-executor.ts b/packages/core/src/executor.ts similarity index 80% rename from packages/core/src/ai-model/action-executor.ts rename to packages/core/src/executor.ts index 38a0ef90b..83d0783f7 100644 --- a/packages/core/src/ai-model/action-executor.ts +++ b/packages/core/src/executor.ts @@ -1,11 +1,13 @@ import type { ExecutionDump, + ExecutionRecorderItem, ExecutionTask, ExecutionTaskApply, ExecutionTaskInsightLocateOutput, ExecutionTaskProgressOptions, ExecutionTaskReturn, ExecutorContext, + UIContext, } from '@/types'; import { assert } from '@midscene/shared/utils'; @@ -19,8 +21,11 @@ export class Executor { onTaskStart?: ExecutionTaskProgressOptions['onTaskStart']; + private readonly uiContextBuilder: () => Promise; + constructor( name: string, + uiContextBuilder: () => Promise, options?: ExecutionTaskProgressOptions & { tasks?: ExecutionTaskApply[]; }, @@ -32,6 +37,45 @@ export class Executor { this.markTaskAsPending(item), ); this.onTaskStart = options?.onTaskStart; + this.uiContextBuilder = uiContextBuilder; + } + + private async captureScreenshot(): Promise { + try { + const uiContext = await this.uiContextBuilder(); + return uiContext?.screenshotBase64; + } catch (error) { + console.error('error while capturing screenshot', error); + } + return undefined; + } + + private attachRecorderItem( + task: ExecutionTask, + contextOrScreenshot: UIContext | string | undefined, + phase: 'before' | 'after', + ): void { + const timing = phase; + const screenshot = + typeof contextOrScreenshot === 'string' + ? contextOrScreenshot + : contextOrScreenshot?.screenshotBase64; + if (!timing || !screenshot) { + return; + } + + const recorderItem: ExecutionRecorderItem = { + type: 'screenshot', + ts: Date.now(), + screenshot, + timing, + }; + + if (!task.recorder) { + task.recorder = [recorderItem]; + return; + } + task.recorder.push(recorderItem); } private markTaskAsPending(task: ExecutionTaskApply): ExecutionTask { @@ -108,9 +152,12 @@ export class Executor { assert(executor, `executor is required for task type: ${task.type}`); let returnValue; + const uiContext = await this.uiContextBuilder(); + task.uiContext = uiContext; const executorContext: ExecutorContext = { task, element: previousFindOutput?.element, + uiContext, }; if (task.type === 'Insight') { @@ -139,6 +186,13 @@ export class Executor { returnValue = await task.executor(param, executorContext); } + const isLastTask = taskIndex === this.tasks.length - 1; + + if (isLastTask) { + const screenshot = await this.captureScreenshot(); + this.attachRecorderItem(task, screenshot, 'after'); + } + Object.assign(task, returnValue); task.status = 'finished'; task.timing.end = Date.now(); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index ad0118e68..8176223da 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,5 @@ import { z } from 'zod'; -import { Executor } from './ai-model/action-executor'; +import { Executor } from './executor'; import Insight from './insight/index'; import { getVersion } from './utils'; diff --git a/packages/core/src/insight/index.ts b/packages/core/src/insight/index.ts index b06af8b3d..7816373b2 100644 --- a/packages/core/src/insight/index.ts +++ b/packages/core/src/insight/index.ts @@ -50,9 +50,7 @@ export default class Insight< ElementType extends BaseElement = BaseElement, ContextType extends UIContext = UIContext, > { - contextRetrieverFn: ( - action: InsightAction, - ) => Promise | ContextType; + contextRetrieverFn: () => Promise | ContextType; aiVendorFn: Exclude = callAIWithObjectResponse; @@ -62,9 +60,7 @@ export default class Insight< taskInfo?: Omit; constructor( - context: - | ContextType - | ((action: InsightAction) => Promise | ContextType), + context: ContextType | (() => Promise | ContextType), opt?: InsightOptions, ) { assert(context, 'context is required for Insight'); @@ -115,7 +111,7 @@ export default class Insight< searchAreaPrompt = undefined; } - const context = opt?.context || (await this.contextRetrieverFn('locate')); + const context = opt?.context || (await this.contextRetrieverFn()); let searchArea: Rect | undefined = undefined; let searchAreaRawResponse: string | undefined = undefined; @@ -255,7 +251,7 @@ export default class Insight< const dumpSubscriber = this.onceDumpUpdatedFn; this.onceDumpUpdatedFn = undefined; - const context = await this.contextRetrieverFn('extract'); + const context = await this.contextRetrieverFn(); const startTime = Date.now(); @@ -320,7 +316,7 @@ export default class Insight< }, ): Promise> { assert(target, 'target is required for insight.describe'); - const context = await this.contextRetrieverFn('describe'); + const context = await this.contextRetrieverFn(); const { screenshotBase64, size } = context; assert(screenshotBase64, 'screenshot is required for insight.describe'); // The result of the "describe" function will be used for positioning, so essentially it is a form of grounding. diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 70ab6dc32..2f49c303e 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -333,6 +333,7 @@ export type ExecutionTaskType = export interface ExecutorContext { task: ExecutionTask; element?: LocateResultElement | null; + uiContext?: UIContext; } export interface ExecutionTaskApply< diff --git a/packages/core/tests/unit-test/executor/index.test.ts b/packages/core/tests/unit-test/executor/index.test.ts index 5db52d058..da249e1c9 100644 --- a/packages/core/tests/unit-test/executor/index.test.ts +++ b/packages/core/tests/unit-test/executor/index.test.ts @@ -5,6 +5,7 @@ import type { ExecutionTaskInsightLocate, ExecutionTaskInsightLocateApply, InsightDump, + UIContext, } from '@/index'; import { fakeInsight } from 'tests/utils'; import { describe, expect, it, vi } from 'vitest'; @@ -47,9 +48,7 @@ const insightFindTask = (shouldThrow?: boolean) => { output: { element, }, - log: { - dump: insightDump, - }, + log: insightDump, cache: { hit: false, }, @@ -59,6 +58,13 @@ const insightFindTask = (shouldThrow?: boolean) => { return insightFindTask; }; +const fakeUIContextBuilder = async () => + ({ + screenshotBase64: '', + tree: { node: null, children: [] }, + size: { width: 0, height: 0 }, + }) as unknown as UIContext; + describe( 'executor', { @@ -92,7 +98,7 @@ describe( const inputTasks = [insightTask1, actionTask, actionTask2]; - const executor = new Executor('test', { + const executor = new Executor('test', fakeUIContextBuilder, { tasks: inputTasks, }); const flushResult = await executor.flush(); @@ -119,7 +125,7 @@ describe( }); it('insight - init and append', async () => { - const initExecutor = new Executor('test'); + const initExecutor = new Executor('test', fakeUIContextBuilder); expect(initExecutor.status).toBe('init'); const tapperFn = vi.fn(); @@ -175,7 +181,7 @@ describe( }); it('insight - run with error', async () => { - const executor = new Executor('test', { + const executor = new Executor('test', fakeUIContextBuilder, { tasks: [insightFindTask(true), insightFindTask()], }); const r = await executor.flush();