diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 801d1e2b2d..4996337df1 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -31,6 +31,7 @@ import { PolicyDecision, type ApprovalMode } from '../policy/types.js'; import { ToolConfirmationOutcome, type AnyDeclarativeTool, + MUTATOR_KINDS, } from '../tools/tools.js'; import { getToolSuggestion } from '../utils/tool-utils.js'; import { runInDevTraceSpan } from '../telemetry/trace.js'; @@ -457,12 +458,21 @@ export class Scheduler { return true; } - // If the first tool is parallelizable, batch all contiguous parallelizable tools. + // If the first tool is parallelizable, batch all contiguous parallelizable tools + // that do not conflict with already batched tools. if (this._isParallelizable(next.request)) { while (this.state.queueLength > 0) { const peeked = this.state.peekQueue(); if (peeked && this._isParallelizable(peeked.request)) { - this.state.dequeue(); + const activeCalls = this.state.allActiveCalls; + const hasConflict = activeCalls.some((c) => + this._hasConflict(c, peeked), + ); + if (!hasConflict) { + this.state.dequeue(); + } else { + break; + } } else { break; } @@ -558,6 +568,38 @@ export class Scheduler { return true; } + private _hasConflict(call1: ToolCall, call2: ToolCall): boolean { + if ( + !('tool' in call1) || + !call1.tool || + !('tool' in call2) || + !call2.tool || + !('invocation' in call1) || + !call1.invocation || + !('invocation' in call2) || + !call2.invocation + ) { + return false; + } + const isMutator1 = MUTATOR_KINDS.includes(call1.tool.kind); + const isMutator2 = MUTATOR_KINDS.includes(call2.tool.kind); + if (!isMutator1 && !isMutator2) { + return false; + } + + const locs1 = call1.invocation.toolLocations(); + const locs2 = call2.invocation.toolLocations(); + + for (const l1 of locs1) { + for (const l2 of locs2) { + if (l1.path === l2.path) { + return true; + } + } + } + return false; + } + private async _processValidatingCall( active: ValidatingToolCall, signal: AbortSignal, diff --git a/packages/core/src/scheduler/scheduler_parallel.test.ts b/packages/core/src/scheduler/scheduler_parallel.test.ts index 1f1f5efafd..3cc878c5d1 100644 --- a/packages/core/src/scheduler/scheduler_parallel.test.ts +++ b/packages/core/src/scheduler/scheduler_parallel.test.ts @@ -175,6 +175,7 @@ describe('Scheduler Parallel Execution', () => { const mockInvocation = { shouldConfirmExecute: vi.fn().mockResolvedValue(false), + toolLocations: vi.fn().mockReturnValue([]), }; beforeEach(() => { @@ -318,20 +319,50 @@ describe('Scheduler Parallel Execution', () => { schedulerId: 'root', }); - vi.mocked(readTool1.build).mockReturnValue( - mockInvocation as unknown as AnyToolInvocation, + vi.mocked(readTool1.build).mockImplementation( + (args) => + ({ + ...mockInvocation, + toolLocations: () => [ + { path: (args as { path?: string })?.path || 'a.txt' }, + ], + }) as unknown as AnyToolInvocation, ); - vi.mocked(readTool2.build).mockReturnValue( - mockInvocation as unknown as AnyToolInvocation, + vi.mocked(readTool2.build).mockImplementation( + (args) => + ({ + ...mockInvocation, + toolLocations: () => [ + { path: (args as { path?: string })?.path || 'b.txt' }, + ], + }) as unknown as AnyToolInvocation, ); - vi.mocked(writeTool.build).mockReturnValue( - mockInvocation as unknown as AnyToolInvocation, + vi.mocked(writeTool.build).mockImplementation( + (args) => + ({ + ...mockInvocation, + toolLocations: () => [ + { path: (args as { path?: string })?.path || 'c.txt' }, + ], + }) as unknown as AnyToolInvocation, ); - vi.mocked(agentTool1.build).mockReturnValue( - mockInvocation as unknown as AnyToolInvocation, + vi.mocked(agentTool1.build).mockImplementation( + (args) => + ({ + ...mockInvocation, + toolLocations: () => [ + { path: (args as { path?: string })?.path || 'agent1.txt' }, + ], + }) as unknown as AnyToolInvocation, ); - vi.mocked(agentTool2.build).mockReturnValue( - mockInvocation as unknown as AnyToolInvocation, + vi.mocked(agentTool2.build).mockImplementation( + (args) => + ({ + ...mockInvocation, + toolLocations: () => [ + { path: (args as { path?: string })?.path || 'agent2.txt' }, + ], + }) as unknown as AnyToolInvocation, ); }); @@ -510,7 +541,7 @@ describe('Scheduler Parallel Execution', () => { expect(start1).toBeGreaterThan(end3); }); - it('should execute non-read-only tools in parallel if wait_for_previous is false', async () => { + it('should execute non-read-only tools in parallel if wait_for_previous is false and paths differ', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => { const id = call.request.callId; @@ -523,8 +554,16 @@ describe('Scheduler Parallel Execution', () => { } as unknown as SuccessfulToolCall; }); - const w1 = { ...req3, callId: 'w1', args: { wait_for_previous: false } }; - const w2 = { ...req3, callId: 'w2', args: { wait_for_previous: false } }; + const w1 = { + ...req3, + callId: 'w1', + args: { path: 'w1.txt', wait_for_previous: false }, + }; + const w2 = { + ...req3, + callId: 'w2', + args: { path: 'w2.txt', wait_for_previous: false }, + }; await scheduler.schedule([w1, w2], signal); @@ -532,6 +571,38 @@ describe('Scheduler Parallel Execution', () => { expect(executionLog.slice(0, 2)).toContain('start-w2'); }); + it('should execute non-read-only tools sequentially if they target the same file even if wait_for_previous is false', async () => { + const executionLog: string[] = []; + mockExecutor.execute.mockImplementation(async ({ call }) => { + const id = call.request.callId; + executionLog.push(`start-${id}`); + await new Promise((resolve) => setTimeout(resolve, 10)); + executionLog.push(`end-${id}`); + return { + status: 'success', + response: { callId: id, responseParts: [] }, + } as unknown as SuccessfulToolCall; + }); + + const w1 = { + ...req3, + callId: 'w1', + args: { path: 'same.txt', wait_for_previous: false }, + }; + const w2 = { + ...req3, + callId: 'w2', + args: { path: 'same.txt', wait_for_previous: false }, + }; + + await scheduler.schedule([w1, w2], signal); + + expect(executionLog[0]).toBe('start-w1'); + expect(executionLog[1]).toBe('end-w1'); + expect(executionLog[2]).toBe('start-w2'); + expect(executionLog[3]).toBe('end-w2'); + }); + it('should execute read-only tools sequentially if wait_for_previous is true', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => {