diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 20b10fd3..c1c15b4b 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -253,17 +253,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string return a.err(fmt.Errorf("failed to get session: %w", err)) } if session.SummaryMessageID != "" { - summaryMsgInex := -1 - for i, msg := range msgs { - if msg.ID == session.SummaryMessageID { - summaryMsgInex = i - break - } - } - if summaryMsgInex != -1 { - msgs = msgs[summaryMsgInex:] - msgs[0].Role = message.User - } + msgs = a.filterMessagesFromSummary(msgs, session.SummaryMessageID) } userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts) @@ -272,6 +262,8 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } // Append the new user message to the conversation history. msgHistory := append(msgs, userMsg) + compactionAttempts := 0 // Track compaction attempts to prevent infinite loops + maxCompactionAttempts := 2 // Allow at most 2 compaction attempts per request for { // Check for cancellation before each iteration @@ -281,6 +273,47 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string default: // Continue processing } + + // Check if auto-compaction should be triggered before each model call + // This is crucial for long tool use loops that can exceed context limits + if cfg.AutoCompact && compactionAttempts < maxCompactionAttempts && a.shouldTriggerAutoCompactionFromHistory(msgHistory) { + compactionAttempts++ + logging.Info("Auto-compaction triggered during tool use loop", "session_id", sessionID, "history_length", len(msgHistory), "attempt", compactionAttempts) + + // Perform synchronous compaction to shrink context + if err := a.performSynchronousCompaction(ctx, sessionID); err != nil { + logging.Warn("Failed to perform auto-compaction during tool use", "error", err, "attempt", compactionAttempts) + // Continue anyway - better to risk context overflow than stop completely + } else { + // After successful compaction, reload messages and rebuild msgHistory + msgs, err := a.messages.List(ctx, sessionID) + if err != nil { + return a.err(fmt.Errorf("failed to reload messages after compaction: %w", err)) + } + + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return a.err(fmt.Errorf("failed to get session after compaction: %w", err)) + } + msgs = a.filterMessagesFromSummary(msgs, session.SummaryMessageID) + + msgHistory = append(msgs, userMsg) + logging.Info("Context compacted, continuing with reduced history", "session_id", sessionID, "new_history_length", len(msgHistory), "attempt", compactionAttempts) + + // NOTE: Check if compaction actually reduced the context size + // If it's still above threshold, we need to break the loop to prevent infinite compaction + if a.shouldTriggerAutoCompactionFromHistory(msgHistory) { + logging.Warn("Auto-compaction did not sufficiently reduce context size, proceeding anyway to prevent infinite loop", + "session_id", sessionID, "history_length", len(msgHistory), "attempt", compactionAttempts) + // Don't continue - proceed with the current msgHistory to avoid infinite loop + } else { + // After compaction, continue to the next iteration to re-check context size + // This prevents sending an oversized request immediately after compaction + continue + } + } + } + agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) if err != nil { if errors.Is(err, context.Canceled) { @@ -297,10 +330,31 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } else { logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) } - if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { - // We are not done, we need to respond with the tool response - msgHistory = append(msgHistory, agentMessage, *toolResults) - continue + if agentMessage.FinishReason() == message.FinishReasonToolUse { + if toolResults != nil { + // We have tool results, continue with the tool response + msgHistory = append(msgHistory, agentMessage, *toolResults) + continue + } else { + // Tool results are nil (tool execution failed or returned empty) + // Create an empty tool results message to allow the LLM to provide a final response + logging.Warn("Tool results are nil, creating empty tool results message to allow final response", "session_id", sessionID) + emptyToolMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: []message.ContentPart{message.TextContent{Text: "Tool execution completed with no results."}}, + }) + if err != nil { + logging.Warn("Failed to create empty tool results message", "error", err) + // If we can't create the message, just return what we have + return AgentEvent{ + Type: AgentEventTypeResponse, + Message: agentMessage, + Done: true, + } + } + msgHistory = append(msgHistory, agentMessage, emptyToolMsg) + continue + } } return AgentEvent{ Type: AgentEventTypeResponse, @@ -532,6 +586,165 @@ func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (mode return a.provider.Model(), nil } +// shouldTriggerAutoCompaction checks if the session should trigger auto-compaction +// based on token usage approaching the context window limit +// filterMessagesFromSummary filters messages to start from the summary message if one exists +// This reduces context size by excluding messages before the summary +func (a *agent) filterMessagesFromSummary(msgs []message.Message, summaryMessageID string) []message.Message { + if summaryMessageID == "" { + return msgs + } + + summaryMsgIndex := -1 + for i, msg := range msgs { + if msg.ID == summaryMessageID { + summaryMsgIndex = i + break + } + } + + if summaryMsgIndex != -1 { + filteredMsgs := msgs[summaryMsgIndex:] + // Convert the summary message role to User so it can be used in conversation + filteredMsgs[0].Role = message.User + return filteredMsgs + } + + return msgs +} + +func (a *agent) shouldTriggerAutoCompaction(session session.Session) bool { + model := a.provider.Model() + contextWindow := model.ContextWindow + + // If context window is not defined, we can't determine if compaction is needed + if contextWindow <= 0 { + return false + } + + totalTokens := session.CompletionTokens + session.PromptTokens + threshold := int64(float64(contextWindow) * 0.95) // 95% threshold + + return totalTokens >= threshold +} + +// shouldTriggerAutoCompactionFromHistory estimates token usage from message history +// and determines if auto-compaction should be triggered. This is used during tool use loops +// where we don't have real-time token counts from the session. +func (a *agent) shouldTriggerAutoCompactionFromHistory(msgHistory []message.Message) bool { + model := a.provider.Model() + contextWindow := model.ContextWindow + + // If context window is not defined, we can't determine if compaction is needed + if contextWindow <= 0 { + return false + } + + // Estimate tokens from message history + // This is a rough estimation: ~4 characters per token for most models + totalChars := 0 + for _, msg := range msgHistory { + for _, part := range msg.Parts { + if textPart, ok := part.(message.TextContent); ok { + totalChars += len(textPart.Text) + } + // For tool calls and other content types, add some overhead + totalChars += 100 // rough estimate for metadata + } + } + + // Convert characters to estimated tokens (rough approximation) + estimatedTokens := int64(totalChars / 4) + threshold := int64(float64(contextWindow) * 0.90) // Use 90% for history-based estimation to be more conservative + + logging.Debug("Token estimation for auto-compaction", + "estimated_tokens", estimatedTokens, + "threshold", threshold, + "context_window", contextWindow, + "message_count", len(msgHistory)) + + return estimatedTokens >= threshold +} + +// performSynchronousCompaction performs summarization synchronously and waits for completion +// This is used for auto-compaction in non-interactive mode to shrink context before continuing +func (a *agent) performSynchronousCompaction(ctx context.Context, sessionID string) error { + if a.summarizeProvider == nil { + return fmt.Errorf("summarize provider not available") + } + + // Note: We don't check IsSessionBusy here because this is called from within + // an active request processing loop, so the session is already marked as busy + + logging.Info("Starting synchronous compaction", "session_id", sessionID) + + // Get all messages from the session + msgs, err := a.messages.List(ctx, sessionID) + if err != nil { + return fmt.Errorf("failed to list messages: %w", err) + } + + if len(msgs) == 0 { + return fmt.Errorf("no messages to summarize") + } + + // Add session context + summarizeCtx := context.WithValue(ctx, tools.SessionIDContextKey, sessionID) + + // Add a system message to guide the summarization + summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next." + + // Create a new message with the summarize prompt + promptMsg := message.Message{ + Role: message.User, + Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}}, + } + + // Append the prompt to the messages + msgsWithPrompt := append(msgs, promptMsg) + + // Send the messages to the summarize provider + response, err := a.summarizeProvider.SendMessages( + summarizeCtx, + msgsWithPrompt, + make([]tools.BaseTool, 0), + ) + if err != nil { + return fmt.Errorf("failed to summarize: %w", err) + } + + summary := strings.TrimSpace(response.Content) + if summary == "" { + return fmt.Errorf("empty summary returned") + } + + // Get the session to update + oldSession, err := a.sessions.Get(summarizeCtx, sessionID) + if err != nil { + return fmt.Errorf("failed to get session: %w", err) + } + + // Create a new message with the summary + msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{message.TextContent{Text: summary}}, + Model: a.summarizeProvider.Model().ID, + }) + if err != nil { + return fmt.Errorf("failed to create summary message: %w", err) + } + + // Update the session with the summary message ID + oldSession.SummaryMessageID = msg.ID + _, err = a.sessions.Save(summarizeCtx, oldSession) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) + } + + logging.Info("Synchronous compaction completed successfully", "session_id", sessionID) + return nil +} + func (a *agent) Summarize(ctx context.Context, sessionID string) error { if a.summarizeProvider == nil { return fmt.Errorf("summarize provider not available") diff --git a/internal/llm/agent/agent_test.go b/internal/llm/agent/agent_test.go new file mode 100644 index 00000000..4bf206de --- /dev/null +++ b/internal/llm/agent/agent_test.go @@ -0,0 +1,300 @@ +package agent + +import ( + "context" + "strings" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/provider" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "github.com/opencode-ai/opencode/internal/session" +) + +func TestShouldTriggerAutoCompaction(t *testing.T) { + tests := []struct { + name string + contextWindow int64 + promptTokens int64 + completionTokens int64 + expected bool + }{ + { + name: "Below threshold", + contextWindow: 100000, + promptTokens: 40000, + completionTokens: 40000, + expected: false, + }, + { + name: "At threshold", + contextWindow: 100000, + promptTokens: 47500, + completionTokens: 47500, + expected: true, + }, + { + name: "Above threshold", + contextWindow: 100000, + promptTokens: 50000, + completionTokens: 50000, + expected: true, + }, + { + name: "No context window", + contextWindow: 0, + promptTokens: 50000, + completionTokens: 50000, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal agent with just the model info we need + a := &agent{} + + // Mock the provider's Model() method by setting up the agent's provider field + a.provider = &testProvider{ + model: models.Model{ + ContextWindow: tt.contextWindow, + }, + } + + session := session.Session{ + PromptTokens: tt.promptTokens, + CompletionTokens: tt.completionTokens, + } + + result := a.shouldTriggerAutoCompaction(session) + if result != tt.expected { + t.Errorf("shouldTriggerAutoCompaction() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestShouldTriggerAutoCompactionFromHistory(t *testing.T) { + tests := []struct { + name string + contextWindow int64 + messages []message.Message + expected bool + }{ + { + name: "Small history below threshold", + contextWindow: 100000, + messages: []message.Message{ + {Parts: []message.ContentPart{message.TextContent{Text: "Hello"}}}, + {Parts: []message.ContentPart{message.TextContent{Text: "Hi there"}}}, + }, + expected: false, + }, + { + name: "Large history above threshold", + contextWindow: 1000, + messages: []message.Message{ + {Parts: []message.ContentPart{message.TextContent{Text: strings.Repeat("This is a long message that will consume many tokens. ", 100)}}}, + {Parts: []message.ContentPart{message.TextContent{Text: strings.Repeat("Another long response with lots of content. ", 100)}}}, + }, + expected: true, + }, + { + name: "No context window", + contextWindow: 0, + messages: []message.Message{ + {Parts: []message.ContentPart{message.TextContent{Text: strings.Repeat("Long message ", 1000)}}}, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &agent{} + a.provider = &testProvider{ + model: models.Model{ + ContextWindow: tt.contextWindow, + }, + } + + result := a.shouldTriggerAutoCompactionFromHistory(tt.messages) + if result != tt.expected { + t.Errorf("shouldTriggerAutoCompactionFromHistory() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestPerformSynchronousCompaction_NoSummarizeProvider(t *testing.T) { + a := &agent{ + summarizeProvider: nil, + } + + err := a.performSynchronousCompaction(context.Background(), "test-session") + if err == nil { + t.Error("expected error when summarizeProvider is nil, got nil") + } + if err.Error() != "summarize provider not available" { + t.Errorf("expected 'summarize provider not available', got %v", err.Error()) + } +} + +func TestAutoCompactionAttemptLimit(t *testing.T) { + // Test that compaction attempts are limited to prevent infinite loops + // This test verifies the fix for the deadloop issue + + longMessage := strings.Repeat("This is a very long message that will exceed the context window. ", 100) + msgHistory := []message.Message{ + {Parts: []message.ContentPart{message.TextContent{Text: longMessage}}}, + {Parts: []message.ContentPart{message.TextContent{Text: longMessage}}}, + } + + a := &agent{} + a.provider = &testProvider{ + model: models.Model{ + ContextWindow: 1000, // Small context window to trigger compaction + }, + } + + // Test that shouldTriggerAutoCompactionFromHistory returns true for this history + shouldTrigger := a.shouldTriggerAutoCompactionFromHistory(msgHistory) + if !shouldTrigger { + t.Error("expected shouldTriggerAutoCompactionFromHistory to return true for large message history") + } +} + +func TestMessageFilteringAfterCompaction(t *testing.T) { + // Test that message filtering logic is applied correctly after compaction + // This verifies the fix for the issue where compaction was increasing context size + + // Create test messages that simulate a conversation with a summary + messages := []message.Message{ + {ID: "msg1", Parts: []message.ContentPart{message.TextContent{Text: "Old message 1"}}}, + {ID: "msg2", Parts: []message.ContentPart{message.TextContent{Text: "Old message 2"}}}, + {ID: "summary", Parts: []message.ContentPart{message.TextContent{Text: "Summary of conversation"}}}, + {ID: "msg3", Parts: []message.ContentPart{message.TextContent{Text: "New message after summary"}}}, + } + + // Simulate finding the summary message and filtering + summaryMsgIndex := -1 + for i, msg := range messages { + if msg.ID == "summary" { + summaryMsgIndex = i + break + } + } + + if summaryMsgIndex == -1 { + t.Fatal("summary message not found") + } + + // Apply the filtering logic (same as in the agent code) + filteredMessages := messages[summaryMsgIndex:] + filteredMessages[0].Role = message.User + + // Verify that filtering worked correctly + if len(filteredMessages) != 2 { + t.Errorf("expected 2 filtered messages, got %d", len(filteredMessages)) + } + + if filteredMessages[0].ID != "summary" { + t.Errorf("expected first message to be summary, got %s", filteredMessages[0].ID) + } + + if filteredMessages[0].Role != message.User { + t.Errorf("expected summary message role to be User, got %s", filteredMessages[0].Role) + } + + if filteredMessages[1].ID != "msg3" { + t.Errorf("expected second message to be msg3, got %s", filteredMessages[1].ID) + } +} + +func TestFilterMessagesFromSummary(t *testing.T) { + a := &agent{} + + tests := []struct { + name string + messages []message.Message + summaryMessageID string + expectedCount int + expectedFirstID string + expectedRole message.MessageRole + }{ + { + name: "No summary message ID", + messages: []message.Message{ + {ID: "msg1", Role: message.Assistant}, + {ID: "msg2", Role: message.User}, + }, + summaryMessageID: "", + expectedCount: 2, + expectedFirstID: "msg1", + expectedRole: message.Assistant, + }, + { + name: "Summary message exists", + messages: []message.Message{ + {ID: "msg1", Role: message.Assistant}, + {ID: "summary", Role: message.Assistant}, + {ID: "msg3", Role: message.User}, + }, + summaryMessageID: "summary", + expectedCount: 2, + expectedFirstID: "summary", + expectedRole: message.User, // Should be converted to User + }, + { + name: "Summary message not found", + messages: []message.Message{ + {ID: "msg1", Role: message.Assistant}, + {ID: "msg2", Role: message.User}, + }, + summaryMessageID: "nonexistent", + expectedCount: 2, + expectedFirstID: "msg1", + expectedRole: message.Assistant, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := a.filterMessagesFromSummary(tt.messages, tt.summaryMessageID) + + if len(result) != tt.expectedCount { + t.Errorf("expected %d messages, got %d", tt.expectedCount, len(result)) + } + + if len(result) > 0 { + if result[0].ID != tt.expectedFirstID { + t.Errorf("expected first message ID %s, got %s", tt.expectedFirstID, result[0].ID) + } + + if result[0].Role != tt.expectedRole { + t.Errorf("expected first message role %s, got %s", tt.expectedRole, result[0].Role) + } + } + }) + } +} + +type testProvider struct { + model models.Model +} + +func (tp *testProvider) Model() models.Model { + return tp.model +} + +func (tp *testProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*provider.ProviderResponse, error) { + return &provider.ProviderResponse{ + Content: "Test summary of the conversation", + }, nil +} + +func (tp *testProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan provider.ProviderEvent { + return nil +} +