From 38feba778cb907665565a48ea6385fcbd8b2c62c Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Fri, 11 Jul 2025 17:26:29 +0400 Subject: [PATCH 1/9] Major additions: - Complete xAI provider with 8 Grok models (2/3/4 series) - Deferred completion support for long-running requests - Concurrent request handling for improved performance - Web search tool integration for real-time information - Image generation and vision capabilities - Advanced reasoning support with model-specific validation Core provider features: - Full streaming support with proper error handling - Intelligent model selection and configuration - Robust caching and retry mechanisms - Provider-specific option handling (reasoning, deferred, concurrent) - Comprehensive test coverage for all capabilities Configuration enhancements: - Runtime options support via CLI flags (--deferred, --deferred-timeout) - Provider-level deferred completion configuration - Auto-enable rules for smart feature activation - Model-specific reasoning effort validation (xAI: low/high only) - Enhanced agent configuration with override capabilities Tool ecosystem expansion: - Web search tool for real-time information retrieval - Enhanced coder agent toolset with web capabilities - Improved tool validation and error handling - Better integration with existing LSP and file tools Documentation and schema updates: - Comprehensive README updates with xAI model details - Updated configuration schema with new provider options - Enhanced CLI help text and examples - Added feature capability documentation --- .claude/settings.local.json | 11 + README.md | 33 +- cmd/root.go | 24 +- internal/app/app.go | 16 +- internal/config/config.go | 130 +- internal/llm/agent/agent.go | 124 +- internal/llm/agent/tools.go | 43 +- internal/llm/agent/tools_test.go | 235 +++ internal/llm/models/models.go | 25 +- internal/llm/models/xai.go | 139 +- internal/llm/provider/concurrent.go | 151 ++ internal/llm/provider/concurrent_test.go | 320 +++++ internal/llm/provider/image_validation.go | 58 + internal/llm/provider/openai.go | 157 +- internal/llm/provider/provider.go | 24 +- internal/llm/provider/xai.go | 515 +++++++ internal/llm/provider/xai_caching_test.go | 311 ++++ internal/llm/provider/xai_deferred.go | 533 +++++++ internal/llm/provider/xai_deferred_test.go | 331 +++++ internal/llm/provider/xai_image_generation.go | 212 +++ .../llm/provider/xai_image_generation_test.go | 245 ++++ internal/llm/provider/xai_live_search_test.go | 433 ++++++ internal/llm/provider/xai_models.go | 292 ++++ internal/llm/provider/xai_streaming.go | 123 ++ internal/llm/provider/xai_streaming_test.go | 260 ++++ internal/llm/provider/xai_test.go | 1267 +++++++++++++++++ internal/llm/provider/xai_validation.go | 242 ++++ internal/llm/provider/xai_validation_test.go | 279 ++++ internal/llm/provider/xai_vision_test.go | 220 +++ internal/llm/tools/web_search.go | 337 +++++ internal/llm/tools/web_search_test.go | 485 +++++++ internal/tui/components/chat/editor.go | 32 +- internal/tui/components/dialog/filepicker.go | 21 +- internal/tui/components/dialog/models.go | 10 +- opencode-schema.json | 54 +- 35 files changed, 7540 insertions(+), 152 deletions(-) create mode 100644 .claude/settings.local.json create mode 100644 internal/llm/agent/tools_test.go create mode 100644 internal/llm/provider/concurrent.go create mode 100644 internal/llm/provider/concurrent_test.go create mode 100644 internal/llm/provider/image_validation.go create mode 100644 internal/llm/provider/xai.go create mode 100644 internal/llm/provider/xai_caching_test.go create mode 100644 internal/llm/provider/xai_deferred.go create mode 100644 internal/llm/provider/xai_deferred_test.go create mode 100644 internal/llm/provider/xai_image_generation.go create mode 100644 internal/llm/provider/xai_image_generation_test.go create mode 100644 internal/llm/provider/xai_live_search_test.go create mode 100644 internal/llm/provider/xai_models.go create mode 100644 internal/llm/provider/xai_streaming.go create mode 100644 internal/llm/provider/xai_streaming_test.go create mode 100644 internal/llm/provider/xai_test.go create mode 100644 internal/llm/provider/xai_validation.go create mode 100644 internal/llm/provider/xai_validation_test.go create mode 100644 internal/llm/provider/xai_vision_test.go create mode 100644 internal/llm/tools/web_search.go create mode 100644 internal/llm/tools/web_search_test.go diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 00000000..eb295133 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,11 @@ +{ + "permissions": { + "allow": [ + "Bash(go fmt:*)", + "Bash(go vet:*)", + "Bash(go build:*)", + "Bash(go test:*)" + ], + "deny": [] + } +} \ No newline at end of file diff --git a/README.md b/README.md index eee06acd..dec333f0 100644 --- a/README.md +++ b/README.md @@ -23,9 +23,11 @@ OpenCode is a Go-based CLI application that brings AI assistance to your termina ## Features - **Interactive TUI**: Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) for a smooth terminal experience -- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, Google Gemini, AWS Bedrock, Groq, Azure OpenAI, and OpenRouter +- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, Google Gemini, AWS Bedrock, Groq, Azure OpenAI, xAI, and OpenRouter - **Session Management**: Save and manage multiple conversation sessions - **Tool Integration**: AI can execute commands, search files, and modify code +- **Image Recognition**: Support for analyzing images with vision-enabled models (xAI Grok) +- **Web Search**: Real-time web search capabilities with supported models - **Vim-like Editor**: Integrated editor with text input capabilities - **Persistent Storage**: SQLite database for storing conversations and sessions - **LSP Integration**: Language Server Protocol support for code intelligence @@ -105,6 +107,7 @@ You can configure OpenCode using environment variables: | `VERTEXAI_PROJECT` | For Google Cloud VertexAI (Gemini) | | `VERTEXAI_LOCATION` | For Google Cloud VertexAI (Gemini) | | `GROQ_API_KEY` | For Groq models | +| `XAI_API_KEY` | For xAI Grok models | | `AWS_ACCESS_KEY_ID` | For AWS Bedrock (Claude) | | `AWS_SECRET_ACCESS_KEY` | For AWS Bedrock (Claude) | | `AWS_REGION` | For AWS Bedrock (Claude) | @@ -244,6 +247,26 @@ OpenCode supports a variety of AI models from different providers: - Gemini 2.0 Flash - Gemini 2.0 Flash Lite +### xAI + +- Grok 4 (grok-4-0709) - Most capable, with reasoning_effort support +- Grok 3 (grok-3) - Advanced model (no reasoning support) +- Grok 3 Fast (grok-3-fast) - Optimized for speed (no reasoning support) +- Grok 3 Mini (grok-3-mini) - Smaller model with reasoning_effort support +- Grok 3 Mini Fast (grok-3-mini-fast) - Fastest with reasoning_effort support +- Grok 2 (grok-2-1212) - General purpose (no reasoning support) +- Grok 2 Vision (grok-2-vision-1212) - Vision understanding (no reasoning support) +- Grok 2 Image (grok-2-image-1212) - Image generation model + +**Special Features:** +- **Web Search**: All xAI models support live web search for current information +- **Reasoning Support** (verified via API): + - Grok 4 (grok-4-0709): Has automatic reasoning (returns reasoning_content) but does NOT accept `reasoningEffort` parameter + - Grok 3 Mini models: Support `reasoningEffort` parameter (only "low" or "high", not "medium") + - Grok 2 models, Grok 3/3-fast: No reasoning support +- **Vision Support**: grok-2-vision-1212 supports image understanding +- **Image Generation**: grok-2-image-1212 supports image generation + ### AWS Bedrock - Claude 3.7 Sonnet @@ -283,6 +306,13 @@ opencode -d opencode -c /path/to/project ``` +## Documentation + +- [Image Recognition](docs/image-recognition.md) - Guide for using vision-enabled models with images +- [Web Search](docs/web-search.md) - Using web search capabilities with supported models +- [Custom Commands](docs/custom-commands.md) - Creating custom commands with named arguments +- [Configuration](docs/configuration.md) - Detailed configuration options + ## Non-interactive Prompt Mode You can run OpenCode in non-interactive mode by passing a prompt directly as a command-line argument. This is useful for scripting, automation, or when you want a quick answer without launching the full TUI. @@ -418,6 +448,7 @@ OpenCode's AI assistant has access to various tools to help with coding tasks: | `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) | | `sourcegraph` | Search code across public repositories | `query` (required), `count` (optional), `context_window` (optional), `timeout` (optional) | | `agent` | Run sub-tasks with the AI agent | `prompt` (required) | +| `web_search` | Search the web (xAI models only) | `query` (required) - Note: Automatically used by xAI models when needed | ## Architecture diff --git a/cmd/root.go b/cmd/root.go index 3a58cec4..39793450 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -26,7 +26,15 @@ var rootCmd = &cobra.Command{ Short: "Terminal-based AI assistant for software development", Long: `OpenCode is a powerful terminal-based AI assistant that helps with software development tasks. It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration -to assist developers in writing, debugging, and understanding code directly from the terminal.`, +to assist developers in writing, debugging, and understanding code directly from the terminal. + +Key Features: +- Interactive AI chat with multiple model providers (OpenAI, Anthropic, xAI, etc.) +- Code analysis and editing capabilities +- LSP (Language Server Protocol) integration +- Web search support for current information (xAI models) +- File system operations and project navigation +- Multi-turn conversations with context retention`, Example: ` # Run in interactive mode opencode @@ -63,6 +71,8 @@ to assist developers in writing, debugging, and understanding code directly from prompt, _ := cmd.Flags().GetString("prompt") outputFormat, _ := cmd.Flags().GetString("output-format") quiet, _ := cmd.Flags().GetBool("quiet") + deferred, _ := cmd.Flags().GetBool("deferred") + deferredTimeout, _ := cmd.Flags().GetString("deferred-timeout") // Validate format option if !format.IsValid(outputFormat) { @@ -97,7 +107,13 @@ to assist developers in writing, debugging, and understanding code directly from ctx, cancel := context.WithCancel(context.Background()) defer cancel() - app, err := app.New(ctx, conn) + // Create runtime options from CLI flags + runtimeOpts := app.RuntimeOptions{ + DeferredEnabled: deferred, + DeferredTimeout: deferredTimeout, + } + + app, err := app.New(ctx, conn, runtimeOpts) if err != nil { logging.Error("Failed to create app: %v", err) return err @@ -302,6 +318,10 @@ func init() { // Add quiet flag to hide spinner in non-interactive mode rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode") + // Deferred completion flags + rootCmd.Flags().Bool("deferred", false, "Enable deferred completions for xAI models (useful for long-running requests)") + rootCmd.Flags().String("deferred-timeout", "10m", "Timeout for deferred completions (e.g., '5m', '30s')") + // Register custom validation for the format flag rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp diff --git a/internal/app/app.go b/internal/app/app.go index abdc1431..f34e803a 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -37,9 +37,18 @@ type App struct { watcherCancelFuncs []context.CancelFunc cancelFuncsMutex sync.Mutex watcherWG sync.WaitGroup + + // Runtime options (e.g., from CLI flags) + RuntimeOptions RuntimeOptions +} + +// RuntimeOptions contains runtime configuration options (e.g., from CLI flags) +type RuntimeOptions struct { + DeferredEnabled bool + DeferredTimeout string } -func New(ctx context.Context, conn *sql.DB) (*App, error) { +func New(ctx context.Context, conn *sql.DB, opts ...RuntimeOptions) (*App, error) { q := db.New(conn) sessions := session.NewService(q) messages := message.NewService(q) @@ -53,6 +62,11 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { LSPClients: make(map[string]*lsp.Client), } + // Apply runtime options if provided + if len(opts) > 0 { + app.RuntimeOptions = opts[0] + } + // Initialize theme based on configuration app.initTheme() diff --git a/internal/config/config.go b/internal/config/config.go index 630fac9b..221f4911 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,15 +45,32 @@ const ( // Agent defines configuration for different LLM models and their token limits. type Agent struct { - Model models.ModelID `json:"model"` - MaxTokens int64 `json:"maxTokens"` - ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh + Model models.ModelID `json:"model"` + MaxTokens int64 `json:"maxTokens"` + ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh + DeferredCompletion *bool `json:"deferredCompletion,omitempty"` // Override provider setting for this agent } // Provider defines configuration for an LLM provider. type Provider struct { - APIKey string `json:"apiKey"` - Disabled bool `json:"disabled"` + APIKey string `json:"apiKey"` + Disabled bool `json:"disabled"` + MaxConcurrentRequests int64 `json:"maxConcurrentRequests,omitempty"` // For providers that support concurrent requests (e.g., xAI) + DeferredCompletion *DeferredCompletionConfig `json:"deferredCompletion,omitempty"` // For providers that support deferred completions (e.g., xAI) +} + +// DeferredCompletionConfig defines settings for deferred completions +type DeferredCompletionConfig struct { + Enabled bool `json:"enabled"` // Enable deferred completions + Timeout string `json:"timeout,omitempty"` // Timeout duration (e.g., "10m") + PollInterval string `json:"pollInterval,omitempty"` // Poll interval duration (e.g., "10s") + AutoEnable *DeferredAutoEnableConfig `json:"autoEnable,omitempty"` // Smart activation rules +} + +// DeferredAutoEnableConfig defines rules for automatically enabling deferred completions +type DeferredAutoEnableConfig struct { + ForModels []string `json:"forModels,omitempty"` // Enable for specific models + WhenTokensExceed int64 `json:"whenTokensExceed,omitempty"` // Enable when max tokens exceed this value } // Data defines storage configuration. @@ -207,10 +224,10 @@ func Load(workingDir string, debug bool) (*Config, error) { cfg.Agents = make(map[AgentName]Agent) } - // Override the max tokens for title agent - cfg.Agents[AgentTitle] = Agent{ - Model: cfg.Agents[AgentTitle].Model, - MaxTokens: 80, + // Override the max tokens for title agent to ensure concise titles + if titleAgent, exists := cfg.Agents[AgentTitle]; exists { + titleAgent.MaxTokens = 80 + cfg.Agents[AgentTitle] = titleAgent } return cfg, nil } @@ -351,10 +368,10 @@ func setProviderDefaults() { // XAI configuration if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.XAIGrok3Beta) - viper.SetDefault("agents.summarizer.model", models.XAIGrok3Beta) - viper.SetDefault("agents.task.model", models.XAIGrok3Beta) - viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta) + viper.SetDefault("agents.coder.model", models.XAIGrok2) + viper.SetDefault("agents.summarizer.model", models.XAIGrok2) + viper.SetDefault("agents.task.model", models.XAIGrok2) + viper.SetDefault("agents.title.model", models.XAIGrok3MiniFast) return } @@ -471,6 +488,7 @@ func applyDefaultValues() { } } +// validateAgent ensures that the agent configuration is valid and supported. // It validates model IDs and providers, ensuring they are supported. func validateAgent(cfg *Config, name AgentName, agent Agent) error { // Check if model exists @@ -563,31 +581,12 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error { } // Validate reasoning effort for models that support reasoning - if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal { - if agent.ReasoningEffort == "" { - // Set default reasoning effort for models that support it - logging.Info("setting default reasoning effort for model that supports reasoning", - "agent", name, - "model", agent.Model) - - // Update the agent with default reasoning effort + if model.CanReason && isReasoningProvider(provider) { + validatedEffort := validateReasoningEffort(agent.ReasoningEffort, provider, string(name), string(agent.Model)) + if validatedEffort != agent.ReasoningEffort { updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "medium" + updatedAgent.ReasoningEffort = validatedEffort cfg.Agents[name] = updatedAgent - } else { - // Check if reasoning effort is valid (low, medium, high) - effort := strings.ToLower(agent.ReasoningEffort) - if effort != "low" && effort != "medium" && effort != "high" { - logging.Warn("invalid reasoning effort, setting to medium", - "agent", name, - "model", agent.Model, - "reasoning_effort", agent.ReasoningEffort) - - // Update the agent with valid reasoning effort - updatedAgent := cfg.Agents[name] - updatedAgent.ReasoningEffort = "medium" - cfg.Agents[name] = updatedAgent - } } } else if !model.CanReason && agent.ReasoningEffort != "" { // Model doesn't support reasoning but reasoning effort is set @@ -929,6 +928,65 @@ func UpdateTheme(themeName string) error { }) } +// isReasoningProvider checks if the provider supports reasoning effort configuration +func isReasoningProvider(provider models.ModelProvider) bool { + return provider == models.ProviderOpenAI || + provider == models.ProviderLocal || + provider == models.ProviderXAI +} + +// validateReasoningEffort validates and potentially adjusts the reasoning effort +// based on provider-specific constraints +func validateReasoningEffort(effort string, provider models.ModelProvider, agentName, modelName string) string { + if effort == "" { + // Set default reasoning effort + defaultEffort := "medium" + if provider == models.ProviderXAI { + defaultEffort = "high" // xAI doesn't support "medium" + } + logging.Info("setting default reasoning effort", + "agent", agentName, + "model", modelName, + "default", defaultEffort) + return defaultEffort + } + + // Normalize to lowercase + normalizedEffort := strings.ToLower(effort) + + // Provider-specific validation + if provider == models.ProviderXAI { + // xAI only supports "low" and "high" + switch normalizedEffort { + case "low", "high": + return normalizedEffort + case "medium": + logging.Info("xAI only supports low/high reasoning effort, mapping medium to high", + "agent", agentName, + "model", modelName) + return "high" + default: + logging.Warn("invalid reasoning effort for xAI, using high", + "agent", agentName, + "model", modelName, + "provided", effort) + return "high" + } + } + + // Standard validation for other providers + switch normalizedEffort { + case "low", "medium", "high": + return normalizedEffort + default: + logging.Warn("invalid reasoning effort, using medium", + "agent", agentName, + "model", modelName, + "provided", effort) + return "medium" + } +} + // Tries to load Github token from all possible locations func LoadGitHubToken() (string, error) { // First check environment variable diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 20b10fd3..34626a7b 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -210,7 +210,10 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac go func() { logging.Debug("Request started", "sessionID", sessionID) defer logging.RecoverPanic("agent.Run", func() { - events <- a.err(fmt.Errorf("panic while running the agent")) + events <- AgentEvent{ + Type: AgentEventTypeError, + Error: fmt.Errorf("panic while running the agent"), + } }) var attachmentParts []message.ContentPart for _, attachment := range attachments { @@ -239,7 +242,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } if len(msgs) == 0 { go func() { - defer logging.RecoverPanic("agent.Run", func() { + defer logging.RecoverPanic("agent.generateTitle", func() { logging.ErrorPersist("panic while generating title") }) titleErr := a.generateTitle(context.Background(), sessionID, content) @@ -370,7 +373,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg tool = availableTool break } - // Monkey patch for Copilot Sonnet-4 tool repetition obfuscation + // TODO: Handle Copilot Sonnet-4 tool name repetition if needed // if strings.HasPrefix(toolCall.Name, availableTool.Info().Name) && // strings.HasPrefix(toolCall.Name, availableTool.Info().Name+availableTool.Info().Name) { // tool = availableTool @@ -731,20 +734,51 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), provider.WithMaxTokens(maxTokens), } - if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason { - opts = append( - opts, - provider.WithOpenAIOptions( - provider.WithReasoningEffort(agentConfig.ReasoningEffort), - ), - ) - } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder { - opts = append( - opts, - provider.WithAnthropicOptions( - provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn), - ), - ) + // Configure reasoning support based on provider + if model.CanReason { + switch model.Provider { + case models.ProviderOpenAI, models.ProviderLocal, models.ProviderXAI: + opts = append(opts, + provider.WithOpenAIOptions( + provider.WithReasoningEffort(agentConfig.ReasoningEffort), + ), + ) + case models.ProviderAnthropic: + if agentName == config.AgentCoder { + opts = append(opts, + provider.WithAnthropicOptions( + provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn), + ), + ) + } + } + } + + // Configure xAI-specific options + if model.Provider == models.ProviderXAI { + var xaiOpts []provider.XAIOption + + // Configure concurrent requests if specified + if providerCfg.MaxConcurrentRequests > 0 { + xaiOpts = append(xaiOpts, provider.WithMaxConcurrentRequests(providerCfg.MaxConcurrentRequests)) + } + + // Configure deferred completions + if shouldEnableDeferred(model, agentConfig, providerCfg) { + xaiOpts = append(xaiOpts, provider.WithDeferredCompletion()) + + // Add custom timeout/interval if configured + if providerCfg.DeferredCompletion != nil { + timeout, interval := parseDeferredTimings(providerCfg.DeferredCompletion) + if timeout > 0 && interval > 0 { + xaiOpts = append(xaiOpts, provider.WithDeferredOptions(timeout, interval)) + } + } + } + + if len(xaiOpts) > 0 { + opts = append(opts, provider.WithXAIOptions(xaiOpts...)) + } } agentProvider, err := provider.NewProvider( model.Provider, @@ -756,3 +790,59 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) return agentProvider, nil } + +// shouldEnableDeferred determines if deferred completions should be enabled +func shouldEnableDeferred(model models.Model, agentConfig config.Agent, providerCfg config.Provider) bool { + // Check agent-level override first + if agentConfig.DeferredCompletion != nil { + return *agentConfig.DeferredCompletion + } + + // Check provider-level configuration + if providerCfg.DeferredCompletion == nil { + return false + } + + // If explicitly enabled/disabled + if providerCfg.DeferredCompletion.Enabled { + return true + } + + // Check auto-enable rules + if providerCfg.DeferredCompletion.AutoEnable != nil { + autoEnable := providerCfg.DeferredCompletion.AutoEnable + + // Check if model is in the auto-enable list + for _, modelID := range autoEnable.ForModels { + if string(model.ID) == modelID { + return true + } + } + + // Check if max tokens exceed threshold + if autoEnable.WhenTokensExceed > 0 && agentConfig.MaxTokens > autoEnable.WhenTokensExceed { + return true + } + } + + return false +} + +// parseDeferredTimings parses timeout and poll interval from configuration +func parseDeferredTimings(cfg *config.DeferredCompletionConfig) (time.Duration, time.Duration) { + var timeout, interval time.Duration + + if cfg.Timeout != "" { + if t, err := time.ParseDuration(cfg.Timeout); err == nil { + timeout = t + } + } + + if cfg.PollInterval != "" { + if i, err := time.ParseDuration(cfg.PollInterval); err == nil { + interval = i + } + } + + return timeout, interval +} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index e6b0119a..693e4047 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -11,6 +11,8 @@ import ( "github.com/opencode-ai/opencode/internal/session" ) +// CoderAgentTools returns the complete set of tools available to the coder agent. +// This includes file manipulation, code search, LSP integration, and web search capabilities. func CoderAgentTools( permissions permission.Service, sessions session.Service, @@ -19,25 +21,32 @@ func CoderAgentTools( lspClients map[string]*lsp.Client, ) []tools.BaseTool { ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) + + // Base tools available to all coder agents + baseTools := []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewPatchTool(lspClients, permissions, history), + tools.NewWriteTool(lspClients, permissions, history), + NewAgentTool(sessions, messages, lspClients), + &tools.WebSearchTool{}, // Enables web search for compatible providers (e.g., xAI) + } + + // Add MCP tools if available + mcpTools := GetMcpTools(ctx, permissions) + + // Add diagnostics tool if LSP clients are configured if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + mcpTools = append(mcpTools, tools.NewDiagnosticsTool(lspClients)) } - return append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions, history), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewPatchTool(lspClients, permissions, history), - tools.NewWriteTool(lspClients, permissions, history), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ) + + return append(baseTools, mcpTools...) } func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { diff --git a/internal/llm/agent/tools_test.go b/internal/llm/agent/tools_test.go new file mode 100644 index 00000000..07cbc7ac --- /dev/null +++ b/internal/llm/agent/tools_test.go @@ -0,0 +1,235 @@ +package agent + +import ( + "context" + "testing" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/history" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/lsp" + "github.com/opencode-ai/opencode/internal/message" + "github.com/opencode-ai/opencode/internal/permission" + "github.com/opencode-ai/opencode/internal/pubsub" + "github.com/opencode-ai/opencode/internal/session" +) + +func TestCoderAgentToolsIncludesWebSearch(t *testing.T) { + // Setup test configuration + setupTestConfig(t) + + // Create mock services + mockPermissions := &mockPermissionService{} + mockSessions := &mockSessionService{} + mockMessages := &mockMessageService{} + mockHistory := &mockHistoryService{} + lspClients := make(map[string]*lsp.Client) + + // Get coder agent tools + agentTools := CoderAgentTools( + mockPermissions, + mockSessions, + mockMessages, + mockHistory, + lspClients, + ) + + // Check if web_search tool is included + found := false + var webSearchTool interface{} + for _, tool := range agentTools { + info := tool.Info() + if info.Name == "web_search" { + found = true + webSearchTool = tool + break + } + } + + if !found { + t.Error("CoderAgentTools should include web_search tool") + } + + // Additional validation + if webSearchTool != nil { + // Type assert to verify it implements the correct interface + if baseTool, ok := webSearchTool.(interface { + Info() tools.ToolInfo + Run(context.Context, tools.ToolCall) (tools.ToolResponse, error) + }); ok { + info := baseTool.Info() + if info.Name != "web_search" { + t.Errorf("Expected tool name 'web_search', got '%s'", info.Name) + } + if info.Description == "" { + t.Error("Web search tool should have a description") + } + // Verify it has properties with query parameter + properties, ok := info.Parameters["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Parameters should have 'properties' field") + } + if _, hasQuery := properties["query"]; !hasQuery { + t.Error("Web search tool should have 'query' parameter in properties") + } + // Verify required fields + if len(info.Required) == 0 || info.Required[0] != "query" { + t.Error("Web search tool should require 'query' parameter") + } + } else { + t.Error("Web search tool does not implement BaseTool interface correctly") + } + } +} + +// Mock implementations for testing +type mockPermissionService struct{} + +func (m *mockPermissionService) Subscribe(ctx context.Context) <-chan pubsub.Event[permission.PermissionRequest] { + ch := make(chan pubsub.Event[permission.PermissionRequest]) + close(ch) + return ch +} + +func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {} + +func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {} + +func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {} + +func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { + return true +} + +func (m *mockPermissionService) AutoApproveSession(sessionID string) {} + +// Mock session service +type mockSessionService struct{} + +func (m *mockSessionService) Subscribe(ctx context.Context) <-chan pubsub.Event[session.Session] { + ch := make(chan pubsub.Event[session.Session]) + close(ch) + return ch +} + +func (m *mockSessionService) Create(ctx context.Context, title string) (session.Session, error) { + return session.Session{ID: "test-session", Title: title}, nil +} + +func (m *mockSessionService) CreateTitleSession(ctx context.Context, parentSessionID string) (session.Session, error) { + return session.Session{ID: "test-title-session", ParentSessionID: parentSessionID}, nil +} + +func (m *mockSessionService) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (session.Session, error) { + return session.Session{ID: "test-task-session", ParentSessionID: parentSessionID, Title: title}, nil +} + +func (m *mockSessionService) Get(ctx context.Context, id string) (session.Session, error) { + return session.Session{ID: id}, nil +} + +func (m *mockSessionService) List(ctx context.Context) ([]session.Session, error) { + return []session.Session{}, nil +} + +func (m *mockSessionService) Save(ctx context.Context, session session.Session) (session.Session, error) { + return session, nil +} + +func (m *mockSessionService) Delete(ctx context.Context, id string) error { + return nil +} + +// Mock message service +type mockMessageService struct{} + +func (m *mockMessageService) Subscribe(ctx context.Context) <-chan pubsub.Event[message.Message] { + ch := make(chan pubsub.Event[message.Message]) + close(ch) + return ch +} + +func (m *mockMessageService) Create(ctx context.Context, sessionID string, params message.CreateMessageParams) (message.Message, error) { + return message.Message{ID: "test-message", SessionID: sessionID}, nil +} + +func (m *mockMessageService) Update(ctx context.Context, msg message.Message) error { + return nil +} + +func (m *mockMessageService) Get(ctx context.Context, id string) (message.Message, error) { + return message.Message{ID: id}, nil +} + +func (m *mockMessageService) List(ctx context.Context, sessionID string) ([]message.Message, error) { + return []message.Message{}, nil +} + +func (m *mockMessageService) Delete(ctx context.Context, id string) error { + return nil +} + +func (m *mockMessageService) DeleteSessionMessages(ctx context.Context, sessionID string) error { + return nil +} + +// Mock history service +type mockHistoryService struct{} + +func (m *mockHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { + ch := make(chan pubsub.Event[history.File]) + close(ch) + return ch +} + +func (m *mockHistoryService) Create(ctx context.Context, sessionID, path, content string) (history.File, error) { + return history.File{ID: "test-file", SessionID: sessionID, Path: path, Content: content}, nil +} + +func (m *mockHistoryService) CreateVersion(ctx context.Context, sessionID, path, content string) (history.File, error) { + return history.File{ID: "test-file-version", SessionID: sessionID, Path: path, Content: content}, nil +} + +func (m *mockHistoryService) Get(ctx context.Context, id string) (history.File, error) { + return history.File{ID: id}, nil +} + +func (m *mockHistoryService) GetByPathAndSession(ctx context.Context, path, sessionID string) (history.File, error) { + return history.File{Path: path, SessionID: sessionID}, nil +} + +func (m *mockHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { + return []history.File{}, nil +} + +func (m *mockHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { + return []history.File{}, nil +} + +func (m *mockHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { + return file, nil +} + +func (m *mockHistoryService) Delete(ctx context.Context, id string) error { + return nil +} + +func (m *mockHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { + return nil +} + +// setupTestConfig initializes a minimal configuration for testing +func setupTestConfig(t *testing.T) { + t.Helper() + + tmpDir := t.TempDir() + cfg, err := config.Load(tmpDir, false) + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Ensure MCPServers map is initialized to prevent nil pointer + if cfg.MCPServers == nil { + cfg.MCPServers = make(map[string]config.MCPServer) + } +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 2bcb508e..31074e2e 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -8,18 +8,19 @@ type ( ) type Model struct { - ID ModelID `json:"id"` - Name string `json:"name"` - Provider ModelProvider `json:"provider"` - APIModel string `json:"api_model"` - CostPer1MIn float64 `json:"cost_per_1m_in"` - CostPer1MOut float64 `json:"cost_per_1m_out"` - CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` - CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` - ContextWindow int64 `json:"context_window"` - DefaultMaxTokens int64 `json:"default_max_tokens"` - CanReason bool `json:"can_reason"` - SupportsAttachments bool `json:"supports_attachments"` + ID ModelID `json:"id"` + Name string `json:"name"` + Provider ModelProvider `json:"provider"` + APIModel string `json:"api_model"` + CostPer1MIn float64 `json:"cost_per_1m_in"` + CostPer1MOut float64 `json:"cost_per_1m_out"` + CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` + CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` + SupportsAttachments bool `json:"supports_attachments"` + SupportsImageGeneration bool `json:"supports_image_generation"` } // Model IDs diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go index 00caf3b8..95877e0d 100644 --- a/internal/llm/models/xai.go +++ b/internal/llm/models/xai.go @@ -1,61 +1,144 @@ package models +// xAI Model Capabilities (verified via API testing): +// - Reasoning support: +// - grok-4-0709: Has reasoning (returns reasoning_content) but does NOT accept reasoning_effort parameter +// - grok-3-mini, grok-3-mini-fast: Support reasoning_effort parameter ("low" or "high" only, NOT "medium") +// - grok-2 models, grok-3, grok-3-fast: No reasoning support +// - Vision support: grok-2-vision-1212 and grok-4 support image understanding +// - Image generation: grok-2-image and potentially grok-4 support image generation +// - Web search: All models support web search via tools +// - Note: Reasoning models cannot use presencePenalty, frequencyPenalty, or stop parameters + const ( ProviderXAI ModelProvider = "xai" - XAIGrok3Beta ModelID = "grok-3-beta" - XAIGrok3MiniBeta ModelID = "grok-3-mini-beta" - XAIGrok3FastBeta ModelID = "grok-3-fast-beta" - XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta" + // Current xAI models (from API as of 2025) + XAIGrok2 ModelID = "grok-2-1212" + XAIGrok2Vision ModelID = "grok-2-vision-1212" + XAIGrok2Image ModelID = "grok-2-image-1212" + XAIGrok3 ModelID = "grok-3" + XAIGrok3Fast ModelID = "grok-3-fast" + XAIGrok3Mini ModelID = "grok-3-mini" + XAIGrok3MiniFast ModelID = "grok-3-mini-fast" + XAIGrok4 ModelID = "grok-4-0709" ) var XAIModels = map[ModelID]Model{ - XAIGrok3Beta: { - ID: XAIGrok3Beta, - Name: "Grok3 Beta", + XAIGrok2: { + ID: XAIGrok2, + Name: "Grok 2", + Provider: ProviderXAI, + APIModel: "grok-2-1212", + CostPer1MIn: 2.0, // $2 per million input tokens + CostPer1MInCached: 0, + CostPer1MOut: 10.0, // $10 per million output tokens + CostPer1MOutCached: 0, + ContextWindow: 131_072, + DefaultMaxTokens: 20_000, + CanReason: false, // No reasoning support + // Capabilities: streaming, function calling, structured outputs, web search + }, + XAIGrok2Vision: { + ID: XAIGrok2Vision, + Name: "Grok 2 Vision", + Provider: ProviderXAI, + APIModel: "grok-2-vision-1212", + CostPer1MIn: 2.0, // $2 per million input tokens + CostPer1MInCached: 0, + CostPer1MOut: 10.0, // $10 per million output tokens + CostPer1MOutCached: 0, + ContextWindow: 8_192, + DefaultMaxTokens: 4_096, + SupportsAttachments: true, + CanReason: false, // No reasoning support + // Capabilities: image understanding, streaming, web search + }, + XAIGrok2Image: { + ID: XAIGrok2Image, + Name: "Grok 2 Image", + Provider: ProviderXAI, + APIModel: "grok-2-image-1212", + CostPer1MIn: 2.0, // Assuming same as Grok 2 + CostPer1MInCached: 0, + CostPer1MOut: 10.0, // Assuming same as Grok 2 + CostPer1MOutCached: 0, + ContextWindow: 8_192, + DefaultMaxTokens: 4_096, + SupportsAttachments: false, // Image generation models don't take image inputs + SupportsImageGeneration: true, + // Capabilities: image generation, web search + }, + XAIGrok3: { + ID: XAIGrok3, + Name: "Grok 3", Provider: ProviderXAI, - APIModel: "grok-3-beta", - CostPer1MIn: 3.0, + APIModel: "grok-3", + CostPer1MIn: 5.0, // Estimated pricing CostPer1MInCached: 0, - CostPer1MOut: 15, + CostPer1MOut: 15.0, // Estimated pricing CostPer1MOutCached: 0, ContextWindow: 131_072, DefaultMaxTokens: 20_000, + CanReason: false, // No reasoning support + // Capabilities: streaming, function calling, structured outputs, web search }, - XAIGrok3MiniBeta: { - ID: XAIGrok3MiniBeta, - Name: "Grok3 Mini Beta", + XAIGrok3Fast: { + ID: XAIGrok3Fast, + Name: "Grok 3 Fast", Provider: ProviderXAI, - APIModel: "grok-3-mini-beta", - CostPer1MIn: 0.3, + APIModel: "grok-3-fast", + CostPer1MIn: 3.0, // Estimated lower pricing for fast variant CostPer1MInCached: 0, - CostPer1MOut: 0.5, + CostPer1MOut: 10.0, // Estimated lower pricing for fast variant CostPer1MOutCached: 0, ContextWindow: 131_072, DefaultMaxTokens: 20_000, + CanReason: false, // No reasoning support + // Capabilities: streaming, function calling, structured outputs, web search }, - XAIGrok3FastBeta: { - ID: XAIGrok3FastBeta, - Name: "Grok3 Fast Beta", + XAIGrok3Mini: { + ID: XAIGrok3Mini, + Name: "Grok 3 Mini", Provider: ProviderXAI, - APIModel: "grok-3-fast-beta", - CostPer1MIn: 5, + APIModel: "grok-3-mini", + CostPer1MIn: 1.0, // Estimated lower pricing for mini CostPer1MInCached: 0, - CostPer1MOut: 25, + CostPer1MOut: 3.0, // Estimated lower pricing for mini CostPer1MOutCached: 0, ContextWindow: 131_072, DefaultMaxTokens: 20_000, + CanReason: true, // Supports reasoning_effort parameter ("low" or "high") + // Capabilities: streaming, function calling, structured outputs, reasoning, web search }, - XAiGrok3MiniFastBeta: { - ID: XAiGrok3MiniFastBeta, - Name: "Grok3 Mini Fast Beta", + XAIGrok3MiniFast: { + ID: XAIGrok3MiniFast, + Name: "Grok 3 Mini Fast", Provider: ProviderXAI, - APIModel: "grok-3-mini-fast-beta", - CostPer1MIn: 0.6, + APIModel: "grok-3-mini-fast", + CostPer1MIn: 0.5, // Estimated lowest pricing CostPer1MInCached: 0, - CostPer1MOut: 4.0, + CostPer1MOut: 1.5, // Estimated lowest pricing CostPer1MOutCached: 0, ContextWindow: 131_072, DefaultMaxTokens: 20_000, + CanReason: true, // Supports reasoning_effort parameter ("low" or "high") + // Capabilities: streaming, function calling, structured outputs, reasoning, web search + }, + XAIGrok4: { + ID: XAIGrok4, + Name: "Grok 4", + Provider: ProviderXAI, + APIModel: "grok-4-0709", + CostPer1MIn: 10.0, // $10 per million input tokens + CostPer1MInCached: 0, + CostPer1MOut: 30.0, // $30 per million output tokens + CostPer1MOutCached: 0, + ContextWindow: 131_072, + DefaultMaxTokens: 20_000, + CanReason: true, // Automatic reasoning (no reasoning_effort parameter) + SupportsAttachments: true, // Grok 4 supports vision + SupportsImageGeneration: false, // Will be detected dynamically via API + // Capabilities: streaming, function calling, structured outputs, automatic reasoning, web search, vision }, } diff --git a/internal/llm/provider/concurrent.go b/internal/llm/provider/concurrent.go new file mode 100644 index 00000000..8a5a8a8f --- /dev/null +++ b/internal/llm/provider/concurrent.go @@ -0,0 +1,151 @@ +package provider + +import ( + "context" + "sync" + + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "golang.org/x/sync/semaphore" +) + +// ConcurrentClient wraps a ProviderClient to add concurrent request handling with rate limiting. +// It uses a semaphore to control the maximum number of concurrent requests and provides +// optional response tracking for monitoring and compliance purposes. +type ConcurrentClient struct { + client ProviderClient + semaphore *semaphore.Weighted + maxInFlight int64 + mu sync.RWMutex + + // Optional callback for tracking responses (e.g., xAI fingerprints) + onResponse func(*ProviderResponse) +} + +// NewConcurrentClient creates a new concurrent client wrapper with the specified max concurrent requests. +// If maxConcurrent is <= 0, it defaults to 10 concurrent requests. +func NewConcurrentClient(client ProviderClient, maxConcurrent int64) *ConcurrentClient { + if maxConcurrent <= 0 { + maxConcurrent = 10 // Default to 10 concurrent requests + } + + return &ConcurrentClient{ + client: client, + semaphore: semaphore.NewWeighted(maxConcurrent), + maxInFlight: maxConcurrent, + } +} + +// Send implements ProviderClient interface with semaphore-based rate limiting +func (c *ConcurrentClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + // Acquire semaphore + if err := c.semaphore.Acquire(ctx, 1); err != nil { + return nil, err + } + defer c.semaphore.Release(1) + + // Forward to underlying client + resp, err := c.client.send(ctx, messages, tools) + + // Call callback if configured + if c.onResponse != nil && resp != nil { + c.onResponse(resp) + } + + return resp, err +} + +// Stream implements ProviderClient interface with semaphore-based rate limiting +func (c *ConcurrentClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + eventChan := make(chan ProviderEvent) + + go func() { + defer close(eventChan) + + // Acquire semaphore + if err := c.semaphore.Acquire(ctx, 1); err != nil { + eventChan <- ProviderEvent{Type: EventError, Error: err} + return + } + defer c.semaphore.Release(1) + + // Forward events from underlying client + for event := range c.client.stream(ctx, messages, tools) { + // Call callback for complete events if configured + if c.onResponse != nil && event.Type == EventComplete && event.Response != nil { + c.onResponse(event.Response) + } + eventChan <- event + } + }() + + return eventChan +} + +// SetMaxConcurrent updates the maximum concurrent requests allowed. +// If max is <= 0, it defaults to 10. +func (c *ConcurrentClient) SetMaxConcurrent(max int64) { + c.mu.Lock() + defer c.mu.Unlock() + + if max <= 0 { + max = 10 + } + + c.maxInFlight = max + c.semaphore = semaphore.NewWeighted(max) +} + +// GetMaxConcurrent returns the current maximum concurrent requests setting +func (c *ConcurrentClient) GetMaxConcurrent() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.maxInFlight +} + +// BatchRequest represents a single request in a batch +type BatchRequest struct { + Messages []message.Message + Tools []tools.BaseTool +} + +// BatchResponse represents the response for a batch request +type BatchResponse struct { + Response *ProviderResponse + Error error + Index int +} + +// SendBatch processes multiple requests concurrently respecting the semaphore limit +func (c *ConcurrentClient) SendBatch(ctx context.Context, requests []BatchRequest) []BatchResponse { + responses := make([]BatchResponse, len(requests)) + var wg sync.WaitGroup + + for i, req := range requests { + wg.Add(1) + go func(index int, request BatchRequest) { + defer wg.Done() + + resp, err := c.send(ctx, request.Messages, request.Tools) + responses[index] = BatchResponse{ + Response: resp, + Error: err, + Index: index, + } + }(i, req) + } + + wg.Wait() + return responses +} + +// StreamBatch processes multiple streaming requests concurrently +func (c *ConcurrentClient) StreamBatch(ctx context.Context, requests []BatchRequest) []<-chan ProviderEvent { + channels := make([]<-chan ProviderEvent, len(requests)) + + for i, req := range requests { + channels[i] = c.stream(ctx, req.Messages, req.Tools) + } + + return channels +} diff --git a/internal/llm/provider/concurrent_test.go b/internal/llm/provider/concurrent_test.go new file mode 100644 index 00000000..ad7585e2 --- /dev/null +++ b/internal/llm/provider/concurrent_test.go @@ -0,0 +1,320 @@ +package provider + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockProviderClient implements ProviderClient for testing +type mockProviderClient struct { + sendCount int32 + streamCount int32 + sendDelay time.Duration + streamDelay time.Duration + sendError error + streamError error + responseFunc func() *ProviderResponse +} + +func (m *mockProviderClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + atomic.AddInt32(&m.sendCount, 1) + if m.sendDelay > 0 { + time.Sleep(m.sendDelay) + } + if m.sendError != nil { + return nil, m.sendError + } + if m.responseFunc != nil { + return m.responseFunc(), nil + } + return &ProviderResponse{ + Content: "test response", + FinishReason: message.FinishReasonEndTurn, + Usage: TokenUsage{ + InputTokens: 10, + OutputTokens: 20, + }, + }, nil +} + +func (m *mockProviderClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + atomic.AddInt32(&m.streamCount, 1) + eventChan := make(chan ProviderEvent) + + go func() { + defer close(eventChan) + + if m.streamDelay > 0 { + time.Sleep(m.streamDelay) + } + + if m.streamError != nil { + eventChan <- ProviderEvent{Type: EventError, Error: m.streamError} + return + } + + // Send some content deltas + eventChan <- ProviderEvent{Type: EventContentDelta, Content: "test "} + eventChan <- ProviderEvent{Type: EventContentDelta, Content: "response"} + + // Send complete event + resp := &ProviderResponse{ + Content: "test response", + FinishReason: message.FinishReasonEndTurn, + Usage: TokenUsage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + if m.responseFunc != nil { + resp = m.responseFunc() + } + eventChan <- ProviderEvent{Type: EventComplete, Response: resp} + }() + + return eventChan +} + +func TestConcurrentClient_Send(t *testing.T) { + tests := []struct { + name string + maxConcurrent int64 + numRequests int + requestDelay time.Duration + expectError bool + }{ + { + name: "single request", + maxConcurrent: 1, + numRequests: 1, + requestDelay: 0, + expectError: false, + }, + { + name: "multiple concurrent requests within limit", + maxConcurrent: 5, + numRequests: 5, + requestDelay: 10 * time.Millisecond, + expectError: false, + }, + { + name: "requests exceed concurrent limit", + maxConcurrent: 2, + numRequests: 10, + requestDelay: 50 * time.Millisecond, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mockProviderClient{ + sendDelay: tt.requestDelay, + } + + concurrentClient := NewConcurrentClient(mockClient, tt.maxConcurrent) + + ctx := context.Background() + messages := []message.Message{} + tools := []tools.BaseTool{} + + var wg sync.WaitGroup + results := make([]*ProviderResponse, tt.numRequests) + errors := make([]error, tt.numRequests) + + start := time.Now() + + for i := 0; i < tt.numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + resp, err := concurrentClient.send(ctx, messages, tools) + results[idx] = resp + errors[idx] = err + }(i) + } + + wg.Wait() + elapsed := time.Since(start) + + // Check all requests completed + for i, err := range errors { + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, results[i]) + assert.Equal(t, "test response", results[i].Content) + } + } + + // Verify semaphore worked by checking timing + if tt.numRequests > int(tt.maxConcurrent) && tt.requestDelay > 0 { + expectedMinTime := tt.requestDelay * time.Duration(tt.numRequests/int(tt.maxConcurrent)) + assert.GreaterOrEqual(t, elapsed, expectedMinTime) + } + + // Verify all requests were made + assert.Equal(t, int32(tt.numRequests), atomic.LoadInt32(&mockClient.sendCount)) + }) + } +} + +func TestConcurrentClient_Stream(t *testing.T) { + mockClient := &mockProviderClient{ + streamDelay: 10 * time.Millisecond, + } + + concurrentClient := NewConcurrentClient(mockClient, 2) + + ctx := context.Background() + messages := []message.Message{} + tools := []tools.BaseTool{} + + // Start multiple streaming requests + numStreams := 5 + var wg sync.WaitGroup + results := make([][]ProviderEvent, numStreams) + + for i := 0; i < numStreams; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + eventChan := concurrentClient.stream(ctx, messages, tools) + var events []ProviderEvent + + for event := range eventChan { + events = append(events, event) + } + + results[idx] = events + }(i) + } + + wg.Wait() + + // Verify all streams completed successfully + for i, events := range results { + require.NotEmpty(t, events, "stream %d should have events", i) + + // Should have content deltas and complete event + var hasContentDelta, hasComplete bool + for _, event := range events { + if event.Type == EventContentDelta { + hasContentDelta = true + } + if event.Type == EventComplete { + hasComplete = true + assert.NotNil(t, event.Response) + assert.Equal(t, "test response", event.Response.Content) + } + } + + assert.True(t, hasContentDelta, "stream %d should have content delta", i) + assert.True(t, hasComplete, "stream %d should have complete event", i) + } + + // Verify all streams were made + assert.Equal(t, int32(numStreams), atomic.LoadInt32(&mockClient.streamCount)) +} + +func TestConcurrentClient_Callback(t *testing.T) { + var callbackCount int32 + var lastResponse *ProviderResponse + + mockClient := &mockProviderClient{ + responseFunc: func() *ProviderResponse { + return &ProviderResponse{ + Content: "test", + SystemFingerprint: "test-fingerprint", + Usage: TokenUsage{ + InputTokens: 5, + OutputTokens: 10, + }, + } + }, + } + + concurrentClient := NewConcurrentClient(mockClient, 1) + concurrentClient.onResponse = func(resp *ProviderResponse) { + atomic.AddInt32(&callbackCount, 1) + lastResponse = resp + } + + ctx := context.Background() + + // Test send callback + resp, err := concurrentClient.send(ctx, nil, nil) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, int32(1), atomic.LoadInt32(&callbackCount)) + assert.Equal(t, "test-fingerprint", lastResponse.SystemFingerprint) + + // Test stream callback + eventChan := concurrentClient.stream(ctx, nil, nil) + var events []ProviderEvent + for event := range eventChan { + events = append(events, event) + } + + assert.Equal(t, int32(2), atomic.LoadInt32(&callbackCount)) + assert.Equal(t, "test-fingerprint", lastResponse.SystemFingerprint) +} + +func TestConcurrentClient_BatchRequests(t *testing.T) { + mockClient := &mockProviderClient{ + sendDelay: 20 * time.Millisecond, + } + + concurrentClient := NewConcurrentClient(mockClient, 3) + + ctx := context.Background() + + // Create batch requests + requests := make([]BatchRequest, 10) + for i := range requests { + requests[i] = BatchRequest{ + Messages: []message.Message{}, + Tools: []tools.BaseTool{}, + } + } + + start := time.Now() + responses := concurrentClient.SendBatch(ctx, requests) + elapsed := time.Since(start) + + // Verify all responses + assert.Len(t, responses, 10) + for i, resp := range responses { + assert.NoError(t, resp.Error) + assert.NotNil(t, resp.Response) + assert.Equal(t, i, resp.Index) + } + + // Verify semaphore worked (10 requests / 3 concurrent = at least 4 batches) + expectedMinTime := 4 * 20 * time.Millisecond + assert.GreaterOrEqual(t, elapsed, expectedMinTime) +} + +func TestConcurrentClient_SetMaxConcurrent(t *testing.T) { + mockClient := &mockProviderClient{} + concurrentClient := NewConcurrentClient(mockClient, 2) + + assert.Equal(t, int64(2), concurrentClient.GetMaxConcurrent()) + + concurrentClient.SetMaxConcurrent(5) + assert.Equal(t, int64(5), concurrentClient.GetMaxConcurrent()) + + // Test with invalid value + concurrentClient.SetMaxConcurrent(0) + assert.Equal(t, int64(10), concurrentClient.GetMaxConcurrent()) // Should default to 10 +} diff --git a/internal/llm/provider/image_validation.go b/internal/llm/provider/image_validation.go new file mode 100644 index 00000000..69d8969d --- /dev/null +++ b/internal/llm/provider/image_validation.go @@ -0,0 +1,58 @@ +package provider + +import ( + "fmt" + "strings" + + "github.com/opencode-ai/opencode/internal/message" +) + +const ( + // MaxImageSize is the maximum allowed image size for xAI (20MiB) + MaxImageSize = 20 * 1024 * 1024 // 20 MiB +) + +// SupportedImageFormats lists the image formats supported by xAI +var SupportedImageFormats = []string{"image/jpeg", "image/jpg", "image/png"} + +// ValidateImageAttachment validates that an image attachment meets xAI requirements +func ValidateImageAttachment(attachment message.Attachment) error { + // Check file size + if len(attachment.Content) > MaxImageSize { + return fmt.Errorf("image size exceeds maximum allowed size of 20MiB (current: %.2fMiB)", + float64(len(attachment.Content))/(1024*1024)) + } + + // Check MIME type + mimeType := strings.ToLower(attachment.MimeType) + supported := false + for _, format := range SupportedImageFormats { + if mimeType == format { + supported = true + break + } + } + + if !supported { + return fmt.Errorf("unsupported image format: %s (supported: %s)", + mimeType, strings.Join(SupportedImageFormats, ", ")) + } + + return nil +} + +// IsVisionModel checks if a model supports image understanding +func IsVisionModel(modelID string) bool { + visionModels := []string{ + "grok-2-vision-1212", + "grok-4-0709", // grok-4 supports vision + } + + for _, vm := range visionModels { + if modelID == vm { + return true + } + } + + return false +} diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 8a561c77..0f956267 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -19,10 +19,13 @@ import ( ) type openaiOptions struct { - baseURL string - disableCache bool - reasoningEffort string - extraHeaders map[string]string + baseURL string + disableCache bool + reasoningEffort string + extraHeaders map[string]string + responseFormat *openai.ChatCompletionNewParamsResponseFormatUnion + toolChoice *openai.ChatCompletionToolChoiceOptionUnionParam + parallelToolCalls *bool } type OpenAIOption func(*openaiOptions) @@ -166,25 +169,68 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar Tools: tools, } - if o.providerOptions.model.CanReason == true { + if o.providerOptions.model.CanReason { params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) - switch o.options.reasoningEffort { - case "low": - params.ReasoningEffort = shared.ReasoningEffortLow - case "medium": - params.ReasoningEffort = shared.ReasoningEffortMedium - case "high": - params.ReasoningEffort = shared.ReasoningEffortHigh - default: - params.ReasoningEffort = shared.ReasoningEffortMedium + + // Determine if reasoning effort parameter should be applied + shouldApplyReasoningEffort := o.shouldApplyReasoningEffort() + + if o.options.reasoningEffort != "" && shouldApplyReasoningEffort { + switch o.options.reasoningEffort { + case "low": + params.ReasoningEffort = shared.ReasoningEffortLow + case "medium": + // xAI only supports "low" and "high", map "medium" to "high" + if o.providerOptions.model.Provider == models.ProviderXAI { + params.ReasoningEffort = shared.ReasoningEffortHigh + } else { + params.ReasoningEffort = shared.ReasoningEffortMedium + } + case "high": + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + // Map invalid values to appropriate defaults + if o.providerOptions.model.Provider == models.ProviderXAI { + params.ReasoningEffort = shared.ReasoningEffortHigh + } else { + params.ReasoningEffort = shared.ReasoningEffortMedium + } + } } } else { params.MaxTokens = openai.Int(o.providerOptions.maxTokens) } + // Add response format if configured + if o.options.responseFormat != nil { + params.ResponseFormat = *o.options.responseFormat + } + + // Add tool choice if configured + if o.options.toolChoice != nil { + params.ToolChoice = *o.options.toolChoice + } + + // Add parallel tool calls setting if configured + if o.options.parallelToolCalls != nil { + params.ParallelToolCalls = openai.Bool(*o.options.parallelToolCalls) + } + return params } +// shouldApplyReasoningEffort determines if the reasoning_effort parameter should be applied +// based on the model and provider. Some models support reasoning but do not accept +// the reasoning_effort parameter (e.g., xAI's grok-4 has automatic reasoning). +func (o *openaiClient) shouldApplyReasoningEffort() bool { + // xAI grok-4 supports reasoning but does not accept reasoning_effort parameter + if o.providerOptions.model.Provider == models.ProviderXAI && + o.providerOptions.model.ID == models.XAIGrok4 { + return false + } + return true +} + func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) cfg := config.Get() @@ -230,10 +276,11 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too } return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: o.usage(*openaiResponse), - FinishReason: finishReason, + Content: content, + ToolCalls: toolCalls, + Usage: o.usage(*openaiResponse), + FinishReason: finishReason, + SystemFingerprint: openaiResponse.SystemFingerprint, }, nil } } @@ -294,10 +341,11 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: o.usage(acc.ChatCompletion), - FinishReason: finishReason, + Content: currentContent, + ToolCalls: toolCalls, + Usage: o.usage(acc.ChatCompletion), + FinishReason: finishReason, + SystemFingerprint: acc.ChatCompletion.SystemFingerprint, }, } close(eventChan) @@ -411,15 +459,78 @@ func WithOpenAIDisableCache() OpenAIOption { } } +func WithOpenAIJSONMode() OpenAIOption { + return func(options *openaiOptions) { + options.responseFormat = &openai.ChatCompletionNewParamsResponseFormatUnion{ + OfJSONObject: &shared.ResponseFormatJSONObjectParam{}, + } + } +} + +func WithOpenAIJSONSchema(name string, schema interface{}) OpenAIOption { + return func(options *openaiOptions) { + options.responseFormat = &openai.ChatCompletionNewParamsResponseFormatUnion{ + OfJSONSchema: &shared.ResponseFormatJSONSchemaParam{ + JSONSchema: shared.ResponseFormatJSONSchemaJSONSchemaParam{ + Name: name, + Schema: schema, + Strict: openai.Bool(true), + }, + }, + } + } +} + func WithReasoningEffort(effort string) OpenAIOption { return func(options *openaiOptions) { + // If effort is empty, don't set it at all + if effort == "" { + options.reasoningEffort = "" + return + } + defaultReasoningEffort := "medium" switch effort { case "low", "medium", "high": defaultReasoningEffort = effort default: - logging.Warn("Invalid reasoning effort, using default: medium") + logging.Warn("Invalid reasoning effort, using default: medium", "provided", effort) } options.reasoningEffort = defaultReasoningEffort } } + +func WithOpenAIToolChoice(choice string) OpenAIOption { + return func(options *openaiOptions) { + // Handle string-based tool choices: "auto", "required", "none" + switch choice { + case "auto", "required", "none": + options.toolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String(choice), + } + default: + logging.Warn("Invalid tool choice, using default: auto", "provided", choice) + options.toolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String("auto"), + } + } + } +} + +func WithOpenAIToolChoiceFunction(functionName string) OpenAIOption { + return func(options *openaiOptions) { + options.toolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{ + OfChatCompletionNamedToolChoice: &openai.ChatCompletionNamedToolChoiceParam{ + Function: openai.ChatCompletionNamedToolChoiceFunctionParam{ + Name: functionName, + }, + }, + } + } +} + +func WithOpenAIParallelToolCalls(enabled bool) OpenAIOption { + return func(options *openaiOptions) { + options.parallelToolCalls = &enabled + } +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index d5be0ba0..2c77bff9 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -35,10 +35,12 @@ type TokenUsage struct { } type ProviderResponse struct { - Content string - ToolCalls []message.ToolCall - Usage TokenUsage - FinishReason message.FinishReason + Content string + ToolCalls []message.ToolCall + Usage TokenUsage + FinishReason message.FinishReason + SystemFingerprint string // For tracking xAI backend configuration changes + Citations []string // For Live Search citations (xAI) } type ProviderEvent struct { @@ -69,6 +71,7 @@ type providerClientOptions struct { geminiOptions []GeminiOption bedrockOptions []BedrockOption copilotOptions []CopilotOption + xaiOptions []XAIOption } type ProviderClientOption func(*providerClientOptions) @@ -145,12 +148,9 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption client: newOpenAIClient(clientOptions), }, nil case models.ProviderXAI: - clientOptions.openaiOptions = append(clientOptions.openaiOptions, - WithOpenAIBaseURL("https://api.x.ai/v1"), - ) - return &baseProvider[OpenAIClient]{ + return &baseProvider[XAIClient]{ options: clientOptions, - client: newOpenAIClient(clientOptions), + client: newXAIClient(clientOptions), }, nil case models.ProviderLocal: clientOptions.openaiOptions = append(clientOptions.openaiOptions, @@ -245,3 +245,9 @@ func WithCopilotOptions(copilotOptions ...CopilotOption) ProviderClientOption { options.copilotOptions = copilotOptions } } + +func WithXAIOptions(xaiOptions ...XAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.xaiOptions = xaiOptions + } +} diff --git a/internal/llm/provider/xai.go b/internal/llm/provider/xai.go new file mode 100644 index 00000000..1f3e6e57 --- /dev/null +++ b/internal/llm/provider/xai.go @@ -0,0 +1,515 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/openai/openai-go" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/logging" + "github.com/opencode-ai/opencode/internal/message" +) + +// FingerprintRecord tracks system fingerprint information for auditing and compliance purposes. +// It helps monitor xAI system changes and optimize caching performance. +type FingerprintRecord struct { + Fingerprint string `json:"fingerprint"` + Timestamp time.Time `json:"timestamp"` + Model string `json:"model"` + TokensUsed TokenUsage `json:"tokens_used"` +} + +// xaiClient wraps the OpenAI client with xAI-specific functionality. +// It provides enhanced features like deferred completions, concurrent request handling, +// Live Search integration, and comprehensive fingerprint tracking for monitoring. +type xaiClient struct { + openaiClient + mu sync.Mutex + lastFingerprint string + fingerprintHistory []FingerprintRecord // For compliance and auditing + concurrent *ConcurrentClient // Optional concurrent request handler + deferredEnabled bool // Enable deferred completions + deferredOptions DeferredOptions // Options for deferred completions + liveSearchEnabled bool // Enable Live Search + liveSearchOptions LiveSearchOptions // Options for Live Search +} + +type XAIClient ProviderClient + +// XAIOption represents xAI-specific configuration options +type XAIOption func(*xaiClient) + +// WithMaxConcurrentRequests configures the maximum number of concurrent requests +func WithMaxConcurrentRequests(maxConcurrent int64) XAIOption { + return func(x *xaiClient) { + x.concurrent = NewConcurrentClient(x, maxConcurrent) + // Set up callback to track fingerprints from concurrent requests + x.concurrent.onResponse = func(resp *ProviderResponse) { + if resp != nil && resp.SystemFingerprint != "" { + x.trackFingerprint(resp.SystemFingerprint, resp.Usage) + } + } + } +} + +// WithDeferredCompletion enables deferred completion mode +func WithDeferredCompletion() XAIOption { + return func(x *xaiClient) { + x.deferredEnabled = true + x.deferredOptions = DefaultDeferredOptions() + } +} + +// WithDeferredOptions configures deferred completion options +func WithDeferredOptions(timeout, interval time.Duration) XAIOption { + return func(x *xaiClient) { + x.deferredOptions = DeferredOptions{ + Timeout: timeout, + Interval: interval, + } + } +} + +// WithLiveSearch enables Live Search with default parameters +func WithLiveSearch() XAIOption { + return func(x *xaiClient) { + x.liveSearchEnabled = true + x.liveSearchOptions = DefaultLiveSearchOptions() + } +} + +// WithLiveSearchOptions enables Live Search with custom parameters +func WithLiveSearchOptions(opts LiveSearchOptions) XAIOption { + return func(x *xaiClient) { + x.liveSearchEnabled = true + x.liveSearchOptions = opts + } +} + +func newXAIClient(opts providerClientOptions) XAIClient { + // Create base OpenAI client with xAI-specific settings + opts.openaiOptions = append(opts.openaiOptions, + WithOpenAIBaseURL("https://api.x.ai/v1"), + ) + + baseClient := newOpenAIClient(opts) + openaiClientImpl := baseClient.(*openaiClient) + + xClient := &xaiClient{ + openaiClient: *openaiClientImpl, + fingerprintHistory: make([]FingerprintRecord, 0), + } + + // Apply xAI-specific options if any + for _, opt := range opts.xaiOptions { + opt(xClient) + } + + return xClient +} + +// shouldApplyReasoningEffort overrides the base implementation for xAI-specific logic +func (x *xaiClient) shouldApplyReasoningEffort() bool { + // xAI grok-4 supports reasoning but does not accept reasoning_effort parameter + if x.providerOptions.model.ID == models.XAIGrok4 { + return false + } + return true +} + +// trackFingerprint records fingerprint for monitoring, security, and compliance +func (x *xaiClient) trackFingerprint(fingerprint string, usage TokenUsage) { + if fingerprint == "" { + return + } + + x.mu.Lock() + defer x.mu.Unlock() + + // Record for audit trail + record := FingerprintRecord{ + Fingerprint: fingerprint, + Timestamp: time.Now(), + Model: string(x.providerOptions.model.ID), + TokensUsed: usage, + } + x.fingerprintHistory = append(x.fingerprintHistory, record) + + // Log for monitoring system changes + if x.lastFingerprint != "" && x.lastFingerprint != fingerprint { + // System configuration changed - important for debugging and performance optimization + logging.Info("xAI system configuration changed", + "previous", x.lastFingerprint, + "current", fingerprint, + "model", x.providerOptions.model.ID, + "timestamp", record.Timestamp.Format(time.RFC3339)) + } + + // Calculate caching efficiency + totalPromptTokens := usage.InputTokens + usage.CacheReadTokens + cacheHitRate := float64(0) + if totalPromptTokens > 0 { + cacheHitRate = float64(usage.CacheReadTokens) / float64(totalPromptTokens) * 100 + } + + // Log enhanced metrics including caching information + logFields := []interface{}{ + "fingerprint", fingerprint, + "model", x.providerOptions.model.ID, + "input_tokens", usage.InputTokens, + "output_tokens", usage.OutputTokens, + "cache_read_tokens", usage.CacheReadTokens, + "cache_creation_tokens", usage.CacheCreationTokens, + "total_prompt_tokens", totalPromptTokens, + "timestamp", record.Timestamp.Format(time.RFC3339), + } + + // Add cache efficiency metrics if caching is happening + if usage.CacheReadTokens > 0 { + logFields = append(logFields, + "cache_hit_rate_percent", cacheHitRate, + "cache_cost_savings", x.calculateCacheCostSavings(usage)) + + logging.Info("xAI prompt caching active", logFields...) + } else { + logging.Debug("xAI API response tracked", logFields...) + } + + x.lastFingerprint = fingerprint +} + +// calculateCacheCostSavings estimates cost savings from prompt caching +func (x *xaiClient) calculateCacheCostSavings(usage TokenUsage) float64 { + // Get model pricing (cost per 1M tokens) + model := x.providerOptions.model + costPer1MIn := model.CostPer1MIn + costPer1MInCached := model.CostPer1MInCached + + // If cached pricing isn't set, assume significant savings (typically 50% discount) + if costPer1MInCached == 0 { + costPer1MInCached = costPer1MIn * 0.5 + } + + // Calculate savings: (regular_cost - cached_cost) * tokens / 1M + if usage.CacheReadTokens > 0 { + regularCost := (costPer1MIn * float64(usage.CacheReadTokens)) / 1_000_000 + cachedCost := (costPer1MInCached * float64(usage.CacheReadTokens)) / 1_000_000 + return regularCost - cachedCost + } + + return 0 +} + +func (x *xaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + // Use deferred completion if enabled + if x.deferredEnabled { + return x.SendDeferred(ctx, messages, tools, x.deferredOptions) + } + + // Use custom HTTP client for Live Search in regular completions + if x.liveSearchEnabled { + return x.sendWithLiveSearch(ctx, messages, tools) + } + + // Use concurrent client if configured + if x.concurrent != nil { + return x.concurrent.send(ctx, messages, tools) + } + + // Call the base OpenAI implementation + response, err := x.openaiClient.send(ctx, messages, tools) + if err != nil { + return nil, err + } + + // Track fingerprint for monitoring, security, and compliance + if response.SystemFingerprint != "" { + x.trackFingerprint(response.SystemFingerprint, response.Usage) + } + + return response, nil +} + +// sendWithLiveSearch sends a regular completion request with Live Search parameters +func (x *xaiClient) sendWithLiveSearch(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + // Build request similar to deferred completions but without the deferred flag + reqBody := map[string]interface{}{ + "model": x.providerOptions.model.APIModel, + "messages": x.convertMessagesToAPI(messages), + "max_tokens": &x.providerOptions.maxTokens, + } + + // Add tools if provided + if len(tools) > 0 { + reqBody["tools"] = x.convertToolsToAPI(tools) + } + + // Temperature is not configurable in the current implementation + + // Apply reasoning effort if applicable + if x.shouldApplyReasoningEffort() && x.options.reasoningEffort != "" { + reqBody["reasoning_effort"] = x.options.reasoningEffort + } + + // Apply response format if configured + if x.options.responseFormat != nil { + reqBody["response_format"] = x.options.responseFormat + } + + // Apply tool choice if configured + if x.options.toolChoice != nil { + reqBody["tool_choice"] = x.options.toolChoice + } + + // Apply parallel tool calls if configured + if x.options.parallelToolCalls != nil { + reqBody["parallel_tool_calls"] = x.options.parallelToolCalls + } + + // Add Live Search parameters + reqBody["search_parameters"] = x.liveSearchOptions + + // Send the request using custom HTTP client + return x.sendCustomHTTPRequest(ctx, reqBody) +} + +// sendCustomHTTPRequest sends a custom HTTP request to the xAI API +func (x *xaiClient) sendCustomHTTPRequest(ctx context.Context, reqBody map[string]interface{}) (*ProviderResponse, error) { + // Import required packages for this method + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + // Get base URL (default to xAI API if not set) + baseURL := "https://api.x.ai" + if x.openaiClient.options.baseURL != "" { + baseURL = x.openaiClient.options.baseURL + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/v1/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + // Send request + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Check status code + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response as OpenAI-style completion result (same format as deferred) + var result DeferredResult + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + logging.Debug("Live Search completion received", "citations", len(result.Citations)) + + // Convert result to ProviderResponse (reuse existing conversion logic) + return x.convertDeferredResult(&result), nil +} + +func (x *xaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + // Use concurrent client if configured + if x.concurrent != nil { + return x.concurrent.stream(ctx, messages, tools) + } + + // Get the base stream + baseChan := x.openaiClient.stream(ctx, messages, tools) + + // Create a new channel to intercept and process events + eventChan := make(chan ProviderEvent) + + go func() { + defer close(eventChan) + + for event := range baseChan { + // If this is a complete event with a response, track the fingerprint + if event.Type == EventComplete && event.Response != nil && event.Response.SystemFingerprint != "" { + x.trackFingerprint(event.Response.SystemFingerprint, event.Response.Usage) + } + + // Forward the event + eventChan <- event + } + }() + + return eventChan +} + +// GetFingerprintHistory returns the fingerprint history for auditing and compliance +func (x *xaiClient) GetFingerprintHistory() []FingerprintRecord { + x.mu.Lock() + defer x.mu.Unlock() + + // Return a copy to prevent external modification + history := make([]FingerprintRecord, len(x.fingerprintHistory)) + copy(history, x.fingerprintHistory) + return history +} + +// GetCurrentFingerprint returns the current system fingerprint +func (x *xaiClient) GetCurrentFingerprint() string { + x.mu.Lock() + defer x.mu.Unlock() + return x.lastFingerprint +} + +// SendBatch processes multiple requests concurrently if concurrent client is configured +func (x *xaiClient) SendBatch(ctx context.Context, requests []BatchRequest) []BatchResponse { + if x.concurrent != nil { + return x.concurrent.SendBatch(ctx, requests) + } + + // Fallback to sequential processing if no concurrent client + responses := make([]BatchResponse, len(requests)) + for i, req := range requests { + resp, err := x.send(ctx, req.Messages, req.Tools) + responses[i] = BatchResponse{ + Response: resp, + Error: err, + Index: i, + } + } + return responses +} + +// StreamBatch processes multiple streaming requests concurrently if configured +func (x *xaiClient) StreamBatch(ctx context.Context, requests []BatchRequest) []<-chan ProviderEvent { + if x.concurrent != nil { + return x.concurrent.StreamBatch(ctx, requests) + } + + // Fallback to sequential processing if no concurrent client + channels := make([]<-chan ProviderEvent, len(requests)) + for i, req := range requests { + channels[i] = x.stream(ctx, req.Messages, req.Tools) + } + return channels +} + +// convertMessages overrides the base implementation to support xAI-specific image handling +func (x *xaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { + // Add system message first + openaiMessages = append(openaiMessages, openai.SystemMessage(x.providerOptions.systemMessage)) + + for _, msg := range messages { + switch msg.Role { + case message.User: + var content []openai.ChatCompletionContentPartUnionParam + + // Add text content if present + if msg.Content().String() != "" { + textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) + } + + // Add binary content (base64 encoded images) + for _, binaryContent := range msg.BinaryContent() { + // xAI expects data URLs in format: data:image/jpeg;base64, + imageURL := openai.ChatCompletionContentPartImageImageURLParam{ + URL: binaryContent.String(models.ProviderOpenAI), // This already formats as data URL + Detail: "high", // Default to high detail for better recognition + } + imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) + } + + // Add image URL content (web URLs) + for _, imageURLContent := range msg.ImageURLContent() { + detail := imageURLContent.Detail + if detail == "" { + detail = "auto" // Default to auto if not specified + } + imageURL := openai.ChatCompletionContentPartImageImageURLParam{ + URL: imageURLContent.URL, + Detail: detail, + } + imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) + } + + openaiMessages = append(openaiMessages, openai.UserMessage(content)) + + case message.Assistant: + // Use base implementation for assistant messages + assistantMsg := openai.ChatCompletionAssistantMessageParam{ + Role: "assistant", + } + + if msg.Content().String() != "" { + assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: openai.String(msg.Content().String()), + } + } + + if len(msg.ToolCalls()) > 0 { + assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls())) + for i, call := range msg.ToolCalls() { + assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{ + ID: call.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: call.Name, + Arguments: call.Input, + }, + } + } + } + + openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{ + OfAssistant: &assistantMsg, + }) + + case message.Tool: + for _, result := range msg.ToolResults() { + openaiMessages = append(openaiMessages, + openai.ToolMessage(result.Content, result.ToolCallID), + ) + } + } + } + + return +} + +// IsVisionCapable returns true if the current model supports image input +func (x *xaiClient) IsVisionCapable() bool { + return x.providerOptions.model.SupportsAttachments +} + +// SetMaxConcurrentRequests updates the maximum concurrent requests at runtime +func (x *xaiClient) SetMaxConcurrentRequests(maxConcurrent int64) { + if x.concurrent == nil { + x.concurrent = NewConcurrentClient(x, maxConcurrent) + } else { + x.concurrent.SetMaxConcurrent(maxConcurrent) + } +} diff --git a/internal/llm/provider/xai_caching_test.go b/internal/llm/provider/xai_caching_test.go new file mode 100644 index 00000000..723b7f20 --- /dev/null +++ b/internal/llm/provider/xai_caching_test.go @@ -0,0 +1,311 @@ +package provider + +import ( + "context" + "os" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_PromptCaching(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("prompt caching with repeated requests", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), // Use a fast model for testing + WithMaxTokens(100), + WithSystemMessage("You are a helpful assistant. Answer concisely."), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Create a longer prompt that's likely to be cached + longPrompt := `Please analyze the following scenario: A company is evaluating whether to implement + a new software system. The system costs $100,000 initially and $20,000 per year to maintain. + It will save the company $35,000 per year in operational costs. The company expects to use + this system for 5 years. Should they implement this system? Provide a brief analysis.` + + baseMessages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: longPrompt}, + }, + }, + } + + ctx := context.Background() + + // First request - should create cache + t.Log("Making first request (cache creation)...") + resp1, err := provider.SendMessages(ctx, baseMessages, nil) + require.NoError(t, err) + require.NotNil(t, resp1) + + t.Logf("First request usage: input=%d, output=%d, cached=%d, cache_creation=%d", + resp1.Usage.InputTokens, resp1.Usage.OutputTokens, + resp1.Usage.CacheReadTokens, resp1.Usage.CacheCreationTokens) + + // Wait a moment to ensure request completes + time.Sleep(1 * time.Second) + + // Second request with same prompt - should use cache + t.Log("Making second request (cache hit expected)...") + resp2, err := provider.SendMessages(ctx, baseMessages, nil) + require.NoError(t, err) + require.NotNil(t, resp2) + + t.Logf("Second request usage: input=%d, output=%d, cached=%d, cache_creation=%d", + resp2.Usage.InputTokens, resp2.Usage.OutputTokens, + resp2.Usage.CacheReadTokens, resp2.Usage.CacheCreationTokens) + + // Check if caching occurred in either request + totalCachedTokens := resp1.Usage.CacheReadTokens + resp2.Usage.CacheReadTokens + if totalCachedTokens > 0 { + t.Logf("✓ Prompt caching detected! Total cached tokens: %d", totalCachedTokens) + + // Calculate cache efficiency + totalPromptTokens := resp1.Usage.InputTokens + resp1.Usage.CacheReadTokens + + resp2.Usage.InputTokens + resp2.Usage.CacheReadTokens + cacheHitRate := float64(totalCachedTokens) / float64(totalPromptTokens) * 100 + t.Logf("Cache hit rate: %.1f%%", cacheHitRate) + + // Test cache cost savings calculation + savings := xaiClient.calculateCacheCostSavings(resp2.Usage) + if savings > 0 { + t.Logf("Estimated cost savings: $%.6f", savings) + } + } else { + t.Log("No caching detected in this test (may need more requests or longer prompts)") + } + + // Verify responses are different (since this is a generative task) + assert.NotEqual(t, resp1.Content, resp2.Content, "Responses should be different for generative tasks") + + // Verify both responses have content + assert.NotEmpty(t, resp1.Content) + assert.NotEmpty(t, resp2.Content) + + // Verify system fingerprints are present + assert.NotEmpty(t, resp1.SystemFingerprint) + assert.NotEmpty(t, resp2.SystemFingerprint) + }) + + t.Run("streaming with caching", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(50), + WithSystemMessage("You are a helpful assistant."), + ) + require.NoError(t, err) + + // Use the same prompt twice to test caching in streaming + prompt := "What is the capital of France? Answer in one sentence." + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: prompt}, + }, + }, + } + + ctx := context.Background() + + // First streaming request + t.Log("First streaming request...") + eventChan1 := provider.StreamResponse(ctx, messages, nil) + var finalResp1 *ProviderResponse + + for event := range eventChan1 { + if event.Type == EventComplete { + finalResp1 = event.Response + break + } else if event.Type == EventError { + t.Fatalf("Streaming error: %v", event.Error) + } + } + + require.NotNil(t, finalResp1) + t.Logf("First streaming usage: input=%d, output=%d, cached=%d", + finalResp1.Usage.InputTokens, finalResp1.Usage.OutputTokens, finalResp1.Usage.CacheReadTokens) + + // Wait a moment + time.Sleep(1 * time.Second) + + // Second streaming request + t.Log("Second streaming request...") + eventChan2 := provider.StreamResponse(ctx, messages, nil) + var finalResp2 *ProviderResponse + + for event := range eventChan2 { + if event.Type == EventComplete { + finalResp2 = event.Response + break + } else if event.Type == EventError { + t.Fatalf("Streaming error: %v", event.Error) + } + } + + require.NotNil(t, finalResp2) + t.Logf("Second streaming usage: input=%d, output=%d, cached=%d", + finalResp2.Usage.InputTokens, finalResp2.Usage.OutputTokens, finalResp2.Usage.CacheReadTokens) + + // Check for any caching + totalCached := finalResp1.Usage.CacheReadTokens + finalResp2.Usage.CacheReadTokens + if totalCached > 0 { + t.Logf("✓ Streaming caching detected! Total cached tokens: %d", totalCached) + } else { + t.Log("No caching detected in streaming requests") + } + + // Verify we got valid responses + assert.NotEmpty(t, finalResp1.Content) + assert.NotEmpty(t, finalResp2.Content) + }) + + t.Run("deferred completion with caching", func(t *testing.T) { + // Test caching with deferred completions + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(100), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Enable deferred mode + xaiClient.deferredEnabled = true + xaiClient.deferredOptions = DeferredOptions{ + Timeout: 2 * time.Minute, + Interval: 5 * time.Second, + } + + prompt := "Explain quantum computing in simple terms." + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: prompt}, + }, + }, + } + + ctx := context.Background() + + // Send deferred request + t.Log("Sending deferred request...") + resp, err := provider.SendMessages(ctx, messages, nil) + if err != nil { + t.Logf("Deferred request failed (expected for some models): %v", err) + return + } + + require.NotNil(t, resp) + t.Logf("Deferred usage: input=%d, output=%d, cached=%d", + resp.Usage.InputTokens, resp.Usage.OutputTokens, resp.Usage.CacheReadTokens) + + if resp.Usage.CacheReadTokens > 0 { + t.Logf("✓ Deferred completion caching detected! Cached tokens: %d", resp.Usage.CacheReadTokens) + } + + assert.NotEmpty(t, resp.Content) + }) + + t.Run("cache metrics validation", func(t *testing.T) { + // Test the cache cost savings calculation + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Test with mock usage data + mockUsage := TokenUsage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadTokens: 25, + CacheCreationTokens: 0, + } + + savings := xaiClient.calculateCacheCostSavings(mockUsage) + t.Logf("Mock cache savings for 25 cached tokens: $%.6f", savings) + + // Savings should be positive when there are cached tokens + if mockUsage.CacheReadTokens > 0 { + assert.Greater(t, savings, 0.0, "Should have positive savings with cached tokens") + } + + // Test with zero cached tokens + zeroUsage := TokenUsage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadTokens: 0, + } + zeroSavings := xaiClient.calculateCacheCostSavings(zeroUsage) + assert.Equal(t, 0.0, zeroSavings, "Should have zero savings with no cached tokens") + }) +} + +func TestCacheTokenHandling(t *testing.T) { + // Test the cached token parsing in deferred results + + // Mock deferred result with cached tokens + result := &DeferredResult{ + Usage: DeferredUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + PromptTokensDetails: &DeferredPromptTokensDetails{ + TextTokens: 75, + CachedTokens: 25, + ImageTokens: 0, + AudioTokens: 0, + }, + }, + } + + // Test parsing + cachedTokens := int64(0) + if result.Usage.PromptTokensDetails != nil { + cachedTokens = result.Usage.PromptTokensDetails.CachedTokens + } + inputTokens := result.Usage.PromptTokens - cachedTokens + + assert.Equal(t, int64(25), cachedTokens, "Should extract cached tokens correctly") + assert.Equal(t, int64(75), inputTokens, "Should calculate input tokens correctly") + assert.Equal(t, int64(100), inputTokens+cachedTokens, "Total should match prompt tokens") +} diff --git a/internal/llm/provider/xai_deferred.go b/internal/llm/provider/xai_deferred.go new file mode 100644 index 00000000..27d00083 --- /dev/null +++ b/internal/llm/provider/xai_deferred.go @@ -0,0 +1,533 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/logging" + "github.com/opencode-ai/opencode/internal/message" +) + +// DeferredCompletionRequest represents the request body for deferred completions +type DeferredCompletionRequest struct { + Model string `json:"model"` + Messages []map[string]interface{} `json:"messages"` + MaxTokens *int64 `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Tools []map[string]interface{} `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Deferred bool `json:"deferred"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + SearchParameters *LiveSearchOptions `json:"search_parameters,omitempty"` +} + +// DeferredCompletionResponse represents the initial response with request_id +type DeferredCompletionResponse struct { + RequestID string `json:"request_id"` +} + +// DeferredResult represents the final deferred completion result +type DeferredResult struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []DeferredChoice `json:"choices"` + Usage DeferredUsage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` + Citations []string `json:"citations,omitempty"` +} + +// DeferredChoice represents a choice in the deferred result +type DeferredChoice struct { + Index int `json:"index"` + Message DeferredMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// DeferredMessage represents a message in the deferred result +type DeferredMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []DeferredToolCall `json:"tool_calls,omitempty"` +} + +// DeferredToolCall represents a tool call in the deferred result +type DeferredToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function DeferredToolFunction `json:"function"` +} + +// DeferredToolFunction represents a tool function call +type DeferredToolFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// DeferredUsage represents token usage in the deferred result +type DeferredUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + PromptTokensDetails *DeferredPromptTokensDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *DeferredCompletionTokensDetails `json:"completion_tokens_details,omitempty"` + NumSourcesUsed int64 `json:"num_sources_used,omitempty"` +} + +// DeferredPromptTokensDetails represents detailed prompt token usage +type DeferredPromptTokensDetails struct { + TextTokens int64 `json:"text_tokens"` + AudioTokens int64 `json:"audio_tokens"` + ImageTokens int64 `json:"image_tokens"` + CachedTokens int64 `json:"cached_tokens"` +} + +// DeferredCompletionTokensDetails represents detailed completion token usage +type DeferredCompletionTokensDetails struct { + ReasoningTokens int64 `json:"reasoning_tokens"` + AudioTokens int64 `json:"audio_tokens"` + AcceptedPredictionTokens int64 `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"` +} + +// DeferredOptions represents options for deferred completions +type DeferredOptions struct { + Timeout time.Duration + Interval time.Duration +} + +// DefaultDeferredOptions returns default options for deferred completions +func DefaultDeferredOptions() DeferredOptions { + return DeferredOptions{ + Timeout: 10 * time.Minute, + Interval: 10 * time.Second, + } +} + +// LiveSearchOptions represents options for Live Search +type LiveSearchOptions struct { + Mode string `json:"mode,omitempty"` // "auto", "on", "off" + MaxSearchResults *int `json:"max_search_results,omitempty"` // 1-20, default 20 + FromDate *string `json:"from_date,omitempty"` // YYYY-MM-DD + ToDate *string `json:"to_date,omitempty"` // YYYY-MM-DD + ReturnCitations *bool `json:"return_citations,omitempty"` // default true + Sources []LiveSearchSource `json:"sources,omitempty"` // Data sources +} + +// LiveSearchSource represents a data source for Live Search +type LiveSearchSource struct { + Type string `json:"type"` // "web", "x", "news", "rss" + Country *string `json:"country,omitempty"` // ISO alpha-2 (web, news) + ExcludedWebsites []string `json:"excluded_websites,omitempty"` // max 5 (web, news) + AllowedWebsites []string `json:"allowed_websites,omitempty"` // max 5 (web only) + SafeSearch *bool `json:"safe_search,omitempty"` // default true (web, news) + IncludedXHandles []string `json:"included_x_handles,omitempty"` // max 10 (x only) + ExcludedXHandles []string `json:"excluded_x_handles,omitempty"` // max 10 (x only) + PostFavoriteCount *int `json:"post_favorite_count,omitempty"` // min favorites (x only) + PostViewCount *int `json:"post_view_count,omitempty"` // min views (x only) + Links []string `json:"links,omitempty"` // RSS URLs, max 1 (rss only) +} + +// DefaultLiveSearchOptions returns default Live Search options +func DefaultLiveSearchOptions() LiveSearchOptions { + returnCitations := true + return LiveSearchOptions{ + Mode: "auto", + ReturnCitations: &returnCitations, + Sources: []LiveSearchSource{ + {Type: "web"}, + {Type: "x"}, + }, + } +} + +// sendDeferred sends a deferred completion request to xAI +func (x *xaiClient) sendDeferred(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (string, error) { + // Convert messages to the format expected by the API + apiMessages := x.convertMessagesToAPI(messages) + + // Convert tools to the format expected by the API + apiTools := x.convertToolsToAPI(tools) + + // Prepare request body + reqBody := DeferredCompletionRequest{ + Model: x.providerOptions.model.APIModel, + Messages: apiMessages, + MaxTokens: &x.providerOptions.maxTokens, + Deferred: true, + Tools: apiTools, + } + + // Apply reasoning effort if applicable + if x.shouldApplyReasoningEffort() && x.options.reasoningEffort != "" { + reqBody.ReasoningEffort = x.options.reasoningEffort + } + + // Apply response format if configured + if x.options.responseFormat != nil { + reqBody.ResponseFormat = x.options.responseFormat + } + + // Apply tool choice if configured + if x.options.toolChoice != nil { + reqBody.ToolChoice = x.options.toolChoice + } + + // Apply parallel tool calls if configured + if x.options.parallelToolCalls != nil { + reqBody.ParallelToolCalls = x.options.parallelToolCalls + } + + // Apply Live Search parameters if enabled + if x.liveSearchEnabled { + reqBody.SearchParameters = &x.liveSearchOptions + } + + // Marshal request body + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request body: %w", err) + } + + // Get base URL (default to xAI API if not set) + baseURL := "https://api.x.ai" + if x.openaiClient.options.baseURL != "" { + baseURL = x.openaiClient.options.baseURL + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/v1/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + // Send request + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + // Check status code + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var deferredResp DeferredCompletionResponse + if err := json.Unmarshal(body, &deferredResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if deferredResp.RequestID == "" { + return "", fmt.Errorf("no request_id in response") + } + + logging.Debug("Created deferred completion", "request_id", deferredResp.RequestID) + + return deferredResp.RequestID, nil +} + +// pollDeferredResult polls for the deferred completion result +func (x *xaiClient) pollDeferredResult(ctx context.Context, requestID string, opts DeferredOptions) (*DeferredResult, error) { + // Get base URL (default to xAI API if not set) + baseURL := "https://api.x.ai" + if x.openaiClient.options.baseURL != "" { + baseURL = x.openaiClient.options.baseURL + } + + url := fmt.Sprintf("%s/v1/chat/deferred-completion/%s", baseURL, requestID) + + // Create HTTP client + client := &http.Client{Timeout: 30 * time.Second} + + // Start polling + ticker := time.NewTicker(opts.Interval) + defer ticker.Stop() + + timeout := time.After(opts.Timeout) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timeout: + return nil, fmt.Errorf("timeout waiting for deferred completion after %v", opts.Timeout) + case <-ticker.C: + // Create request + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create polling request: %w", err) + } + + // Set headers + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + // Send request + resp, err := client.Do(req) + if err != nil { + logging.Debug("Error polling deferred result", "error", err) + continue // Retry on network errors + } + defer resp.Body.Close() + + // Check status code + if resp.StatusCode == http.StatusAccepted { + // 202 Accepted means still processing + logging.Debug("Deferred completion still processing", "request_id", requestID) + continue + } + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read polling response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("polling failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse result + var result DeferredResult + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse deferred result: %w", err) + } + + logging.Debug("Deferred completion ready", "request_id", requestID) + + return &result, nil + } + } +} + +// SendDeferred sends a deferred completion request and polls for the result +func (x *xaiClient) SendDeferred(ctx context.Context, messages []message.Message, tools []tools.BaseTool, opts DeferredOptions) (*ProviderResponse, error) { + // Send deferred request + requestID, err := x.sendDeferred(ctx, messages, tools) + if err != nil { + return nil, fmt.Errorf("failed to send deferred request: %w", err) + } + + // Poll for result + result, err := x.pollDeferredResult(ctx, requestID, opts) + if err != nil { + return nil, fmt.Errorf("failed to get deferred result: %w", err) + } + + // Convert result to ProviderResponse + return x.convertDeferredResult(result), nil +} + +// convertDeferredResult converts a DeferredResult to ProviderResponse +func (x *xaiClient) convertDeferredResult(result *DeferredResult) *ProviderResponse { + if result == nil || len(result.Choices) == 0 { + return &ProviderResponse{ + FinishReason: message.FinishReasonUnknown, + } + } + + choice := result.Choices[0] + + // Convert tool calls + var toolCalls []message.ToolCall + for _, tc := range choice.Message.ToolCalls { + toolCalls = append(toolCalls, message.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Input: tc.Function.Arguments, + Finished: true, + }) + } + + // Determine finish reason + finishReason := x.finishReason(choice.FinishReason) + if len(toolCalls) > 0 { + finishReason = message.FinishReasonToolUse + } + + // Calculate cached tokens and actual input tokens + var cachedTokens int64 + var inputTokens int64 + + if result.Usage.PromptTokensDetails != nil { + cachedTokens = result.Usage.PromptTokensDetails.CachedTokens + } + inputTokens = result.Usage.PromptTokens - cachedTokens + + // Create response + resp := &ProviderResponse{ + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: TokenUsage{ + InputTokens: inputTokens, + OutputTokens: result.Usage.CompletionTokens, + CacheCreationTokens: 0, // Not provided in deferred responses + CacheReadTokens: cachedTokens, + }, + SystemFingerprint: result.SystemFingerprint, + Citations: result.Citations, + } + + // Track fingerprint + if resp.SystemFingerprint != "" { + x.trackFingerprint(resp.SystemFingerprint, resp.Usage) + } + + return resp +} + +// convertMessagesToAPI converts internal messages to API format +func (x *xaiClient) convertMessagesToAPI(messages []message.Message) []map[string]interface{} { + var apiMessages []map[string]interface{} + + // Add system message + apiMessages = append(apiMessages, map[string]interface{}{ + "role": "system", + "content": x.providerOptions.systemMessage, + }) + + // Convert user messages + for _, msg := range messages { + switch msg.Role { + case message.User: + // Check if message has images + hasImages := len(msg.BinaryContent()) > 0 || len(msg.ImageURLContent()) > 0 + + if hasImages { + // Build content array for multimodal message + var content []map[string]interface{} + + // Add text content if present + if msg.Content().String() != "" { + content = append(content, map[string]interface{}{ + "type": "text", + "text": msg.Content().String(), + }) + } + + // Add binary images (base64 encoded) + for _, binaryContent := range msg.BinaryContent() { + content = append(content, map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": binaryContent.String(models.ProviderOpenAI), // data:image/jpeg;base64, + "detail": "high", // Default to high detail + }, + }) + } + + // Add image URLs (web URLs) + for _, imageURLContent := range msg.ImageURLContent() { + detail := imageURLContent.Detail + if detail == "" { + detail = "auto" + } + content = append(content, map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": imageURLContent.URL, + "detail": detail, + }, + }) + } + + apiMsg := map[string]interface{}{ + "role": "user", + "content": content, + } + apiMessages = append(apiMessages, apiMsg) + } else { + // Simple text message + apiMsg := map[string]interface{}{ + "role": "user", + "content": msg.Content().String(), + } + apiMessages = append(apiMessages, apiMsg) + } + + case message.Assistant: + apiMsg := map[string]interface{}{ + "role": "assistant", + } + + if msg.Content().String() != "" { + apiMsg["content"] = msg.Content().String() + } + + if len(msg.ToolCalls()) > 0 { + var toolCalls []map[string]interface{} + for _, tc := range msg.ToolCalls() { + toolCalls = append(toolCalls, map[string]interface{}{ + "id": tc.ID, + "type": "function", + "function": map[string]interface{}{ + "name": tc.Name, + "arguments": tc.Input, + }, + }) + } + apiMsg["tool_calls"] = toolCalls + } + + apiMessages = append(apiMessages, apiMsg) + + case message.Tool: + for _, result := range msg.ToolResults() { + apiMessages = append(apiMessages, map[string]interface{}{ + "role": "tool", + "content": result.Content, + "tool_call_id": result.ToolCallID, + }) + } + } + } + + return apiMessages +} + +// convertToolsToAPI converts internal tools to API format +func (x *xaiClient) convertToolsToAPI(tools []tools.BaseTool) []map[string]interface{} { + var apiTools []map[string]interface{} + + for _, tool := range tools { + info := tool.Info() + apiTools = append(apiTools, map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": info.Name, + "description": info.Description, + "parameters": map[string]interface{}{ + "type": "object", + "properties": info.Parameters, + "required": info.Required, + }, + }, + }) + } + + return apiTools +} diff --git a/internal/llm/provider/xai_deferred_test.go b/internal/llm/provider/xai_deferred_test.go new file mode 100644 index 00000000..1db8a58a --- /dev/null +++ b/internal/llm/provider/xai_deferred_test.go @@ -0,0 +1,331 @@ +package provider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_DeferredCompletions(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("basic deferred completion", func(t *testing.T) { + // Create xAI client with deferred completion enabled + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 100, + systemMessage: "You are a helpful assistant.", + xaiOptions: []XAIOption{ + WithDeferredCompletion(), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + require.True(t, xaiClient.deferredEnabled) + + // Create a simple message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is 2+2? Answer in one word."}, + }, + }, + } + + // Send deferred request + resp, err := xaiClient.send(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify response + assert.NotEmpty(t, resp.Content) + assert.NotEmpty(t, resp.SystemFingerprint) + assert.Equal(t, message.FinishReasonEndTurn, resp.FinishReason) + + t.Logf("Deferred response: %s", resp.Content) + t.Logf("System fingerprint: %s", resp.SystemFingerprint) + }) + + t.Run("deferred completion with custom options", func(t *testing.T) { + // Create xAI client with custom deferred options + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 150, + systemMessage: "You are a helpful assistant.", + xaiOptions: []XAIOption{ + WithDeferredCompletion(), + WithDeferredOptions(5*time.Minute, 5*time.Second), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + assert.Equal(t, 5*time.Minute, xaiClient.deferredOptions.Timeout) + assert.Equal(t, 5*time.Second, xaiClient.deferredOptions.Interval) + + // Create message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Explain quantum computing in one sentence."}, + }, + }, + } + + // Send request + resp, err := xaiClient.send(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.NotEmpty(t, resp.Content) + assert.NotEmpty(t, resp.SystemFingerprint) + }) + + t.Run("deferred completion with tool use", func(t *testing.T) { + // Create xAI client with deferred completion + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 200, + systemMessage: "You are a helpful assistant. Use tools when appropriate.", + xaiOptions: []XAIOption{ + WithDeferredCompletion(), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + + // Create a mock tool + mockTool := &MockTool{ + name: "calculate", + description: "Perform calculations", + parameters: map[string]interface{}{ + "expression": map[string]interface{}{ + "type": "string", + "description": "The mathematical expression", + }, + }, + required: []string{"expression"}, + response: `{"result": 9}`, + } + + // Create message that should trigger tool use + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is 3 times 3?"}, + }, + }, + } + + // Send request + resp, err := xaiClient.send(context.Background(), messages, []tools.BaseTool{mockTool}) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify response (may or may not use tool depending on model) + assert.NotEmpty(t, resp.SystemFingerprint) + if resp.FinishReason == message.FinishReasonToolUse { + assert.NotEmpty(t, resp.ToolCalls) + t.Logf("Tool was called: %+v", resp.ToolCalls) + } else { + assert.NotEmpty(t, resp.Content) + t.Logf("Direct response: %s", resp.Content) + } + }) +} + +func TestXAIProvider_DeferredCompletionsMock(t *testing.T) { + // Mock server tests for deferred completions + t.Run("mock deferred completion flow", func(t *testing.T) { + requestCount := int32(0) + requestID := "test-request-123" + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&requestCount, 1) + + switch r.URL.Path { + case "/v1/chat/completions": + // Initial deferred request + assert.Equal(t, "POST", r.Method) + + // Verify deferred flag + var reqBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&reqBody) + require.NoError(t, err) + assert.Equal(t, true, reqBody["deferred"]) + + // Return request ID + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(DeferredCompletionResponse{ + RequestID: requestID, + }) + + case "/v1/chat/deferred-completion/" + requestID: + // Polling request + assert.Equal(t, "GET", r.Method) + + if count < 3 { + // First two polls return 202 (still processing) + w.WriteHeader(http.StatusAccepted) + } else { + // Third poll returns the result + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(DeferredResult{ + ID: "completion-123", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "grok-3-fast", + Choices: []DeferredChoice{ + { + Index: 0, + Message: DeferredMessage{ + Role: "assistant", + Content: "42", + }, + FinishReason: "stop", + }, + }, + Usage: DeferredUsage{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + SystemFingerprint: "fp_test123", + }) + } + + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Create xAI client pointing to mock server + opts := providerClientOptions{ + apiKey: "test-key", + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 100, + systemMessage: "Test system message", + openaiOptions: []OpenAIOption{ + WithOpenAIBaseURL(server.URL), + }, + xaiOptions: []XAIOption{ + WithDeferredCompletion(), + WithDeferredOptions(30*time.Second, 100*time.Millisecond), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + + // Override base URL for deferred requests + xaiClient.openaiClient.options.baseURL = server.URL + + // Create test messages + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is the answer?"}, + }, + }, + } + + // Send request + start := time.Now() + resp, err := xaiClient.SendDeferred(context.Background(), messages, nil, xaiClient.deferredOptions) + duration := time.Since(start) + + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify response + assert.Equal(t, "42", resp.Content) + assert.Equal(t, message.FinishReasonEndTurn, resp.FinishReason) + assert.Equal(t, int64(10), resp.Usage.InputTokens) + assert.Equal(t, int64(5), resp.Usage.OutputTokens) + assert.Equal(t, "fp_test123", resp.SystemFingerprint) + + // Verify polling happened (should have made at least 3 requests) + assert.GreaterOrEqual(t, atomic.LoadInt32(&requestCount), int32(3)) + + // Verify timing (should have taken at least 200ms due to 2 polls at 100ms interval) + assert.GreaterOrEqual(t, duration, 200*time.Millisecond) + }) + + t.Run("mock deferred completion timeout", func(t *testing.T) { + // Create mock server that always returns 202 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/chat/completions": + // Return request ID + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(DeferredCompletionResponse{ + RequestID: "timeout-test", + }) + default: + // Always return 202 (processing) + w.WriteHeader(http.StatusAccepted) + } + })) + defer server.Close() + + // Create client with short timeout + opts := providerClientOptions{ + apiKey: "test-key", + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 100, + systemMessage: "Test", + openaiOptions: []OpenAIOption{ + WithOpenAIBaseURL(server.URL), + }, + xaiOptions: []XAIOption{ + WithDeferredCompletion(), + WithDeferredOptions(500*time.Millisecond, 100*time.Millisecond), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + xaiClient.openaiClient.options.baseURL = server.URL + + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Test"}, + }, + }, + } + + // Should timeout + _, err := xaiClient.SendDeferred(context.Background(), messages, nil, xaiClient.deferredOptions) + assert.Error(t, err) + assert.Contains(t, err.Error(), "timeout") + }) +} diff --git a/internal/llm/provider/xai_image_generation.go b/internal/llm/provider/xai_image_generation.go new file mode 100644 index 00000000..137974a9 --- /dev/null +++ b/internal/llm/provider/xai_image_generation.go @@ -0,0 +1,212 @@ +package provider + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/openai/openai-go" + "github.com/opencode-ai/opencode/internal/logging" +) + +// ImageGenerationRequest represents a request to generate images +type ImageGenerationRequest struct { + Prompt string + Model string + N int // Number of images (1-10) + ResponseFormat string // "url" or "b64_json" +} + +// ImageGenerationResponse represents the response from image generation +type ImageGenerationResponse struct { + Images []GeneratedImage + RevisedPrompt string + Model string + Created time.Time +} + +// GeneratedImage represents a single generated image +type GeneratedImage struct { + URL string // For URL format + Base64 string // For b64_json format + ContentType string // MIME type +} + +// GenerateImages generates one or more images based on a text prompt +func (x *xaiClient) GenerateImages(ctx context.Context, req ImageGenerationRequest) (*ImageGenerationResponse, error) { + // Validate request + if req.Prompt == "" { + return nil, fmt.Errorf("prompt cannot be empty") + } + + if req.N < 1 { + req.N = 1 + } else if req.N > 10 { + return nil, fmt.Errorf("n must be between 1 and 10, got %d", req.N) + } + + if req.ResponseFormat == "" { + req.ResponseFormat = "url" + } else if req.ResponseFormat != "url" && req.ResponseFormat != "b64_json" { + return nil, fmt.Errorf("response_format must be 'url' or 'b64_json', got %s", req.ResponseFormat) + } + + // Use the model from request or fall back to provider's model + model := req.Model + if model == "" { + model = string(x.providerOptions.model.APIModel) + } + + // Check if model supports image generation + caps, err := x.DiscoverModelCapabilities(ctx, model) + if err != nil { + logging.Warn("Failed to discover model capabilities, proceeding anyway", "error", err) + } else if caps != nil && !caps.SupportsImageOutput { + return nil, fmt.Errorf("model %s does not support image generation", model) + } + + // Create the image generation request + params := openai.ImageGenerateParams{ + Model: openai.ImageModel(model), + Prompt: req.Prompt, + N: openai.Int(int64(req.N)), + ResponseFormat: openai.ImageGenerateParamsResponseFormat(req.ResponseFormat), + // xAI doesn't support quality, size, or style parameters + } + + logging.Debug("Generating images", + "model", model, + "prompt_length", len(req.Prompt), + "n", req.N, + "format", req.ResponseFormat) + + // Make the API call + startTime := time.Now() + result, err := x.client.Images.Generate(ctx, params) + if err != nil { + return nil, fmt.Errorf("image generation failed: %w", err) + } + + elapsed := time.Since(startTime) + logging.Debug("Image generation completed", + "model", model, + "n", req.N, + "elapsed", elapsed) + + // Convert response + response := &ImageGenerationResponse{ + Model: model, + Created: time.Unix(result.Created, 0), + Images: make([]GeneratedImage, 0, len(result.Data)), + } + + // Extract revised prompt if available + if len(result.Data) > 0 && result.Data[0].RevisedPrompt != "" { + response.RevisedPrompt = result.Data[0].RevisedPrompt + } + + // Process each generated image + for i, imgData := range result.Data { + img := GeneratedImage{ + ContentType: "image/jpeg", // xAI generates JPEGs + } + + if req.ResponseFormat == "url" { + img.URL = imgData.URL + } else { + // b64_json format + img.Base64 = imgData.B64JSON + } + + response.Images = append(response.Images, img) + + logging.Debug("Processed generated image", + "index", i, + "has_url", img.URL != "", + "has_base64", img.Base64 != "") + } + + // Track fingerprint if available + if x.providerOptions.model.Provider == "xai" { + // Note: Image generation responses don't include system fingerprint in the same way + // but we can still track the generation event + logging.Debug("Image generation completed", + "model", model, + "images_generated", len(response.Images)) + } + + return response, nil +} + +// GenerateImage is a convenience method to generate a single image +func (x *xaiClient) GenerateImage(ctx context.Context, prompt string) (*GeneratedImage, error) { + req := ImageGenerationRequest{ + Prompt: prompt, + N: 1, + ResponseFormat: "url", + } + + resp, err := x.GenerateImages(ctx, req) + if err != nil { + return nil, err + } + + if len(resp.Images) == 0 { + return nil, fmt.Errorf("no images generated") + } + + return &resp.Images[0], nil +} + +// SaveGeneratedImage downloads and returns the image data from a URL-based response +func (x *xaiClient) SaveGeneratedImage(ctx context.Context, image *GeneratedImage) ([]byte, error) { + if image.Base64 != "" { + // Already have base64 data, decode it + // Remove data URL prefix if present + b64Data := image.Base64 + if strings.HasPrefix(b64Data, "data:") { + parts := strings.SplitN(b64Data, ",", 2) + if len(parts) == 2 { + b64Data = parts[1] + } + } + + return base64.StdEncoding.DecodeString(b64Data) + } + + if image.URL == "" { + return nil, fmt.Errorf("no image data available") + } + + // Download from URL + req, err := http.NewRequestWithContext(ctx, "GET", image.URL, nil) + if err != nil { + return nil, err + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to download image: status %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// ConvertToDataURL converts image data to a data URL +func ConvertImageToDataURL(data []byte, contentType string) string { + if contentType == "" { + contentType = "image/jpeg" + } + encoded := base64.StdEncoding.EncodeToString(data) + return fmt.Sprintf("data:%s;base64,%s", contentType, encoded) +} diff --git a/internal/llm/provider/xai_image_generation_test.go b/internal/llm/provider/xai_image_generation_test.go new file mode 100644 index 00000000..14f13fdc --- /dev/null +++ b/internal/llm/provider/xai_image_generation_test.go @@ -0,0 +1,245 @@ +package provider + +import ( + "context" + "os" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_ImageGeneration(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("model capability discovery", func(t *testing.T) { + // Create provider with image generation model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Image]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Test capability discovery + ctx := context.Background() + caps, err := xaiClient.DiscoverModelCapabilities(ctx, "grok-2-image") + require.NoError(t, err) + require.NotNil(t, caps) + + assert.True(t, caps.SupportsText, "Image generation models should support text prompts") + assert.True(t, caps.SupportsImageOutput, "Image generation models should support image output") + assert.False(t, caps.SupportsImageInput, "Image generation models don't typically support image input") + + t.Logf("Image generation model capabilities: %+v", caps) + }) + + t.Run("single image generation", func(t *testing.T) { + // Create provider with image generation model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Image]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Generate a simple image + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + image, err := xaiClient.GenerateImage(ctx, "A simple red circle on white background") + require.NoError(t, err) + require.NotNil(t, image) + + // Check that we got either a URL or base64 data + assert.True(t, image.URL != "" || image.Base64 != "", "Should have either URL or base64 data") + assert.Equal(t, "image/jpeg", image.ContentType) + + t.Logf("Generated image: URL=%v, has_base64=%v", image.URL != "", image.Base64 != "") + }) + + t.Run("multiple image generation", func(t *testing.T) { + // Create provider with image generation model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Image]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Generate multiple images + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + req := ImageGenerationRequest{ + Prompt: "A cat sitting on a tree branch", + N: 3, + ResponseFormat: "url", + } + + resp, err := xaiClient.GenerateImages(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Len(t, resp.Images, 3, "Should generate 3 images") + assert.NotEmpty(t, resp.RevisedPrompt, "Should have revised prompt") + + // Check each image + for i, img := range resp.Images { + assert.NotEmpty(t, img.URL, "Image %d should have URL", i) + assert.Equal(t, "image/jpeg", img.ContentType) + } + + t.Logf("Generated %d images", len(resp.Images)) + t.Logf("Original prompt: %s", req.Prompt) + t.Logf("Revised prompt: %s", resp.RevisedPrompt) + }) + + t.Run("base64 format", func(t *testing.T) { + // Create provider with image generation model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Image]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + // Generate image in base64 format + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + req := ImageGenerationRequest{ + Prompt: "A simple blue square", + N: 1, + ResponseFormat: "b64_json", + } + + resp, err := xaiClient.GenerateImages(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Images, 1) + + img := resp.Images[0] + assert.NotEmpty(t, img.Base64, "Should have base64 data") + assert.Empty(t, img.URL, "Should not have URL when using b64_json format") + + // Test saving the image + data, err := xaiClient.SaveGeneratedImage(ctx, &img) + require.NoError(t, err) + assert.Greater(t, len(data), 1000, "Image data should be reasonably sized") + + t.Logf("Generated base64 image with %d bytes", len(data)) + }) + + t.Run("validation tests", func(t *testing.T) { + // Create provider with image generation model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Image]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Test empty prompt + _, err = xaiClient.GenerateImages(ctx, ImageGenerationRequest{ + Prompt: "", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "prompt cannot be empty") + + // Test invalid N + _, err = xaiClient.GenerateImages(ctx, ImageGenerationRequest{ + Prompt: "test", + N: 15, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "n must be between 1 and 10") + + // Test invalid response format + _, err = xaiClient.GenerateImages(ctx, ImageGenerationRequest{ + Prompt: "test", + ResponseFormat: "invalid", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "response_format must be") + }) + + t.Run("capability discovery for grok-4", func(t *testing.T) { + // Test if grok-4 supports image generation + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + caps, err := xaiClient.DiscoverModelCapabilities(ctx, "grok-4") + if err != nil { + t.Logf("Could not discover grok-4 capabilities: %v", err) + return + } + + t.Logf("Grok-4 capabilities: text=%v, image_input=%v, image_output=%v, web_search=%v", + caps.SupportsText, caps.SupportsImageInput, caps.SupportsImageOutput, caps.SupportsWebSearch) + + // If grok-4 supports image generation, test it + if caps.SupportsImageOutput { + t.Log("Grok-4 supports image generation! Testing...") + + req := ImageGenerationRequest{ + Prompt: "A simple test image", + Model: "grok-4", + N: 1, + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := xaiClient.GenerateImages(ctx, req) + if err != nil { + t.Logf("Grok-4 image generation failed: %v", err) + } else { + t.Logf("Grok-4 successfully generated %d images", len(resp.Images)) + } + } + }) +} diff --git a/internal/llm/provider/xai_live_search_test.go b/internal/llm/provider/xai_live_search_test.go new file mode 100644 index 00000000..738f96db --- /dev/null +++ b/internal/llm/provider/xai_live_search_test.go @@ -0,0 +1,433 @@ +package provider + +import ( + "context" + "os" + "testing" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_LiveSearchOptions(t *testing.T) { + t.Run("default live search options", func(t *testing.T) { + opts := DefaultLiveSearchOptions() + + assert.Equal(t, "auto", opts.Mode) + assert.NotNil(t, opts.ReturnCitations) + assert.True(t, *opts.ReturnCitations) + assert.Len(t, opts.Sources, 2) + assert.Equal(t, "web", opts.Sources[0].Type) + assert.Equal(t, "x", opts.Sources[1].Type) + }) + + t.Run("custom live search options", func(t *testing.T) { + maxResults := 10 + fromDate := "2025-01-01" + toDate := "2025-12-31" + returnCitations := false + + opts := LiveSearchOptions{ + Mode: "on", + MaxSearchResults: &maxResults, + FromDate: &fromDate, + ToDate: &toDate, + ReturnCitations: &returnCitations, + Sources: []LiveSearchSource{ + { + Type: "web", + Country: stringPtr("US"), + ExcludedWebsites: []string{"example.com"}, + }, + { + Type: "news", + Country: stringPtr("UK"), + }, + { + Type: "x", + IncludedXHandles: []string{"xai", "openai"}, + PostFavoriteCount: intPtr(100), + }, + { + Type: "rss", + Links: []string{"https://example.com/feed.xml"}, + }, + }, + } + + assert.Equal(t, "on", opts.Mode) + assert.Equal(t, 10, *opts.MaxSearchResults) + assert.Equal(t, "2025-01-01", *opts.FromDate) + assert.Equal(t, "2025-12-31", *opts.ToDate) + assert.False(t, *opts.ReturnCitations) + assert.Len(t, opts.Sources, 4) + + // Check web source + webSource := opts.Sources[0] + assert.Equal(t, "web", webSource.Type) + assert.Equal(t, "US", *webSource.Country) + assert.Equal(t, []string{"example.com"}, webSource.ExcludedWebsites) + + // Check X source + xSource := opts.Sources[2] + assert.Equal(t, "x", xSource.Type) + assert.Equal(t, []string{"xai", "openai"}, xSource.IncludedXHandles) + assert.Equal(t, 100, *xSource.PostFavoriteCount) + + // Check RSS source + rssSource := opts.Sources[3] + assert.Equal(t, "rss", rssSource.Type) + assert.Equal(t, []string{"https://example.com/feed.xml"}, rssSource.Links) + }) + + t.Run("xai client with live search options", func(t *testing.T) { + opts := providerClientOptions{ + apiKey: "test-key", + model: models.SupportedModels[models.XAIGrok4], + maxTokens: 1000, + systemMessage: "Test system message", + xaiOptions: []XAIOption{ + WithLiveSearch(), + }, + } + + client := newXAIClient(opts).(*xaiClient) + + assert.True(t, client.liveSearchEnabled) + assert.Equal(t, "auto", client.liveSearchOptions.Mode) + assert.NotNil(t, client.liveSearchOptions.ReturnCitations) + assert.True(t, *client.liveSearchOptions.ReturnCitations) + }) + + t.Run("xai client with custom live search options", func(t *testing.T) { + customOpts := LiveSearchOptions{ + Mode: "on", + MaxSearchResults: intPtr(5), + Sources: []LiveSearchSource{ + {Type: "web"}, + }, + } + + opts := providerClientOptions{ + apiKey: "test-key", + model: models.SupportedModels[models.XAIGrok4], + maxTokens: 1000, + systemMessage: "Test system message", + xaiOptions: []XAIOption{ + WithLiveSearchOptions(customOpts), + }, + } + + client := newXAIClient(opts).(*xaiClient) + + assert.True(t, client.liveSearchEnabled) + assert.Equal(t, "on", client.liveSearchOptions.Mode) + assert.Equal(t, 5, *client.liveSearchOptions.MaxSearchResults) + assert.Len(t, client.liveSearchOptions.Sources, 1) + assert.Equal(t, "web", client.liveSearchOptions.Sources[0].Type) + }) + + t.Run("combined xai options", func(t *testing.T) { + opts := providerClientOptions{ + apiKey: "test-key", + model: models.SupportedModels[models.XAIGrok4], + maxTokens: 1000, + systemMessage: "Test system message", + xaiOptions: []XAIOption{ + WithMaxConcurrentRequests(3), + WithDeferredCompletion(), + WithLiveSearch(), + }, + } + + client := newXAIClient(opts).(*xaiClient) + + // Verify all options are applied + assert.NotNil(t, client.concurrent) + assert.True(t, client.deferredEnabled) + assert.True(t, client.liveSearchEnabled) + + assert.Equal(t, int64(3), client.concurrent.GetMaxConcurrent()) + assert.Equal(t, "auto", client.liveSearchOptions.Mode) + }) +} + +func TestXAIProvider_LiveSearchIntegration(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("basic live search request", func(t *testing.T) { + // Create provider with Live Search enabled + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant that can search the web for current information."), + WithXAIOptions( + WithLiveSearch(), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + // Create a message that should trigger live search + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What are the latest news about artificial intelligence in 2025?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{webSearchTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Live Search responses should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Response should have content (either from tools or direct search) + assert.True(t, response.Content != "" || len(response.ToolCalls) > 0) + + // Log response details for manual verification + t.Logf("Response content length: %d", len(response.Content)) + t.Logf("Number of tool calls: %d", len(response.ToolCalls)) + t.Logf("Citations: %v", response.Citations) + t.Logf("System fingerprint: %s", response.SystemFingerprint) + + // If citations are present, they should be valid URLs + for i, citation := range response.Citations { + assert.NotEmpty(t, citation, "Citation %d should not be empty", i) + t.Logf("Citation %d: %s", i+1, citation) + } + }) + + t.Run("live search with date filtering", func(t *testing.T) { + fromDate := "2025-01-01" + toDate := "2025-01-31" + + customOpts := LiveSearchOptions{ + Mode: "on", + MaxSearchResults: intPtr(5), + FromDate: &fromDate, + ToDate: &toDate, + ReturnCitations: boolPtr(true), + Sources: []LiveSearchSource{ + {Type: "web"}, + {Type: "news"}, + }, + } + + // Create provider with custom Live Search options + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearchOptions(customOpts), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + // Create a message that should use the date-filtered search + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What major technology announcements happened in January 2025?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{webSearchTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Log response for verification + t.Logf("Date-filtered search response: %s", response.Content) + t.Logf("Citations: %v", response.Citations) + }) + + t.Run("live search with x source filtering", func(t *testing.T) { + customOpts := LiveSearchOptions{ + Mode: "on", + MaxSearchResults: intPtr(3), + ReturnCitations: boolPtr(true), + Sources: []LiveSearchSource{ + { + Type: "x", + IncludedXHandles: []string{"xai", "elonmusk"}, + PostFavoriteCount: intPtr(10), + }, + }, + } + + // Create provider with X-specific search + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearchOptions(customOpts), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + // Create a message about xAI updates + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What are the latest updates from xAI?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{webSearchTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Log response for verification + t.Logf("X-filtered search response: %s", response.Content) + t.Logf("Citations: %v", response.Citations) + }) + + t.Run("live search combined with deferred completion", func(t *testing.T) { + // Create provider with both Live Search and deferred completion + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearch(), + WithDeferredCompletion(), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + // Create a complex message that might require deferred processing + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Please provide a comprehensive analysis of recent AI developments, including the latest research papers, company announcements, and market trends."}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{webSearchTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Should have substantial content for comprehensive analysis + assert.NotEmpty(t, response.Content) + + // Log response for verification + t.Logf("Deferred + Live Search response length: %d characters", len(response.Content)) + t.Logf("Citations: %d", len(response.Citations)) + t.Logf("System fingerprint: %s", response.SystemFingerprint) + }) + + t.Run("streaming with live search", func(t *testing.T) { + // Create provider with Live Search + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(300), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearch(), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + // Create a message for streaming search + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What's happening in the tech world today?"}, + }, + }, + } + + // Stream response + eventChan := provider.StreamResponse(context.Background(), messages, []tools.BaseTool{webSearchTool}) + + var finalResponse *ProviderResponse + var hasContent bool + + for event := range eventChan { + switch event.Type { + case EventContentDelta: + hasContent = true + case EventComplete: + finalResponse = event.Response + case EventError: + t.Fatalf("Streaming error: %v", event.Error) + } + } + + require.NotNil(t, finalResponse) + + // Should have content or tool calls + assert.True(t, hasContent || finalResponse.Content != "" || len(finalResponse.ToolCalls) > 0) + + // Should have system fingerprint + assert.NotEmpty(t, finalResponse.SystemFingerprint) + + // Log final response + t.Logf("Streaming final response: %s", finalResponse.Content) + t.Logf("Streaming citations: %v", finalResponse.Citations) + }) +} + +// Helper functions are defined in xai_test.go to avoid duplication diff --git a/internal/llm/provider/xai_models.go b/internal/llm/provider/xai_models.go new file mode 100644 index 00000000..3abb8cb0 --- /dev/null +++ b/internal/llm/provider/xai_models.go @@ -0,0 +1,292 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/opencode-ai/opencode/internal/logging" +) + +// XAIModelInfo represents the detailed model information from xAI API +type XAIModelInfo struct { + ID string `json:"id"` + Fingerprint string `json:"fingerprint"` + Created int64 `json:"created"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Version string `json:"version"` + InputModalities []string `json:"input_modalities"` + OutputModalities []string `json:"output_modalities"` + PromptTextTokenPrice int `json:"prompt_text_token_price"` + CachedPromptTextTokenPrice int `json:"cached_prompt_text_token_price"` + PromptImageTokenPrice int `json:"prompt_image_token_price"` + CompletionTextTokenPrice int `json:"completion_text_token_price"` + SearchPrice int `json:"search_price"` + Aliases []string `json:"aliases"` +} + +// XAIImageModelInfo represents image generation model information +type XAIImageModelInfo struct { + ID string `json:"id"` + Fingerprint string `json:"fingerprint"` + MaxPromptLength int `json:"max_prompt_length"` + Created int64 `json:"created"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Version string `json:"version"` + InputModalities []string `json:"input_modalities"` + OutputModalities []string `json:"output_modalities"` + ImagePrice int `json:"image_price"` + PromptTextTokenPrice int `json:"prompt_text_token_price"` + PromptImageTokenPrice int `json:"prompt_image_token_price"` + GeneratedImageTokenPrice int `json:"generated_image_token_price"` + Aliases []string `json:"aliases"` +} + +// XAILanguageModelsResponse represents the response from /v1/language-models +type XAILanguageModelsResponse struct { + Models []XAIModelInfo `json:"models"` +} + +// XAIImageModelsResponse represents the response from /v1/image-generation-models +type XAIImageModelsResponse struct { + Models []XAIImageModelInfo `json:"models"` +} + +// ModelCapabilities represents the capabilities of a model +type ModelCapabilities struct { + SupportsText bool + SupportsImageInput bool + SupportsImageOutput bool + SupportsWebSearch bool + MaxPromptLength int + Aliases []string +} + +// DiscoverModelCapabilities queries the xAI API to discover model capabilities +func (x *xaiClient) DiscoverModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { + // First try language models endpoint + langCaps, err := x.getLanguageModelCapabilities(ctx, modelID) + if err == nil && langCaps != nil { + return langCaps, nil + } + + // Then try image generation models endpoint + imgCaps, err := x.getImageModelCapabilities(ctx, modelID) + if err == nil && imgCaps != nil { + return imgCaps, nil + } + + // Fallback to basic model info + return x.getBasicModelCapabilities(ctx, modelID) +} + +// getLanguageModelCapabilities fetches capabilities from language models endpoint +func (x *xaiClient) getLanguageModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { + url := fmt.Sprintf("%s/v1/language-models/%s", x.getBaseURL(), modelID) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil // Not a language model + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + } + + var modelInfo XAIModelInfo + if err := json.NewDecoder(resp.Body).Decode(&modelInfo); err != nil { + return nil, err + } + + caps := &ModelCapabilities{ + Aliases: modelInfo.Aliases, + } + + // Check input modalities + for _, mod := range modelInfo.InputModalities { + switch mod { + case "text": + caps.SupportsText = true + case "image": + caps.SupportsImageInput = true + } + } + + // Check output modalities + for _, mod := range modelInfo.OutputModalities { + switch mod { + case "text": + // Text output is standard + case "image": + caps.SupportsImageOutput = true + } + } + + // Web search is available for all language models + caps.SupportsWebSearch = caps.SupportsText + + logging.Debug("Discovered language model capabilities", + "model", modelID, + "text", caps.SupportsText, + "image_input", caps.SupportsImageInput, + "image_output", caps.SupportsImageOutput, + "web_search", caps.SupportsWebSearch) + + return caps, nil +} + +// getImageModelCapabilities fetches capabilities from image generation models endpoint +func (x *xaiClient) getImageModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { + url := fmt.Sprintf("%s/v1/image-generation-models/%s", x.getBaseURL(), modelID) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil // Not an image generation model + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + } + + var modelInfo XAIImageModelInfo + if err := json.NewDecoder(resp.Body).Decode(&modelInfo); err != nil { + return nil, err + } + + caps := &ModelCapabilities{ + SupportsText: true, // Image generation takes text prompts + SupportsImageOutput: true, + MaxPromptLength: modelInfo.MaxPromptLength, + Aliases: modelInfo.Aliases, + } + + logging.Debug("Discovered image generation model capabilities", + "model", modelID, + "max_prompt_length", caps.MaxPromptLength) + + return caps, nil +} + +// getBasicModelCapabilities fetches basic model info as fallback +func (x *xaiClient) getBasicModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { + url := fmt.Sprintf("%s/v1/models/%s", x.getBaseURL(), modelID) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + } + + // Basic model info doesn't provide capability details + // Return minimal capabilities + return &ModelCapabilities{ + SupportsText: true, // Assume text support for all models + }, nil +} + +// getBaseURL returns the base URL for API requests +func (x *xaiClient) getBaseURL() string { + if x.options.baseURL != "" { + return x.options.baseURL + } + return "https://api.x.ai" +} + +// ListAllModels fetches all available models from xAI API +func (x *xaiClient) ListAllModels(ctx context.Context) ([]XAIModelInfo, []XAIImageModelInfo, error) { + var languageModels []XAIModelInfo + var imageModels []XAIImageModelInfo + + // Fetch language models + langURL := fmt.Sprintf("%s/v1/language-models", x.getBaseURL()) + langReq, err := http.NewRequestWithContext(ctx, "GET", langURL, nil) + if err != nil { + return nil, nil, err + } + langReq.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + langResp, err := client.Do(langReq) + if err != nil { + return nil, nil, err + } + defer langResp.Body.Close() + + if langResp.StatusCode == http.StatusOK { + var langModelsResp XAILanguageModelsResponse + if err := json.NewDecoder(langResp.Body).Decode(&langModelsResp); err != nil { + return nil, nil, err + } + languageModels = langModelsResp.Models + } + + // Fetch image generation models + imgURL := fmt.Sprintf("%s/v1/image-generation-models", x.getBaseURL()) + imgReq, err := http.NewRequestWithContext(ctx, "GET", imgURL, nil) + if err != nil { + return languageModels, nil, err + } + imgReq.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + + imgResp, err := client.Do(imgReq) + if err != nil { + return languageModels, nil, err + } + defer imgResp.Body.Close() + + if imgResp.StatusCode == http.StatusOK { + var imgModelsResp XAIImageModelsResponse + if err := json.NewDecoder(imgResp.Body).Decode(&imgModelsResp); err != nil { + return languageModels, nil, err + } + imageModels = imgModelsResp.Models + } + + return languageModels, imageModels, nil +} diff --git a/internal/llm/provider/xai_streaming.go b/internal/llm/provider/xai_streaming.go new file mode 100644 index 00000000..29951c03 --- /dev/null +++ b/internal/llm/provider/xai_streaming.go @@ -0,0 +1,123 @@ +package provider + +import ( + "context" + "time" + + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/logging" + "github.com/opencode-ai/opencode/internal/message" +) + +// StreamingMetrics tracks streaming performance for xAI +type StreamingMetrics struct { + FirstTokenTime time.Duration + TotalStreamTime time.Duration + TokenCount int + ChunkCount int + SystemFingerprint string +} + +// streamWithMetrics wraps the base streaming with xAI-specific metrics and handling +func (x *xaiClient) streamWithMetrics(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + // Use concurrent client if configured + if x.concurrent != nil { + return x.concurrent.stream(ctx, messages, tools) + } + + // Get the base stream + baseChan := x.openaiClient.stream(ctx, messages, tools) + + // Create a new channel to intercept and process events + eventChan := make(chan ProviderEvent) + + go func() { + defer close(eventChan) + + startTime := time.Now() + var firstTokenTime time.Duration + var metrics StreamingMetrics + tokenCount := 0 + chunkCount := 0 + + for event := range baseChan { + // Track metrics + switch event.Type { + case EventContentDelta: + chunkCount++ + if tokenCount == 0 { + firstTokenTime = time.Since(startTime) + metrics.FirstTokenTime = firstTokenTime + logging.Debug("xAI streaming first token received", + "model", x.providerOptions.model.ID, + "time_to_first_token", firstTokenTime) + } + tokenCount += len(event.Content) + + case EventComplete: + if event.Response != nil { + metrics.TotalStreamTime = time.Since(startTime) + metrics.TokenCount = tokenCount + metrics.ChunkCount = chunkCount + metrics.SystemFingerprint = event.Response.SystemFingerprint + + // Track fingerprint for monitoring, security, and compliance + if event.Response.SystemFingerprint != "" { + x.trackFingerprint(event.Response.SystemFingerprint, event.Response.Usage) + } + + // Log streaming metrics + logging.Debug("xAI streaming completed", + "model", x.providerOptions.model.ID, + "total_time", metrics.TotalStreamTime, + "first_token_time", metrics.FirstTokenTime, + "chunks", metrics.ChunkCount, + "usage", event.Response.Usage, + "system_fingerprint", metrics.SystemFingerprint) + } + } + + // Forward the event + eventChan <- event + } + }() + + return eventChan +} + +// EnhancedStreamOptions provides xAI-specific streaming configuration +type EnhancedStreamOptions struct { + // TimeoutOverride allows manual timeout override for reasoning models + TimeoutOverride *time.Duration + + // EnableMetrics enables detailed streaming metrics + EnableMetrics bool + + // BufferSize controls the event channel buffer size + BufferSize int +} + +// streamWithOptions provides enhanced streaming with xAI-specific options +func (x *xaiClient) streamWithOptions(ctx context.Context, messages []message.Message, tools []tools.BaseTool, opts EnhancedStreamOptions) <-chan ProviderEvent { + // Apply timeout override if specified (useful for reasoning models) + if opts.TimeoutOverride != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *opts.TimeoutOverride) + defer cancel() + } + + // Use metrics-enabled streaming if requested + if opts.EnableMetrics { + return x.streamWithMetrics(ctx, messages, tools) + } + + // Otherwise use standard streaming + return x.stream(ctx, messages, tools) +} + +// ValidateStreamingSupport checks if the model supports streaming +func (x *xaiClient) ValidateStreamingSupport() error { + // All xAI chat models support streaming + // This is a placeholder for future model-specific validation + return nil +} diff --git a/internal/llm/provider/xai_streaming_test.go b/internal/llm/provider/xai_streaming_test.go new file mode 100644 index 00000000..e1828edb --- /dev/null +++ b/internal/llm/provider/xai_streaming_test.go @@ -0,0 +1,260 @@ +package provider + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_Streaming(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("basic streaming", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(200), + WithSystemMessage("You are a helpful assistant. Be concise."), + ) + require.NoError(t, err) + + // Create a simple message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Count from 1 to 5, one number per line."}, + }, + }, + } + + // Stream response + ctx := context.Background() + eventChan := provider.StreamResponse(ctx, messages, nil) + + // Collect events + var contentChunks []string + var finalResponse *ProviderResponse + eventCount := 0 + hasContentDelta := false + + for event := range eventChan { + eventCount++ + switch event.Type { + case EventContentDelta: + hasContentDelta = true + contentChunks = append(contentChunks, event.Content) + t.Logf("Content delta: %q", event.Content) + + case EventComplete: + finalResponse = event.Response + t.Logf("Stream complete - Total chunks: %d", len(contentChunks)) + + case EventError: + t.Fatalf("Streaming error: %v", event.Error) + } + } + + // Verify streaming worked correctly + require.NotNil(t, finalResponse) + assert.True(t, hasContentDelta, "Should have received content deltas") + assert.Greater(t, eventCount, 1, "Should have multiple events") + + // Verify content accumulation + accumulatedContent := strings.Join(contentChunks, "") + assert.Equal(t, finalResponse.Content, accumulatedContent, "Accumulated content should match final response") + + // Verify xAI-specific fields + assert.NotEmpty(t, finalResponse.SystemFingerprint, "Should have system fingerprint") + assert.Greater(t, finalResponse.Usage.InputTokens, int64(0), "Should have input tokens") + assert.Greater(t, finalResponse.Usage.OutputTokens, int64(0), "Should have output tokens") + + t.Logf("Final content: %s", finalResponse.Content) + t.Logf("System fingerprint: %s", finalResponse.SystemFingerprint) + t.Logf("Usage: %+v", finalResponse.Usage) + }) + + t.Run("streaming with enhanced metrics", func(t *testing.T) { + // Create xAI client directly to access enhanced features + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 150, + systemMessage: "You are a helpful assistant.", + } + + xaiClient := newXAIClient(opts).(*xaiClient) + + // Create message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is 2+2? Answer in one sentence."}, + }, + }, + } + + // Stream with metrics + ctx := context.Background() + startTime := time.Now() + eventChan := xaiClient.streamWithMetrics(ctx, messages, nil) + + // Track metrics + var firstTokenTime time.Duration + chunkCount := 0 + firstChunkReceived := false + + for event := range eventChan { + switch event.Type { + case EventContentDelta: + if !firstChunkReceived { + firstTokenTime = time.Since(startTime) + firstChunkReceived = true + t.Logf("Time to first token: %v", firstTokenTime) + } + chunkCount++ + + case EventComplete: + totalTime := time.Since(startTime) + t.Logf("Total streaming time: %v", totalTime) + t.Logf("Total chunks: %d", chunkCount) + + // Verify reasonable performance + assert.Less(t, firstTokenTime, 5*time.Second, "First token should arrive quickly") + assert.Less(t, totalTime, 10*time.Second, "Total streaming should complete reasonably fast") + } + } + + assert.True(t, firstChunkReceived, "Should have received at least one chunk") + assert.Greater(t, chunkCount, 0, "Should have received chunks") + }) + + t.Run("streaming with reasoning model", func(t *testing.T) { + // Create provider with grok-4 + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + ) + require.NoError(t, err) + + // Create a reasoning task + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is 15% of 80? Show your calculation."}, + }, + }, + } + + // Stream with longer timeout for reasoning + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + eventChan := provider.StreamResponse(ctx, messages, nil) + + // Collect response + var finalResponse *ProviderResponse + hasContent := false + + for event := range eventChan { + switch event.Type { + case EventContentDelta: + hasContent = true + + case EventComplete: + finalResponse = event.Response + + case EventError: + t.Fatalf("Streaming error: %v", event.Error) + } + } + + // Verify response + require.NotNil(t, finalResponse) + assert.True(t, hasContent, "Should have received content") + assert.NotEmpty(t, finalResponse.Content) + assert.NotEmpty(t, finalResponse.SystemFingerprint) + + t.Logf("Reasoning response: %s", finalResponse.Content) + }) + + t.Run("streaming interruption handling", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + ) + require.NoError(t, err) + + // Create a longer task + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Count from 1 to 20, one number per line."}, + }, + }, + } + + // Create cancellable context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Ensure cancel is always called + + eventChan := provider.StreamResponse(ctx, messages, nil) + + // Cancel after receiving some chunks + chunkCount := 0 + for event := range eventChan { + if event.Type == EventContentDelta { + chunkCount++ + if chunkCount >= 3 { + cancel() // Interrupt streaming + break + } + } + } + + // Verify we received some chunks before cancellation + assert.GreaterOrEqual(t, chunkCount, 3, "Should have received chunks before cancellation") + + // Drain remaining events (should close soon after cancellation) + timeout := time.After(5 * time.Second) + for { + select { + case _, ok := <-eventChan: + if !ok { + return // Channel closed as expected + } + case <-timeout: + t.Fatal("Channel did not close after cancellation") + } + } + }) +} diff --git a/internal/llm/provider/xai_test.go b/internal/llm/provider/xai_test.go new file mode 100644 index 00000000..7253e88c --- /dev/null +++ b/internal/llm/provider/xai_test.go @@ -0,0 +1,1267 @@ +package provider + +import ( + "context" + "encoding/json" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockTool implements a simple tool for testing +type MockTool struct { + name string + description string + parameters map[string]interface{} + required []string + response string + callCount int +} + +func (t *MockTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: t.name, + Description: t.description, + Parameters: t.parameters, + Required: t.required, + } +} + +func (t *MockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) { + t.callCount++ + return tools.NewTextResponse(t.response), nil +} + +func TestXAIProvider_FunctionCalling(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("basic function calling", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(1000), + WithSystemMessage("You are a helpful assistant. Use the provided tools when appropriate."), + ) + require.NoError(t, err) + + // Create a mock tool + mockTool := &MockTool{ + name: "get_weather", + description: "Get the current weather in a given location", + parameters: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + required: []string{"location"}, + response: `{"temperature": 72, "condition": "sunny"}`, + } + + // Create a message that should trigger tool use + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What's the weather like in San Francisco?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{mockTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify tool was called + assert.Equal(t, message.FinishReasonToolUse, response.FinishReason) + assert.NotEmpty(t, response.ToolCalls) + assert.NotEmpty(t, response.SystemFingerprint) + + // Verify tool call details + if len(response.ToolCalls) > 0 { + assert.Equal(t, "get_weather", response.ToolCalls[0].Name) + assert.NotEmpty(t, response.ToolCalls[0].ID) + assert.NotEmpty(t, response.ToolCalls[0].Input) + assert.True(t, response.ToolCalls[0].Finished) + + // Verify the input contains location + var input map[string]interface{} + err := json.Unmarshal([]byte(response.ToolCalls[0].Input), &input) + assert.NoError(t, err) + assert.Contains(t, input, "location") + } + }) + + t.Run("tool choice modes", func(t *testing.T) { + testCases := []struct { + name string + toolChoice string + message string + expectTool bool + }{ + { + name: "auto mode with tool-triggering prompt", + toolChoice: "auto", + message: "What's the weather in New York?", + expectTool: true, + }, + { + name: "none mode should not call tools", + toolChoice: "none", + message: "What's the weather in New York?", + expectTool: false, + }, + { + name: "required mode forces tool call", + toolChoice: "required", + message: "Hello there!", + expectTool: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create provider with specific tool choice + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + WithOpenAIOptions(WithOpenAIToolChoice(tc.toolChoice)), + ) + require.NoError(t, err) + + // Create a mock tool + mockTool := &MockTool{ + name: "get_info", + description: "Get information about a topic", + parameters: map[string]interface{}{ + "topic": map[string]interface{}{ + "type": "string", + "description": "The topic to get information about", + }, + }, + required: []string{"topic"}, + response: `{"info": "some information"}`, + } + + // Create message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: tc.message}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{mockTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify expectations + if tc.expectTool { + assert.Equal(t, message.FinishReasonToolUse, response.FinishReason) + assert.NotEmpty(t, response.ToolCalls) + } else { + assert.NotEqual(t, message.FinishReasonToolUse, response.FinishReason) + assert.Empty(t, response.ToolCalls) + assert.NotEmpty(t, response.Content) // Should have text response instead + } + }) + } + }) + + t.Run("parallel function calling", func(t *testing.T) { + // Create provider with parallel tool calls enabled + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(1000), + WithSystemMessage("You are a helpful assistant. When asked to do multiple things, use multiple tools in parallel if appropriate."), + WithOpenAIOptions(WithOpenAIParallelToolCalls(true)), + ) + require.NoError(t, err) + + // Create multiple mock tools + weatherTool := &MockTool{ + name: "get_weather", + description: "Get the weather in a location", + parameters: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The location", + }, + }, + required: []string{"location"}, + response: `{"temperature": 72}`, + } + + timeTool := &MockTool{ + name: "get_time", + description: "Get the current time in a location", + parameters: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The location", + }, + }, + required: []string{"location"}, + response: `{"time": "2:30 PM"}`, + } + + // Create a message that could trigger multiple tools + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What's the weather and current time in both Paris and London?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{weatherTool, timeTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify function calling occurred + assert.Equal(t, message.FinishReasonToolUse, response.FinishReason) + assert.NotEmpty(t, response.ToolCalls) + + // Log the number of tool calls for observation + t.Logf("Number of tool calls: %d", len(response.ToolCalls)) + for i, call := range response.ToolCalls { + t.Logf("Tool call %d: %s", i+1, call.Name) + } + }) + + t.Run("streaming with function calls", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant."), + ) + require.NoError(t, err) + + // Create a mock tool + mockTool := &MockTool{ + name: "calculate", + description: "Perform a calculation", + parameters: map[string]interface{}{ + "expression": map[string]interface{}{ + "type": "string", + "description": "The mathematical expression", + }, + }, + required: []string{"expression"}, + response: `{"result": 42}`, + } + + // Create message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is 6 times 7?"}, + }, + }, + } + + // Stream response + eventChan := provider.StreamResponse(context.Background(), messages, []tools.BaseTool{mockTool}) + + var finalResponse *ProviderResponse + var hasContent bool + + for event := range eventChan { + switch event.Type { + case EventContentDelta: + hasContent = true + case EventComplete: + finalResponse = event.Response + case EventError: + t.Fatalf("Streaming error: %v", event.Error) + } + } + + require.NotNil(t, finalResponse) + + // According to xAI docs, function calls come in whole chunks in streaming + if finalResponse.FinishReason == message.FinishReasonToolUse { + assert.NotEmpty(t, finalResponse.ToolCalls) + assert.NotEmpty(t, finalResponse.SystemFingerprint) + } else { + // If no tool was called, we should have content + assert.True(t, hasContent || finalResponse.Content != "") + } + }) + + t.Run("system fingerprint tracking", func(t *testing.T) { + // Create xAI provider (which includes fingerprint tracking) + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(100), + WithSystemMessage("You are a helpful assistant."), + ) + require.NoError(t, err) + + // Send multiple requests to check fingerprint + fingerprints := make([]string, 0) + + for i := 0; i < 3; i++ { + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Say hello"}, + }, + }, + } + + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // xAI should always return a system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + fingerprints = append(fingerprints, response.SystemFingerprint) + + t.Logf("Request %d - System fingerprint: %s", i+1, response.SystemFingerprint) + } + + // Fingerprints might be the same or different depending on backend changes + // We just verify they are populated + for _, fp := range fingerprints { + assert.NotEmpty(t, fp) + } + }) +} + +func TestXAIProvider_StructuredOutput(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("JSON mode", func(t *testing.T) { + // Create provider with JSON mode + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant that responds in JSON format."), + WithOpenAIOptions(WithOpenAIJSONMode()), + ) + require.NoError(t, err) + + // Create message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Create a JSON object with a name and age for a fictional person."}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify response is valid JSON + var result map[string]interface{} + err = json.Unmarshal([]byte(response.Content), &result) + assert.NoError(t, err, "Response should be valid JSON") + assert.NotEmpty(t, result) + }) + + t.Run("JSON schema mode", func(t *testing.T) { + // Define a schema + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "Person's name", + }, + "age": map[string]interface{}{ + "type": "integer", + "description": "Person's age", + "minimum": 0, + "maximum": 120, + }, + "email": map[string]interface{}{ + "type": "string", + "description": "Email address", + "format": "email", + }, + }, + "required": []string{"name", "age"}, + } + + // Create provider with JSON schema + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant that provides structured data."), + WithOpenAIOptions(WithOpenAIJSONSchema("person_info", schema)), + ) + require.NoError(t, err) + + // Create message + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Create information for a person named Alice who is 25 years old."}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify response matches schema + var result map[string]interface{} + err = json.Unmarshal([]byte(response.Content), &result) + require.NoError(t, err, "Response should be valid JSON") + + // Check required fields + assert.Contains(t, result, "name") + assert.Contains(t, result, "age") + + // Verify types + _, nameIsString := result["name"].(string) + assert.True(t, nameIsString, "name should be a string") + + // JSON numbers are parsed as float64 + age, ageIsNumber := result["age"].(float64) + assert.True(t, ageIsNumber, "age should be a number") + if ageIsNumber { + assert.GreaterOrEqual(t, age, float64(0)) + assert.LessOrEqual(t, age, float64(120)) + } + }) +} + +func TestXAIProvider_ConcurrentRequests(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("concurrent requests with rate limiting", func(t *testing.T) { + // Create xAI client with max 2 concurrent requests + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 100, + systemMessage: "You are a helpful assistant. Respond concisely.", + xaiOptions: []XAIOption{ + WithMaxConcurrentRequests(2), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + require.NotNil(t, xaiClient.concurrent) + + // Track request timings + var requestTimes sync.Map + var requestCount int32 + + // Create 5 concurrent requests + numRequests := 5 + var wg sync.WaitGroup + responses := make([]*ProviderResponse, numRequests) + errors := make([]error, numRequests) + + startTime := time.Now() + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + reqStart := time.Now() + atomic.AddInt32(&requestCount, 1) + + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Say hello"}, + }, + }, + } + + resp, err := xaiClient.send(context.Background(), messages, nil) + + reqDuration := time.Since(reqStart) + requestTimes.Store(idx, reqDuration) + + responses[idx] = resp + errors[idx] = err + }(i) + } + + wg.Wait() + totalDuration := time.Since(startTime) + + // Verify all requests completed + for i, err := range errors { + require.NoError(t, err, "Request %d should not error", i) + require.NotNil(t, responses[i], "Request %d should have response", i) + assert.NotEmpty(t, responses[i].Content) + assert.NotEmpty(t, responses[i].SystemFingerprint) + } + + // Verify fingerprint tracking + history := xaiClient.GetFingerprintHistory() + assert.Len(t, history, numRequests) + + // Log timing information + t.Logf("Total duration for %d requests: %v", numRequests, totalDuration) + t.Logf("Max concurrent requests: %d", xaiClient.concurrent.GetMaxConcurrent()) + + // Since we have max 2 concurrent requests, at least 3 batches should be needed + // This is a rough check, actual timing depends on API response times + t.Logf("Average time per request: %v", totalDuration/time.Duration(numRequests)) + }) + + t.Run("batch requests", func(t *testing.T) { + // Create xAI client with concurrent support + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 150, + systemMessage: "You are a helpful assistant.", + xaiOptions: []XAIOption{ + WithMaxConcurrentRequests(3), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + + // Create batch requests + requests := []BatchRequest{ + { + Messages: []message.Message{{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is 2+2?"}, + }, + }}, + }, + { + Messages: []message.Message{{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What is the capital of France?"}, + }, + }}, + }, + { + Messages: []message.Message{{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What color is the sky?"}, + }, + }}, + }, + } + + // Send batch + responses := xaiClient.SendBatch(context.Background(), requests) + + // Verify all responses + assert.Len(t, responses, 3) + for i, resp := range responses { + assert.NoError(t, resp.Error) + assert.NotNil(t, resp.Response) + assert.NotEmpty(t, resp.Response.Content) + assert.Equal(t, i, resp.Index) + t.Logf("Response %d: %s", i, resp.Response.Content) + } + }) + + t.Run("streaming batch requests", func(t *testing.T) { + // Create xAI client with concurrent support + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 100, + systemMessage: "You are a helpful assistant. Respond concisely.", + xaiOptions: []XAIOption{ + WithMaxConcurrentRequests(2), + }, + } + + xaiClient := newXAIClient(opts).(*xaiClient) + + // Create batch streaming requests + requests := []BatchRequest{ + { + Messages: []message.Message{{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Count to 3"}, + }, + }}, + }, + { + Messages: []message.Message{{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Say ABC"}, + }, + }}, + }, + } + + // Stream batch + channels := xaiClient.StreamBatch(context.Background(), requests) + assert.Len(t, channels, 2) + + // Collect responses from all streams + var wg sync.WaitGroup + responses := make([]*ProviderResponse, len(channels)) + + for i, ch := range channels { + wg.Add(1) + go func(idx int, eventChan <-chan ProviderEvent) { + defer wg.Done() + + for event := range eventChan { + if event.Type == EventComplete { + responses[idx] = event.Response + } else if event.Type == EventError { + t.Errorf("Stream %d error: %v", idx, event.Error) + } + } + }(i, ch) + } + + wg.Wait() + + // Verify all streams completed + for i, resp := range responses { + assert.NotNil(t, resp, "Stream %d should have response", i) + assert.NotEmpty(t, resp.Content) + assert.NotEmpty(t, resp.SystemFingerprint) + } + }) + + t.Run("runtime max concurrent update", func(t *testing.T) { + // Create xAI client without initial concurrent support + opts := providerClientOptions{ + apiKey: apiKey, + model: models.SupportedModels[models.XAIGrok3Fast], + maxTokens: 50, + systemMessage: "You are a helpful assistant.", + } + + xaiClient := newXAIClient(opts).(*xaiClient) + + // Initially no concurrent client + assert.Nil(t, xaiClient.concurrent) + + // Set max concurrent requests at runtime + xaiClient.SetMaxConcurrentRequests(3) + assert.NotNil(t, xaiClient.concurrent) + assert.Equal(t, int64(3), xaiClient.concurrent.GetMaxConcurrent()) + + // Update max concurrent requests + xaiClient.SetMaxConcurrentRequests(5) + assert.Equal(t, int64(5), xaiClient.concurrent.GetMaxConcurrent()) + + // Test that it works + messages := []message.Message{{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Hi"}, + }, + }} + + resp, err := xaiClient.send(context.Background(), messages, nil) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.NotEmpty(t, resp.Content) + }) +} + +func TestXAIProvider_LiveSearchFunctionality(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("live search with web search tool", func(t *testing.T) { + // Create provider with Live Search enabled + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant. Use web search when you need current information."), + WithXAIOptions( + WithLiveSearch(), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + // Create a message that should trigger web search + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What are the latest developments in AI in 2025?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{webSearchTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Should have either tool calls or direct content with Live Search + assert.True(t, len(response.ToolCalls) > 0 || response.Content != "") + + // Log response details + t.Logf("Response content length: %d", len(response.Content)) + t.Logf("Tool calls: %d", len(response.ToolCalls)) + t.Logf("Citations: %d", len(response.Citations)) + + if len(response.ToolCalls) > 0 { + for i, call := range response.ToolCalls { + t.Logf("Tool call %d: %s", i, call.Name) + assert.Equal(t, "web_search", call.Name) + assert.NotEmpty(t, call.Input) + } + } + + // Check for citations if present + for i, citation := range response.Citations { + t.Logf("Citation %d: %s", i+1, citation) + assert.NotEmpty(t, citation) + } + }) + + t.Run("live search without tools (direct integration)", func(t *testing.T) { + // Create provider with Live Search enabled + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(300), + WithSystemMessage("You are a helpful assistant with access to current information."), + WithXAIOptions( + WithLiveSearchOptions(LiveSearchOptions{ + Mode: "on", + MaxSearchResults: intPtr(5), + ReturnCitations: boolPtr(true), + Sources: []LiveSearchSource{ + {Type: "web"}, + {Type: "news"}, + }, + }), + ), + ) + require.NoError(t, err) + + // Create a message asking for current information (no tools provided) + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What's happening in the tech industry today?"}, + }, + }, + } + + // Send request without providing tools + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Should have content since Live Search is enabled + assert.NotEmpty(t, response.Content) + + // Log response details + t.Logf("Direct Live Search response length: %d", len(response.Content)) + t.Logf("Citations: %d", len(response.Citations)) + + // Should have citations since we enabled them + if len(response.Citations) > 0 { + t.Logf("Citations received: %v", response.Citations) + for _, citation := range response.Citations { + assert.NotEmpty(t, citation) + } + } + }) + + t.Run("live search combined with deferred completion", func(t *testing.T) { + // Create provider with both Live Search and deferred completion + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a research assistant."), + WithXAIOptions( + WithLiveSearch(), + WithDeferredCompletion(), + ), + ) + require.NoError(t, err) + + // Create a complex research query + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Please provide a comprehensive analysis of recent AI safety research, including key papers, industry initiatives, and regulatory developments in 2025."}, + }, + }, + } + + // Send request (should use deferred completion with Live Search) + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Should have substantial content + assert.NotEmpty(t, response.Content) + + // Log response details + t.Logf("Deferred + Live Search response length: %d", len(response.Content)) + t.Logf("Citations: %d", len(response.Citations)) + t.Logf("System fingerprint: %s", response.SystemFingerprint) + }) + + t.Run("live search with specific source filtering", func(t *testing.T) { + // Create provider with specific source configuration + customOpts := LiveSearchOptions{ + Mode: "on", + MaxSearchResults: intPtr(3), + ReturnCitations: boolPtr(true), + Sources: []LiveSearchSource{ + { + Type: "x", + IncludedXHandles: []string{"xai"}, + }, + }, + } + + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(300), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearchOptions(customOpts), + ), + ) + require.NoError(t, err) + + // Ask about xAI specifically + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What are the latest announcements from xAI?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Should have content + assert.NotEmpty(t, response.Content) + + // Log response details + t.Logf("X-filtered Live Search response: %s", response.Content) + t.Logf("Citations: %v", response.Citations) + }) + + t.Run("live search mode off", func(t *testing.T) { + // Create provider with Live Search explicitly turned off + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(200), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearchOptions(LiveSearchOptions{ + Mode: "off", + }), + ), + ) + require.NoError(t, err) + + // Ask for current information + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What's the latest news in technology?"}, + }, + }, + } + + // Send request + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Should have system fingerprint + assert.NotEmpty(t, response.SystemFingerprint) + + // Should have content (but based on training data, not live search) + assert.NotEmpty(t, response.Content) + + // Should not have citations since Live Search is off + assert.Empty(t, response.Citations) + + t.Logf("Live Search off response: %s", response.Content) + }) +} + +func TestXAIProvider_LiveSearchIntegrationDetailed(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("basic live search integration", func(t *testing.T) { + // Create provider with default Live Search + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(300), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearch(), + ), + ) + require.NoError(t, err) + + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What are the latest AI developments in 2025?"}, + }, + }, + } + + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Validate response structure + assert.NotEmpty(t, response.Content) + assert.NotEmpty(t, response.SystemFingerprint) + + t.Logf("Basic Live Search Response:") + t.Logf(" Response length: %d characters", len(response.Content)) + t.Logf(" Citations: %d", len(response.Citations)) + t.Logf(" System fingerprint: %s", response.SystemFingerprint) + + if len(response.Citations) > 0 { + t.Logf(" Citations received:") + for i, citation := range response.Citations { + assert.NotEmpty(t, citation) + t.Logf(" %d. %s", i+1, citation) + } + } + }) + + t.Run("custom live search parameters integration", func(t *testing.T) { + // Create provider with custom Live Search options + maxResults := 5 + returnCitations := true + + customOpts := LiveSearchOptions{ + Mode: "on", + MaxSearchResults: &maxResults, + ReturnCitations: &returnCitations, + Sources: []LiveSearchSource{ + {Type: "web"}, + {Type: "news"}, + }, + } + + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(300), + WithSystemMessage("You are a helpful assistant."), + WithXAIOptions( + WithLiveSearchOptions(customOpts), + ), + ) + require.NoError(t, err) + + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What's happening in the tech industry today?"}, + }, + }, + } + + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Validate response structure + assert.NotEmpty(t, response.Content) + assert.NotEmpty(t, response.SystemFingerprint) + + t.Logf("Custom Live Search Response:") + t.Logf(" Response length: %d characters", len(response.Content)) + t.Logf(" Citations: %d", len(response.Citations)) + t.Logf(" System fingerprint: %s", response.SystemFingerprint) + + // Should have citations since we enabled them and mode is "on" + if returnCitations && len(response.Citations) > 0 { + t.Logf(" Citations (as expected):") + for i, citation := range response.Citations { + assert.NotEmpty(t, citation) + t.Logf(" %d. %s", i+1, citation) + } + } + }) + + t.Run("live search with web search tool integration", func(t *testing.T) { + // Create provider with Live Search + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(500), + WithSystemMessage("You are a helpful assistant. Use web search when you need current information."), + WithXAIOptions( + WithLiveSearch(), + ), + ) + require.NoError(t, err) + + // Create web search tool + webSearchTool := &tools.WebSearchTool{} + + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Search for recent news about artificial intelligence breakthroughs."}, + }, + }, + } + + response, err := provider.SendMessages(context.Background(), messages, []tools.BaseTool{webSearchTool}) + require.NoError(t, err) + require.NotNil(t, response) + + // Validate response structure + assert.NotEmpty(t, response.SystemFingerprint) + // Should have either content or tool calls + assert.True(t, response.Content != "" || len(response.ToolCalls) > 0) + + t.Logf("Tool-based Live Search Response:") + t.Logf(" Response length: %d characters", len(response.Content)) + t.Logf(" Tool calls: %d", len(response.ToolCalls)) + t.Logf(" Citations: %d", len(response.Citations)) + t.Logf(" System fingerprint: %s", response.SystemFingerprint) + + if len(response.ToolCalls) > 0 { + t.Logf(" Tool calls made:") + for i, call := range response.ToolCalls { + assert.Equal(t, "web_search", call.Name) + assert.NotEmpty(t, call.Input) + t.Logf(" %d. %s", i+1, call.Name) + + // Validate tool call input contains Live Search parameters + var params map[string]interface{} + err := json.Unmarshal([]byte(call.Input), ¶ms) + assert.NoError(t, err) + assert.Contains(t, params, "query") + t.Logf(" Query: %v", params["query"]) + } + } + + // Validate citations if present + for i, citation := range response.Citations { + assert.NotEmpty(t, citation) + t.Logf(" Citation %d: %s", i+1, citation) + } + }) + + t.Run("live search comprehensive feature test", func(t *testing.T) { + // Test with multiple advanced features + customOpts := LiveSearchOptions{ + Mode: "auto", + MaxSearchResults: intPtr(3), + ReturnCitations: boolPtr(true), + Sources: []LiveSearchSource{ + { + Type: "x", + IncludedXHandles: []string{"xai"}, + }, + { + Type: "web", + Country: stringPtr("US"), + }, + }, + } + + // Combine with other xAI features + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok4]), + WithMaxTokens(400), + WithSystemMessage("You are a research assistant with access to current information."), + WithXAIOptions( + WithLiveSearchOptions(customOpts), + WithMaxConcurrentRequests(2), + ), + ) + require.NoError(t, err) + + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What are the latest updates from xAI and recent AI research?"}, + }, + }, + } + + response, err := provider.SendMessages(context.Background(), messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Validate comprehensive response + assert.NotEmpty(t, response.Content) + assert.NotEmpty(t, response.SystemFingerprint) + + t.Logf("Comprehensive Live Search Response:") + t.Logf(" Response length: %d characters", len(response.Content)) + t.Logf(" Citations: %d", len(response.Citations)) + t.Logf(" System fingerprint: %s", response.SystemFingerprint) + t.Logf(" First 200 chars: %s...", + func() string { + if len(response.Content) > 200 { + return response.Content[:200] + } + return response.Content + }()) + + // Log citations for verification + if len(response.Citations) > 0 { + t.Logf(" Citations from X and web sources:") + for i, citation := range response.Citations { + assert.NotEmpty(t, citation) + t.Logf(" %d. %s", i+1, citation) + } + } + }) +} + +// Helper functions for pointer creation +func stringPtr(s string) *string { + return &s +} + +func intPtr(i int) *int { + return &i +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/internal/llm/provider/xai_validation.go b/internal/llm/provider/xai_validation.go new file mode 100644 index 00000000..f4e553bc --- /dev/null +++ b/internal/llm/provider/xai_validation.go @@ -0,0 +1,242 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/opencode-ai/opencode/internal/logging" +) + +// XAIAPIKeyInfo represents the information about an xAI API key +type XAIAPIKeyInfo struct { + RedactedAPIKey string `json:"redacted_api_key"` + UserID string `json:"user_id"` + Name string `json:"name"` + CreateTime string `json:"create_time"` + ModifyTime string `json:"modify_time"` + ModifiedBy string `json:"modified_by"` + TeamID string `json:"team_id"` + ACLs []string `json:"acls"` + APIKeyID string `json:"api_key_id"` + TeamBlocked bool `json:"team_blocked"` + APIKeyBlocked bool `json:"api_key_blocked"` + APIKeyDisabled bool `json:"api_key_disabled"` +} + +// ValidateAPIKey validates the xAI API key and returns detailed information about it +func (x *xaiClient) ValidateAPIKey(ctx context.Context) (*XAIAPIKeyInfo, error) { + url := fmt.Sprintf("%s/v1/api-key", x.getBaseURL()) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create validation request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to validate API key: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read validation response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API key validation failed with status %d: %s", resp.StatusCode, string(body)) + } + + var keyInfo XAIAPIKeyInfo + if err := json.Unmarshal(body, &keyInfo); err != nil { + return nil, fmt.Errorf("failed to parse API key info: %w", err) + } + + logging.Debug("xAI API key validation successful", + "redacted_key", keyInfo.RedactedAPIKey, + "name", keyInfo.Name, + "team_id", keyInfo.TeamID, + "blocked", keyInfo.APIKeyBlocked, + "disabled", keyInfo.APIKeyDisabled, + "team_blocked", keyInfo.TeamBlocked) + + return &keyInfo, nil +} + +// IsAPIKeyValid performs a quick validation check and returns true if the key is valid and active +func (x *xaiClient) IsAPIKeyValid(ctx context.Context) bool { + keyInfo, err := x.ValidateAPIKey(ctx) + if err != nil { + logging.Debug("API key validation failed", "error", err) + return false + } + + // Check if key is blocked or disabled + if keyInfo.APIKeyBlocked || keyInfo.APIKeyDisabled || keyInfo.TeamBlocked { + logging.Warn("xAI API key is blocked or disabled", + "api_key_blocked", keyInfo.APIKeyBlocked, + "api_key_disabled", keyInfo.APIKeyDisabled, + "team_blocked", keyInfo.TeamBlocked) + return false + } + + return true +} + +// CheckPermissions validates that the API key has the required permissions for specific operations +func (x *xaiClient) CheckPermissions(ctx context.Context, requiredACLs []string) error { + keyInfo, err := x.ValidateAPIKey(ctx) + if err != nil { + return fmt.Errorf("failed to validate API key: %w", err) + } + + // Check if key is blocked or disabled + if keyInfo.APIKeyBlocked { + return fmt.Errorf("API key is blocked") + } + if keyInfo.APIKeyDisabled { + return fmt.Errorf("API key is disabled") + } + if keyInfo.TeamBlocked { + return fmt.Errorf("team is blocked") + } + + // Check if required ACLs are present + aclMap := make(map[string]bool) + for _, acl := range keyInfo.ACLs { + aclMap[acl] = true + } + + var missingACLs []string + for _, required := range requiredACLs { + if !aclMap[required] && !aclMap["api-key:endpoint:*"] && !aclMap["api-key:model:*"] { + // Check for wildcard permissions + found := false + for _, acl := range keyInfo.ACLs { + if acl == "api-key:endpoint:*" || acl == "api-key:model:*" { + found = true + break + } + } + if !found { + missingACLs = append(missingACLs, required) + } + } + } + + if len(missingACLs) > 0 { + return fmt.Errorf("API key missing required permissions: %v", missingACLs) + } + + logging.Debug("xAI API key permissions validated", + "required_acls", requiredACLs, + "available_acls", keyInfo.ACLs) + + return nil +} + +// GetAPIKeyInfo returns detailed information about the API key for debugging purposes +func (x *xaiClient) GetAPIKeyInfo(ctx context.Context) (map[string]interface{}, error) { + keyInfo, err := x.ValidateAPIKey(ctx) + if err != nil { + return nil, err + } + + return map[string]interface{}{ + "redacted_key": keyInfo.RedactedAPIKey, + "name": keyInfo.Name, + "team_id": keyInfo.TeamID, + "created": keyInfo.CreateTime, + "modified": keyInfo.ModifyTime, + "permissions": keyInfo.ACLs, + "status": map[string]interface{}{ + "active": !keyInfo.APIKeyBlocked && !keyInfo.APIKeyDisabled && !keyInfo.TeamBlocked, + "key_blocked": keyInfo.APIKeyBlocked, + "key_disabled": keyInfo.APIKeyDisabled, + "team_blocked": keyInfo.TeamBlocked, + }, + }, nil +} + +// ValidateForOperation checks if the API key is valid for a specific operation type +func (x *xaiClient) ValidateForOperation(ctx context.Context, operation string) error { + var requiredACLs []string + + switch operation { + case "chat": + requiredACLs = []string{"api-key:endpoint:chat", "api-key:model:*"} + case "image_generation": + requiredACLs = []string{"api-key:endpoint:images", "api-key:model:*"} + case "models": + requiredACLs = []string{"api-key:endpoint:models"} + default: + // For unknown operations, just check basic endpoint access + requiredACLs = []string{"api-key:endpoint:*"} + } + + return x.CheckPermissions(ctx, requiredACLs) +} + +// HealthCheck performs a comprehensive health check of the xAI API key and service +func (x *xaiClient) HealthCheck(ctx context.Context) map[string]interface{} { + result := map[string]interface{}{ + "timestamp": time.Now().Format(time.RFC3339), + "provider": "xai", + "model": string(x.providerOptions.model.ID), + } + + // Test API key validation + keyInfo, err := x.ValidateAPIKey(ctx) + if err != nil { + result["api_key_status"] = "invalid" + result["api_key_error"] = err.Error() + result["overall_status"] = "failed" + return result + } + + result["api_key_status"] = "valid" + result["api_key_name"] = keyInfo.Name + result["team_id"] = keyInfo.TeamID + + // Check if key is active + if keyInfo.APIKeyBlocked || keyInfo.APIKeyDisabled || keyInfo.TeamBlocked { + result["key_active"] = false + result["block_reasons"] = map[string]bool{ + "api_key_blocked": keyInfo.APIKeyBlocked, + "api_key_disabled": keyInfo.APIKeyDisabled, + "team_blocked": keyInfo.TeamBlocked, + } + result["overall_status"] = "blocked" + return result + } + + result["key_active"] = true + result["permissions"] = keyInfo.ACLs + + // Test model capabilities if available + caps, err := x.DiscoverModelCapabilities(ctx, string(x.providerOptions.model.ID)) + if err != nil { + result["model_capabilities"] = "unavailable" + result["capabilities_error"] = err.Error() + } else { + result["model_capabilities"] = map[string]interface{}{ + "supports_text": caps.SupportsText, + "supports_image_input": caps.SupportsImageInput, + "supports_image_output": caps.SupportsImageOutput, + "supports_web_search": caps.SupportsWebSearch, + "max_prompt_length": caps.MaxPromptLength, + "aliases": caps.Aliases, + } + } + + result["overall_status"] = "healthy" + return result +} diff --git a/internal/llm/provider/xai_validation_test.go b/internal/llm/provider/xai_validation_test.go new file mode 100644 index 00000000..4e0b0759 --- /dev/null +++ b/internal/llm/provider/xai_validation_test.go @@ -0,0 +1,279 @@ +package provider + +import ( + "context" + "os" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_APIKeyValidation(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("validate API key", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Validate API key + keyInfo, err := xaiClient.ValidateAPIKey(ctx) + require.NoError(t, err) + require.NotNil(t, keyInfo) + + // Check basic fields + assert.NotEmpty(t, keyInfo.RedactedAPIKey, "Should have redacted API key") + assert.NotEmpty(t, keyInfo.UserID, "Should have user ID") + assert.NotEmpty(t, keyInfo.TeamID, "Should have team ID") + assert.NotEmpty(t, keyInfo.APIKeyID, "Should have API key ID") + assert.NotEmpty(t, keyInfo.ACLs, "Should have ACLs") + + // Check key status + assert.False(t, keyInfo.APIKeyBlocked, "API key should not be blocked") + assert.False(t, keyInfo.APIKeyDisabled, "API key should not be disabled") + assert.False(t, keyInfo.TeamBlocked, "Team should not be blocked") + + t.Logf("API Key Info:") + t.Logf(" Redacted Key: %s", keyInfo.RedactedAPIKey) + t.Logf(" Name: %s", keyInfo.Name) + t.Logf(" Team ID: %s", keyInfo.TeamID) + t.Logf(" Created: %s", keyInfo.CreateTime) + t.Logf(" ACLs: %v", keyInfo.ACLs) + }) + + t.Run("check if API key is valid", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Check if key is valid + isValid := xaiClient.IsAPIKeyValid(ctx) + assert.True(t, isValid, "API key should be valid") + + t.Logf("API key validation result: %v", isValid) + }) + + t.Run("check permissions", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Check basic permissions + err = xaiClient.CheckPermissions(ctx, []string{"api-key:model:*"}) + assert.NoError(t, err, "Should have model permissions") + + err = xaiClient.CheckPermissions(ctx, []string{"api-key:endpoint:*"}) + assert.NoError(t, err, "Should have endpoint permissions") + + t.Log("Permission checks passed") + }) + + t.Run("validate for specific operations", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Test different operations + operations := []string{"chat", "image_generation", "models"} + for _, op := range operations { + err = xaiClient.ValidateForOperation(ctx, op) + if err != nil { + t.Logf("Operation %s validation failed: %v", op, err) + } else { + t.Logf("Operation %s validation passed", op) + } + } + }) + + t.Run("get API key info", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Get API key info + info, err := xaiClient.GetAPIKeyInfo(ctx) + require.NoError(t, err) + require.NotNil(t, info) + + // Check required fields + assert.Contains(t, info, "redacted_key") + assert.Contains(t, info, "name") + assert.Contains(t, info, "team_id") + assert.Contains(t, info, "permissions") + assert.Contains(t, info, "status") + + status, ok := info["status"].(map[string]interface{}) + require.True(t, ok, "Status should be a map") + assert.Contains(t, status, "active") + + t.Logf("API key info retrieved successfully") + t.Logf(" Active: %v", status["active"]) + t.Logf(" Permissions: %v", info["permissions"]) + }) + + t.Run("health check", func(t *testing.T) { + // Create provider + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Perform health check + health := xaiClient.HealthCheck(ctx) + require.NotNil(t, health) + + // Check required fields + assert.Contains(t, health, "timestamp") + assert.Contains(t, health, "provider") + assert.Contains(t, health, "model") + assert.Contains(t, health, "overall_status") + assert.Contains(t, health, "api_key_status") + + assert.Equal(t, "xai", health["provider"]) + assert.Equal(t, "grok-3-fast", health["model"]) + + overallStatus := health["overall_status"] + t.Logf("Health check status: %v", overallStatus) + + // Should be healthy or at least not failed + assert.NotEqual(t, "failed", overallStatus, "Health check should not fail completely") + + if overallStatus == "healthy" { + assert.Equal(t, "valid", health["api_key_status"]) + assert.True(t, health["key_active"].(bool)) + } + + t.Logf("Health check completed: %v", health) + }) +} + +func TestXAIProvider_InvalidAPIKey(t *testing.T) { + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("invalid API key", func(t *testing.T) { + // Create provider with invalid API key + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey("xai-invalid-key-12345"), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Validate API key should fail + _, err = xaiClient.ValidateAPIKey(ctx) + assert.Error(t, err, "Should fail with invalid API key") + + // IsAPIKeyValid should return false + isValid := xaiClient.IsAPIKeyValid(ctx) + assert.False(t, isValid, "Invalid API key should return false") + + // Health check should show failure + health := xaiClient.HealthCheck(ctx) + assert.Equal(t, "failed", health["overall_status"]) + assert.Equal(t, "invalid", health["api_key_status"]) + + t.Logf("Invalid key test completed: %v", health["api_key_error"]) + }) + + t.Run("empty API key", func(t *testing.T) { + // Create provider with empty API key + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(""), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + ctx := context.Background() + + // Should fail validation + isValid := xaiClient.IsAPIKeyValid(ctx) + assert.False(t, isValid, "Empty API key should be invalid") + }) +} diff --git a/internal/llm/provider/xai_vision_test.go b/internal/llm/provider/xai_vision_test.go new file mode 100644 index 00000000..23cc174f --- /dev/null +++ b/internal/llm/provider/xai_vision_test.go @@ -0,0 +1,220 @@ +package provider + +import ( + "context" + "encoding/base64" + "os" + "testing" + "time" + + "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXAIProvider_VisionSupport(t *testing.T) { + // Skip if no API key is provided + apiKey := os.Getenv("XAI_API_KEY") + if apiKey == "" { + t.Skip("XAI_API_KEY not set") + } + + // Initialize config for tests + tmpDir := t.TempDir() + _, err := config.Load(tmpDir, false) + require.NoError(t, err) + + t.Run("vision model detection", func(t *testing.T) { + // Test grok-2-vision model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Vision]), + WithMaxTokens(200), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + assert.True(t, xaiClient.IsVisionCapable(), "grok-2-vision should be vision capable") + }) + + t.Run("non-vision model detection", func(t *testing.T) { + // Test grok-3-fast model (non-vision) + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok3Fast]), + WithMaxTokens(200), + ) + require.NoError(t, err) + + baseProvider, ok := provider.(*baseProvider[XAIClient]) + require.True(t, ok, "Provider should be baseProvider[XAIClient]") + xaiClient := baseProvider.client.(*xaiClient) + + assert.False(t, xaiClient.IsVisionCapable(), "grok-3-fast should not be vision capable") + }) + + t.Run("image recognition with base64", func(t *testing.T) { + // Create a simple test image (1x1 red pixel PNG) + redPixelPNG := []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d, + 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xde, 0x00, 0x00, 0x00, + 0x0c, 0x49, 0x44, 0x41, 0x54, 0x08, 0xd7, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0x18, 0xdd, 0x8d, 0xb4, 0x00, 0x00, 0x00, + 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82, + } + + // Create provider with vision model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Vision]), + WithMaxTokens(200), + WithSystemMessage("You are a helpful assistant that analyzes images concisely."), + ) + require.NoError(t, err) + + // Create message with image + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "What color is this 1x1 pixel image? Just say the color name."}, + message.BinaryContent{ + MIMEType: "image/png", + Data: redPixelPNG, + }, + }, + }, + } + + // Send request + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + response, err := provider.SendMessages(ctx, messages, nil) + require.NoError(t, err) + require.NotNil(t, response) + + // Check that we got a response about red color + assert.NotEmpty(t, response.Content) + assert.NotEmpty(t, response.SystemFingerprint) + assert.Greater(t, response.Usage.InputTokens, int64(0)) + assert.Greater(t, response.Usage.OutputTokens, int64(0)) + + t.Logf("Vision response: %s", response.Content) + t.Logf("System fingerprint: %s", response.SystemFingerprint) + t.Logf("Usage: %+v", response.Usage) + }) + + t.Run("image validation", func(t *testing.T) { + // Test valid image + validAttachment := message.Attachment{ + FileName: "test.jpg", + MimeType: "image/jpeg", + Content: make([]byte, 1024*1024), // 1MB + } + err := ValidateImageAttachment(validAttachment) + assert.NoError(t, err) + + // Test oversized image + oversizedAttachment := message.Attachment{ + FileName: "large.jpg", + MimeType: "image/jpeg", + Content: make([]byte, 21*1024*1024), // 21MB + } + err = ValidateImageAttachment(oversizedAttachment) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum allowed size") + + // Test unsupported format + unsupportedAttachment := message.Attachment{ + FileName: "test.gif", + MimeType: "image/gif", + Content: make([]byte, 1024), + } + err = ValidateImageAttachment(unsupportedAttachment) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported image format") + }) + + t.Run("streaming with images", func(t *testing.T) { + // Create a simple test image (base64 encoded small JPEG) + smallJPEG := "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/2wBDAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/wAARCAABAAEDAREAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAX/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCmAA//2Q==" + + // Extract base64 data + b64Data := smallJPEG[23:] // Skip "data:image/jpeg;base64," + imageData, err := base64.StdEncoding.DecodeString(b64Data) + require.NoError(t, err) + + // Create provider with vision model + provider, err := NewProvider( + models.ProviderXAI, + WithAPIKey(apiKey), + WithModel(models.SupportedModels[models.XAIGrok2Vision]), + WithMaxTokens(200), + WithSystemMessage("You are a helpful assistant."), + ) + require.NoError(t, err) + + // Create message with image + messages := []message.Message{ + { + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: "Describe this image in 5 words or less."}, + message.BinaryContent{ + MIMEType: "image/jpeg", + Data: imageData, + }, + }, + }, + } + + // Stream response + ctx := context.Background() + eventChan := provider.StreamResponse(ctx, messages, nil) + + // Collect events + var contentChunks []string + var finalResponse *ProviderResponse + hasContentDelta := false + + for event := range eventChan { + switch event.Type { + case EventContentDelta: + hasContentDelta = true + contentChunks = append(contentChunks, event.Content) + + case EventComplete: + finalResponse = event.Response + + case EventError: + t.Fatalf("Streaming error: %v", event.Error) + } + } + + // Verify streaming worked correctly + require.NotNil(t, finalResponse) + assert.True(t, hasContentDelta, "Should have received content deltas") + assert.NotEmpty(t, finalResponse.Content) + assert.NotEmpty(t, finalResponse.SystemFingerprint) + + t.Logf("Streaming vision response: %s", finalResponse.Content) + }) + + t.Run("deferred completion with images", func(t *testing.T) { + // Skip if not configured for deferred + t.Skip("Deferred completion with images test - enable when needed") + + // This test would verify deferred completions work with images + // Similar structure to above tests but using SendDeferred + }) +} diff --git a/internal/llm/tools/web_search.go b/internal/llm/tools/web_search.go new file mode 100644 index 00000000..ba6e109d --- /dev/null +++ b/internal/llm/tools/web_search.go @@ -0,0 +1,337 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" +) + +// WebSearchTool provides web search functionality for AI models that support it. +// This tool is primarily used by providers like xAI Grok models which have built-in +// web search capabilities. The actual search execution is handled by the model provider, +// not by this tool implementation. It supports advanced Live Search parameters for +// enhanced search control and filtering. +type WebSearchTool struct{} + +// WebSearchParameters defines the input parameters for web search requests. +// Supports xAI Live Search parameters for enhanced search capabilities. +type WebSearchParameters struct { + Query string `json:"query" jsonschema:"description=The search query to execute,required"` + Mode *string `json:"mode,omitempty" jsonschema:"description=Search mode: auto|on|off (default: auto)"` + MaxSearchResults *int `json:"max_search_results,omitempty" jsonschema:"description=Maximum number of search results (1-20, default: 20)"` + FromDate *string `json:"from_date,omitempty" jsonschema:"description=Start date for search results in YYYY-MM-DD format"` + ToDate *string `json:"to_date,omitempty" jsonschema:"description=End date for search results in YYYY-MM-DD format"` + ReturnCitations *bool `json:"return_citations,omitempty" jsonschema:"description=Whether to return citations (default: true)"` + Sources []WebSearchSource `json:"sources,omitempty" jsonschema:"description=List of data sources to search"` +} + +// WebSearchSource represents a data source for Live Search +type WebSearchSource struct { + Type string `json:"type" jsonschema:"description=Source type: web|x|news|rss,required"` + Country *string `json:"country,omitempty" jsonschema:"description=ISO alpha-2 country code (web, news)"` + ExcludedWebsites []string `json:"excluded_websites,omitempty" jsonschema:"description=Websites to exclude (max 5, web/news)"` + AllowedWebsites []string `json:"allowed_websites,omitempty" jsonschema:"description=Allowed websites only (max 5, web only)"` + SafeSearch *bool `json:"safe_search,omitempty" jsonschema:"description=Enable safe search (default: true, web/news)"` + IncludedXHandles []string `json:"included_x_handles,omitempty" jsonschema:"description=X handles to include (max 10, x only)"` + ExcludedXHandles []string `json:"excluded_x_handles,omitempty" jsonschema:"description=X handles to exclude (max 10, x only)"` + PostFavoriteCount *int `json:"post_favorite_count,omitempty" jsonschema:"description=Minimum favorite count for X posts"` + PostViewCount *int `json:"post_view_count,omitempty" jsonschema:"description=Minimum view count for X posts"` + Links []string `json:"links,omitempty" jsonschema:"description=RSS feed URLs (1 link max, rss only)"` +} + +// Info returns metadata about the web search tool including its parameters and description. +func (t *WebSearchTool) Info() ToolInfo { + return ToolInfo{ + Name: "web_search", + Description: "Search the web for current information with advanced Live Search capabilities. Supports multiple data sources (web, X, news, RSS), date filtering, and citation tracking.", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "The search query to execute", + }, + "mode": map[string]interface{}{ + "type": "string", + "description": "Search mode: 'auto' (default), 'on', or 'off'", + "enum": []string{"auto", "on", "off"}, + }, + "max_search_results": map[string]interface{}{ + "type": "integer", + "description": "Maximum number of search results (1-20, default: 20)", + "minimum": 1, + "maximum": 20, + }, + "from_date": map[string]interface{}{ + "type": "string", + "description": "Start date for search results in YYYY-MM-DD format", + "pattern": "^\\d{4}-\\d{2}-\\d{2}$", + }, + "to_date": map[string]interface{}{ + "type": "string", + "description": "End date for search results in YYYY-MM-DD format", + "pattern": "^\\d{4}-\\d{2}-\\d{2}$", + }, + "return_citations": map[string]interface{}{ + "type": "boolean", + "description": "Whether to return citations (default: true)", + }, + "sources": map[string]interface{}{ + "type": "array", + "description": "List of data sources to search", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "type": map[string]interface{}{ + "type": "string", + "description": "Source type", + "enum": []string{"web", "x", "news", "rss"}, + }, + "country": map[string]interface{}{ + "type": "string", + "description": "ISO alpha-2 country code (web, news)", + "pattern": "^[A-Z]{2}$", + }, + "excluded_websites": map[string]interface{}{ + "type": "array", + "description": "Websites to exclude (max 5, web/news)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 5, + }, + "allowed_websites": map[string]interface{}{ + "type": "array", + "description": "Allowed websites only (max 5, web only)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 5, + }, + "safe_search": map[string]interface{}{ + "type": "boolean", + "description": "Enable safe search (default: true, web/news)", + }, + "included_x_handles": map[string]interface{}{ + "type": "array", + "description": "X handles to include (max 10, x only)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 10, + }, + "excluded_x_handles": map[string]interface{}{ + "type": "array", + "description": "X handles to exclude (max 10, x only)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 10, + }, + "post_favorite_count": map[string]interface{}{ + "type": "integer", + "description": "Minimum favorite count for X posts", + "minimum": 0, + }, + "post_view_count": map[string]interface{}{ + "type": "integer", + "description": "Minimum view count for X posts", + "minimum": 0, + }, + "links": map[string]interface{}{ + "type": "array", + "description": "RSS feed URLs (1 link max, rss only)", + "items": map[string]interface{}{"type": "string", "format": "uri"}, + "maxItems": 1, + }, + }, + "required": []string{"type"}, + }, + }, + }, + "required": []string{"query"}, + }, + Required: []string{"query"}, + } +} + +// Run processes the web search request. Since the actual web search is performed +// by the AI model provider (e.g., xAI), this method validates the input parameters +// and returns a response indicating that the search will be handled by the provider. +func (t *WebSearchTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) { + var searchParams WebSearchParameters + + if err := json.Unmarshal([]byte(params.Input), &searchParams); err != nil { + return ToolResponse{ + Type: ToolResponseTypeText, + Content: fmt.Sprintf("Failed to parse web search parameters: %v", err), + IsError: true, + }, nil + } + + // Validate query is not empty + if searchParams.Query == "" { + return ToolResponse{ + Type: ToolResponseTypeText, + Content: "Search query cannot be empty", + IsError: true, + }, nil + } + + // Validate Live Search parameters + if err := t.validateLiveSearchParams(&searchParams); err != nil { + return ToolResponse{ + Type: ToolResponseTypeText, + Content: fmt.Sprintf("Invalid Live Search parameters: %v", err), + IsError: true, + }, nil + } + + // Build description of search configuration + description := fmt.Sprintf("Searching the web for: %s", searchParams.Query) + + if searchParams.Mode != nil && *searchParams.Mode != "auto" { + description += fmt.Sprintf(" (mode: %s)", *searchParams.Mode) + } + + if searchParams.MaxSearchResults != nil { + description += fmt.Sprintf(" (max results: %d)", *searchParams.MaxSearchResults) + } + + if searchParams.FromDate != nil || searchParams.ToDate != nil { + if searchParams.FromDate != nil && searchParams.ToDate != nil { + description += fmt.Sprintf(" (date range: %s to %s)", *searchParams.FromDate, *searchParams.ToDate) + } else if searchParams.FromDate != nil { + description += fmt.Sprintf(" (from: %s)", *searchParams.FromDate) + } else { + description += fmt.Sprintf(" (until: %s)", *searchParams.ToDate) + } + } + + if len(searchParams.Sources) > 0 { + sourceTypes := make([]string, len(searchParams.Sources)) + for i, source := range searchParams.Sources { + sourceTypes[i] = source.Type + } + description += fmt.Sprintf(" (sources: %v)", sourceTypes) + } + + // Return success response indicating the provider will handle the search + return ToolResponse{ + Type: ToolResponseTypeText, + Content: description, + IsError: false, + }, nil +} + +// validateLiveSearchParams validates Live Search parameters according to xAI specifications +func (t *WebSearchTool) validateLiveSearchParams(params *WebSearchParameters) error { + // Validate mode + if params.Mode != nil { + mode := *params.Mode + if mode != "auto" && mode != "on" && mode != "off" { + return fmt.Errorf("mode must be 'auto', 'on', or 'off', got: %s", mode) + } + } + + // Validate max_search_results range + if params.MaxSearchResults != nil { + if *params.MaxSearchResults < 1 || *params.MaxSearchResults > 20 { + return fmt.Errorf("max_search_results must be between 1 and 20, got: %d", *params.MaxSearchResults) + } + } + + // Validate date formats (basic YYYY-MM-DD validation) + if params.FromDate != nil { + date := *params.FromDate + if len(date) != 10 || date[4] != '-' || date[7] != '-' { + return fmt.Errorf("from_date must be in YYYY-MM-DD format, got: %s", date) + } + } + if params.ToDate != nil { + date := *params.ToDate + if len(date) != 10 || date[4] != '-' || date[7] != '-' { + return fmt.Errorf("to_date must be in YYYY-MM-DD format, got: %s", date) + } + } + + // Validate sources + for i, source := range params.Sources { + if err := t.validateSource(&source, i); err != nil { + return fmt.Errorf("source %d: %w", i, err) + } + } + + return nil +} + +// validateSource validates individual source parameters +func (t *WebSearchTool) validateSource(source *WebSearchSource, index int) error { + // Validate source type + validTypes := map[string]bool{"web": true, "x": true, "news": true, "rss": true} + if !validTypes[source.Type] { + return fmt.Errorf("invalid source type: %s (must be web, x, news, or rss)", source.Type) + } + + // Validate website exclusions/inclusions + if len(source.ExcludedWebsites) > 5 { + return fmt.Errorf("excluded_websites cannot exceed 5 entries, got: %d", len(source.ExcludedWebsites)) + } + if len(source.AllowedWebsites) > 5 { + return fmt.Errorf("allowed_websites cannot exceed 5 entries, got: %d", len(source.AllowedWebsites)) + } + if len(source.ExcludedWebsites) > 0 && len(source.AllowedWebsites) > 0 { + return fmt.Errorf("cannot use both excluded_websites and allowed_websites in the same source") + } + + // Validate X handles + if len(source.IncludedXHandles) > 10 { + return fmt.Errorf("included_x_handles cannot exceed 10 entries, got: %d", len(source.IncludedXHandles)) + } + if len(source.ExcludedXHandles) > 10 { + return fmt.Errorf("excluded_x_handles cannot exceed 10 entries, got: %d", len(source.ExcludedXHandles)) + } + if len(source.IncludedXHandles) > 0 && len(source.ExcludedXHandles) > 0 { + return fmt.Errorf("cannot use both included_x_handles and excluded_x_handles in the same source") + } + + // Validate RSS links + if len(source.Links) > 1 { + return fmt.Errorf("RSS source can only have 1 link, got: %d", len(source.Links)) + } + + // Validate source-specific parameters + switch source.Type { + case "web": + if len(source.IncludedXHandles) > 0 || len(source.ExcludedXHandles) > 0 || + source.PostFavoriteCount != nil || source.PostViewCount != nil { + return fmt.Errorf("X-specific parameters not allowed for web source") + } + if len(source.Links) > 0 { + return fmt.Errorf("RSS links not allowed for web source") + } + case "x": + if source.Country != nil || len(source.ExcludedWebsites) > 0 || + len(source.AllowedWebsites) > 0 || source.SafeSearch != nil { + return fmt.Errorf("web/news-specific parameters not allowed for X source") + } + if len(source.Links) > 0 { + return fmt.Errorf("RSS links not allowed for X source") + } + case "news": + if len(source.IncludedXHandles) > 0 || len(source.ExcludedXHandles) > 0 || + source.PostFavoriteCount != nil || source.PostViewCount != nil { + return fmt.Errorf("X-specific parameters not allowed for news source") + } + if len(source.AllowedWebsites) > 0 { + return fmt.Errorf("allowed_websites not supported for news source") + } + if len(source.Links) > 0 { + return fmt.Errorf("RSS links not allowed for news source") + } + case "rss": + if source.Country != nil || len(source.ExcludedWebsites) > 0 || + len(source.AllowedWebsites) > 0 || source.SafeSearch != nil || + len(source.IncludedXHandles) > 0 || len(source.ExcludedXHandles) > 0 || + source.PostFavoriteCount != nil || source.PostViewCount != nil { + return fmt.Errorf("only links parameter allowed for RSS source") + } + if len(source.Links) == 0 { + return fmt.Errorf("RSS source requires at least one link") + } + } + + return nil +} diff --git a/internal/llm/tools/web_search_test.go b/internal/llm/tools/web_search_test.go new file mode 100644 index 00000000..a4f4be5e --- /dev/null +++ b/internal/llm/tools/web_search_test.go @@ -0,0 +1,485 @@ +package tools + +import ( + "context" + "encoding/json" + "strings" + "testing" +) + +func TestWebSearchTool(t *testing.T) { + tool := &WebSearchTool{} + ctx := context.Background() + + t.Run("Info returns correct metadata", func(t *testing.T) { + info := tool.Info() + + if info.Name != "web_search" { + t.Errorf("Expected tool name 'web_search', got '%s'", info.Name) + } + + if info.Description == "" { + t.Error("Tool description should not be empty") + } + + if !strings.Contains(info.Description, "Live Search") { + t.Error("Description should mention Live Search functionality") + } + + // Check parameters structure + if paramType, ok := info.Parameters["type"]; !ok || paramType != "object" { + t.Error("Parameters should have type 'object'") + } + + properties, ok := info.Parameters["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Parameters should have 'properties' field") + } + + // Check all Live Search parameters are present + expectedParams := []string{ + "query", "mode", "max_search_results", "from_date", "to_date", + "return_citations", "sources", + } + + for _, param := range expectedParams { + if _, hasParam := properties[param]; !hasParam { + t.Errorf("Parameters should have '%s' property", param) + } + } + + // Check mode enum values + if modeParam, ok := properties["mode"].(map[string]interface{}); ok { + if enumValues, ok := modeParam["enum"].([]string); ok { + expectedModes := []string{"auto", "on", "off"} + for _, mode := range expectedModes { + found := false + for _, enumVal := range enumValues { + if enumVal == mode { + found = true + break + } + } + if !found { + t.Errorf("Mode enum should include '%s'", mode) + } + } + } + } + + if len(info.Required) != 1 || info.Required[0] != "query" { + t.Errorf("Expected required fields to be ['query'], got %v", info.Required) + } + }) + + t.Run("Run with valid query", func(t *testing.T) { + params := WebSearchParameters{ + Query: "test search query", + } + + inputJSON, err := json.Marshal(params) + if err != nil { + t.Fatalf("Failed to marshal parameters: %v", err) + } + + toolCall := ToolCall{ + ID: "test-id", + Name: "web_search", + Input: string(inputJSON), + } + + response, err := tool.Run(ctx, toolCall) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + if response.IsError { + t.Errorf("Expected success response, got error: %s", response.Content) + } + + if response.Content == "" { + t.Error("Response content should not be empty") + } + + if !strings.Contains(response.Content, "test search query") { + t.Error("Response should mention the search query") + } + + if response.Type != ToolResponseTypeText { + t.Errorf("Expected response type %s, got %s", ToolResponseTypeText, response.Type) + } + }) + + t.Run("Run with invalid JSON", func(t *testing.T) { + toolCall := ToolCall{ + ID: "test-id", + Name: "web_search", + Input: "invalid json{", + } + + response, err := tool.Run(ctx, toolCall) + if err != nil { + t.Errorf("Expected no error from Run method, got: %v", err) + } + + if !response.IsError { + t.Error("Expected error response for invalid JSON") + } + + if !strings.Contains(response.Content, "parse") { + t.Error("Error message should mention parsing failure") + } + }) + + t.Run("Run with empty query", func(t *testing.T) { + params := WebSearchParameters{ + Query: "", + } + + inputJSON, _ := json.Marshal(params) + toolCall := ToolCall{ + ID: "test-id", + Name: "web_search", + Input: string(inputJSON), + } + + response, err := tool.Run(ctx, toolCall) + if err != nil { + t.Errorf("Expected no error from Run method, got: %v", err) + } + + if !response.IsError { + t.Error("Expected error response for empty query") + } + + if !strings.Contains(response.Content, "empty") { + t.Error("Error message should mention empty query") + } + }) + + // Live Search parameter tests + t.Run("Run with Live Search parameters", func(t *testing.T) { + mode := "auto" + maxResults := 10 + fromDate := "2025-01-01" + toDate := "2025-12-31" + returnCitations := true + + params := WebSearchParameters{ + Query: "AI developments 2025", + Mode: &mode, + MaxSearchResults: &maxResults, + FromDate: &fromDate, + ToDate: &toDate, + ReturnCitations: &returnCitations, + Sources: []WebSearchSource{ + { + Type: "web", + Country: stringPtr("US"), + }, + { + Type: "news", + ExcludedWebsites: []string{"example.com"}, + }, + }, + } + + inputJSON, err := json.Marshal(params) + if err != nil { + t.Fatalf("Failed to marshal parameters: %v", err) + } + + toolCall := ToolCall{ + ID: "test-id", + Name: "web_search", + Input: string(inputJSON), + } + + response, err := tool.Run(ctx, toolCall) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + if response.IsError { + t.Errorf("Expected success response, got error: %s", response.Content) + } + + content := response.Content + if !strings.Contains(content, "AI developments 2025") { + t.Error("Response should mention the search query") + } + + // Should include parameter details in response (mode auto is default, not shown) + if !strings.Contains(content, "max results: 10") { + t.Error("Response should mention max results") + } + + if !strings.Contains(content, "2025-01-01 to 2025-12-31") { + t.Error("Response should mention date range") + } + + if !strings.Contains(content, "[web news]") { + t.Error("Response should mention source types") + } + }) + + t.Run("Parameter validation tests", func(t *testing.T) { + testCases := []struct { + name string + params WebSearchParameters + expectError bool + errorMsg string + }{ + { + name: "invalid mode", + params: WebSearchParameters{ + Query: "test", + Mode: stringPtr("invalid"), + }, + expectError: true, + errorMsg: "mode must be", + }, + { + name: "max results too high", + params: WebSearchParameters{ + Query: "test", + MaxSearchResults: intPtr(25), + }, + expectError: true, + errorMsg: "between 1 and 20", + }, + { + name: "max results too low", + params: WebSearchParameters{ + Query: "test", + MaxSearchResults: intPtr(0), + }, + expectError: true, + errorMsg: "between 1 and 20", + }, + { + name: "invalid date format", + params: WebSearchParameters{ + Query: "test", + FromDate: stringPtr("2025/01/01"), + }, + expectError: true, + errorMsg: "YYYY-MM-DD format", + }, + { + name: "invalid source type", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + {Type: "invalid"}, + }, + }, + expectError: true, + errorMsg: "invalid source type", + }, + { + name: "too many excluded websites", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "web", + ExcludedWebsites: []string{"1.com", "2.com", "3.com", "4.com", "5.com", "6.com"}, + }, + }, + }, + expectError: true, + errorMsg: "cannot exceed 5", + }, + { + name: "conflicting website filters", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "web", + ExcludedWebsites: []string{"example.com"}, + AllowedWebsites: []string{"test.com"}, + }, + }, + }, + expectError: true, + errorMsg: "cannot use both", + }, + { + name: "too many X handles", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "x", + IncludedXHandles: []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"}, + }, + }, + }, + expectError: true, + errorMsg: "cannot exceed 10", + }, + { + name: "conflicting X handle filters", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "x", + IncludedXHandles: []string{"xai"}, + ExcludedXHandles: []string{"openai"}, + }, + }, + }, + expectError: true, + errorMsg: "cannot use both", + }, + { + name: "RSS without links", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + {Type: "rss"}, + }, + }, + expectError: true, + errorMsg: "requires at least one link", + }, + { + name: "RSS with too many links", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "rss", + Links: []string{"feed1.xml", "feed2.xml"}, + }, + }, + }, + expectError: true, + errorMsg: "can only have 1 link", + }, + { + name: "X parameters on web source", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "web", + IncludedXHandles: []string{"xai"}, + }, + }, + }, + expectError: true, + errorMsg: "X-specific parameters not allowed", + }, + { + name: "web parameters on X source", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "x", + Country: stringPtr("US"), + }, + }, + }, + expectError: true, + errorMsg: "web/news-specific parameters not allowed", + }, + { + name: "allowed websites on news source", + params: WebSearchParameters{ + Query: "test", + Sources: []WebSearchSource{ + { + Type: "news", + AllowedWebsites: []string{"news.com"}, + }, + }, + }, + expectError: true, + errorMsg: "allowed_websites not supported for news", + }, + { + name: "valid parameters", + params: WebSearchParameters{ + Query: "test query", + Mode: stringPtr("auto"), + MaxSearchResults: intPtr(10), + FromDate: stringPtr("2025-01-01"), + ToDate: stringPtr("2025-12-31"), + ReturnCitations: boolPtr(true), + Sources: []WebSearchSource{ + { + Type: "web", + Country: stringPtr("US"), + ExcludedWebsites: []string{"spam.com"}, + }, + { + Type: "x", + IncludedXHandles: []string{"xai"}, + PostFavoriteCount: intPtr(100), + }, + { + Type: "news", + Country: stringPtr("UK"), + }, + { + Type: "rss", + Links: []string{"https://feeds.example.com/rss.xml"}, + }, + }, + }, + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + inputJSON, err := json.Marshal(tc.params) + if err != nil { + t.Fatalf("Failed to marshal parameters: %v", err) + } + + toolCall := ToolCall{ + ID: "test-id", + Name: "web_search", + Input: string(inputJSON), + } + + response, err := tool.Run(ctx, toolCall) + if err != nil { + t.Errorf("Expected no error from Run method, got: %v", err) + } + + if tc.expectError { + if !response.IsError { + t.Errorf("Expected error response for %s", tc.name) + } + if !strings.Contains(response.Content, tc.errorMsg) { + t.Errorf("Expected error message to contain '%s', got: %s", tc.errorMsg, response.Content) + } + } else { + if response.IsError { + t.Errorf("Expected success response for %s, got error: %s", tc.name, response.Content) + } + } + }) + } + }) +} + +// Helper functions for pointer creation +func stringPtr(s string) *string { + return &s +} + +func intPtr(i int) *int { + return &i +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index a6c5a44e..eccbaae1 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -256,13 +256,24 @@ func (m *editorCmp) attachmentsContent() string { MarginLeft(1). Background(t.TextMuted()). Foreground(t.Text()) + for i, attachment := range m.attachments { + // Choose appropriate icon based on MIME type + icon := styles.DocumentIcon + if strings.HasPrefix(attachment.MimeType, "image/") { + icon = "🖼️" // Image icon for images + } + + // Format file size + sizeStr := formatFileSize(len(attachment.Content)) + var filename string if len(attachment.FileName) > 10 { - filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, attachment.FileName[0:7]) + filename = fmt.Sprintf(" %s %s... (%s)", icon, attachment.FileName[0:7], sizeStr) } else { - filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, attachment.FileName) + filename = fmt.Sprintf(" %s %s (%s)", icon, attachment.FileName, sizeStr) } + if m.deleteMode { filename = fmt.Sprintf("%d%s", i, filename) } @@ -272,6 +283,23 @@ func (m *editorCmp) attachmentsContent() string { return content } +// formatFileSize formats bytes into a human-readable string +func formatFileSize(bytes int) string { + const ( + KB = 1024 + MB = KB * 1024 + ) + + switch { + case bytes >= MB: + return fmt.Sprintf("%.1fMB", float64(bytes)/MB) + case bytes >= KB: + return fmt.Sprintf("%.1fKB", float64(bytes)/KB) + default: + return fmt.Sprintf("%dB", bytes) + } +} + func (m *editorCmp) BindingKeys() []key.Binding { bindings := []key.Binding{} bindings = append(bindings, layout.KeyMapToSlice(editorMaps)...) diff --git a/internal/tui/components/dialog/filepicker.go b/internal/tui/components/dialog/filepicker.go index 3b9a0dc6..e3148819 100644 --- a/internal/tui/components/dialog/filepicker.go +++ b/internal/tui/components/dialog/filepicker.go @@ -16,6 +16,7 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/opencode-ai/opencode/internal/app" "github.com/opencode-ai/opencode/internal/config" + "github.com/opencode-ai/opencode/internal/llm/provider" "github.com/opencode-ai/opencode/internal/logging" "github.com/opencode-ai/opencode/internal/message" "github.com/opencode-ai/opencode/internal/tui/image" @@ -234,13 +235,21 @@ func (f *filepickerCmp) addAttachmentToMessage() (tea.Model, tea.Cmd) { return f, nil } - isFileLarge, err := image.ValidateFileSize(selectedFilePath, maxAttachmentSize) + // Check if current model is xAI to apply specific size limit + cfg := config.Get() + modelInfo := GetSelectedModel(cfg) + maxSize := maxAttachmentSize // Default 5MB + if strings.HasPrefix(string(modelInfo.ID), "grok") { + maxSize = provider.MaxImageSize // xAI allows up to 20MB + } + + isFileLarge, err := image.ValidateFileSize(selectedFilePath, maxSize) if err != nil { logging.ErrorPersist("unable to read the image") return f, nil } if isFileLarge { - logging.ErrorPersist("file too large, max 5MB") + logging.ErrorPersist(fmt.Sprintf("file too large, max %.0fMB", float64(maxSize)/(1024*1024))) return f, nil } @@ -254,6 +263,14 @@ func (f *filepickerCmp) addAttachmentToMessage() (tea.Model, tea.Cmd) { mimeType := http.DetectContentType(content[:mimeBufferSize]) fileName := filepath.Base(selectedFilePath) attachment := message.Attachment{FilePath: selectedFilePath, FileName: fileName, MimeType: mimeType, Content: content} + + // Additional xAI-specific validation + if strings.HasPrefix(string(modelInfo.ID), "grok") { + if err := provider.ValidateImageAttachment(attachment); err != nil { + logging.ErrorPersist(fmt.Sprintf("Invalid image: %v", err)) + return f, nil + } + } f.selectedFile = "" return f, util.CmdHandler(AttachmentAddedMsg{attachment}) } diff --git a/internal/tui/components/dialog/models.go b/internal/tui/components/dialog/models.go index 77c2a02a..0e6e6d2c 100644 --- a/internal/tui/components/dialog/models.go +++ b/internal/tui/components/dialog/models.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/opencode-ai/opencode/internal/config" "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/provider" "github.com/opencode-ai/opencode/internal/tui/layout" "github.com/opencode-ai/opencode/internal/tui/styles" "github.com/opencode-ai/opencode/internal/tui/theme" @@ -208,7 +209,14 @@ func (m *modelDialogCmp) View() string { itemStyle = itemStyle.Background(t.Primary()). Foreground(t.Background()).Bold(true) } - modelItems = append(modelItems, itemStyle.Render(m.models[i].Name)) + + // Add vision indicator for models that support images + modelName := m.models[i].Name + if m.models[i].SupportsAttachments && provider.IsVisionModel(string(m.models[i].ID)) { + modelName += " 👁" // Eye emoji to indicate vision support + } + + modelItems = append(modelItems, itemStyle.Render(modelName)) } scrollIndicator := m.getScrollIndicators(maxDialogWidth) diff --git a/opencode-schema.json b/opencode-schema.json index 406c75f8..d1a46165 100644 --- a/opencode-schema.json +++ b/opencode-schema.json @@ -100,6 +100,10 @@ "high" ], "type": "string" + }, + "deferredCompletion": { + "description": "Override provider's deferred completion setting for this agent", + "type": "boolean" } }, "required": [ @@ -210,6 +214,10 @@ "high" ], "type": "string" + }, + "deferredCompletion": { + "description": "Override provider's deferred completion setting for this agent", + "type": "boolean" } }, "required": [ @@ -372,6 +380,49 @@ "description": "Whether the provider is disabled", "type": "boolean" }, + "maxConcurrentRequests": { + "description": "Maximum concurrent requests for providers that support it (e.g., xAI)", + "minimum": 1, + "type": "integer" + }, + "deferredCompletion": { + "description": "Deferred completion configuration for providers that support it (e.g., xAI)", + "type": "object", + "properties": { + "enabled": { + "description": "Enable deferred completions", + "type": "boolean" + }, + "timeout": { + "description": "Timeout duration (e.g., '10m', '30s')", + "type": "string", + "pattern": "^[0-9]+(ms|s|m|h)$" + }, + "pollInterval": { + "description": "Poll interval duration (e.g., '10s', '500ms')", + "type": "string", + "pattern": "^[0-9]+(ms|s|m|h)$" + }, + "autoEnable": { + "description": "Rules for automatically enabling deferred completions", + "type": "object", + "properties": { + "forModels": { + "description": "Enable for specific models", + "type": "array", + "items": { + "type": "string" + } + }, + "whenTokensExceed": { + "description": "Enable when max tokens exceed this value", + "type": "integer", + "minimum": 1 + } + } + } + } + }, "provider": { "description": "Provider type", "enum": [ @@ -383,7 +434,8 @@ "bedrock", "azure", "vertexai", - "copilot" + "copilot", + "xai" ], "type": "string" } From 91c192b28e8da6560189b1051171df06d8b1b2c3 Mon Sep 17 00:00:00 2001 From: askiiRobotics Date: Fri, 11 Jul 2025 17:32:43 +0400 Subject: [PATCH 2/9] Clearing out local config files .claude/settings.local.json - deleted --- .claude/settings.local.json | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index eb295133..00000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(go fmt:*)", - "Bash(go vet:*)", - "Bash(go build:*)", - "Bash(go test:*)" - ], - "deny": [] - } -} \ No newline at end of file From 6cca518cec8b3ea25a36a23da53b7c95684ff574 Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Fri, 11 Jul 2025 17:39:25 +0400 Subject: [PATCH 3/9] Clearing out out of scope autofixes --- internal/llm/agent/agent.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 34626a7b..b0d1d81a 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -242,7 +242,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } if len(msgs) == 0 { go func() { - defer logging.RecoverPanic("agent.generateTitle", func() { + defer logging.RecoverPanic("agent.Run", func() { logging.ErrorPersist("panic while generating title") }) titleErr := a.generateTitle(context.Background(), sessionID, content) @@ -373,7 +373,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg tool = availableTool break } - // TODO: Handle Copilot Sonnet-4 tool name repetition if needed + // Monkey patch for Copilot Sonnet-4 tool repetition obfuscation // if strings.HasPrefix(toolCall.Name, availableTool.Info().Name) && // strings.HasPrefix(toolCall.Name, availableTool.Info().Name+availableTool.Info().Name) { // tool = availableTool From 314cdb6794ed7f830055ba89f29fcb591d689920 Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Fri, 11 Jul 2025 17:39:58 +0400 Subject: [PATCH 4/9] Updated xAI default models --- internal/config/config.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 221f4911..0adec159 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -225,10 +225,9 @@ func Load(workingDir string, debug bool) (*Config, error) { } // Override the max tokens for title agent to ensure concise titles - if titleAgent, exists := cfg.Agents[AgentTitle]; exists { - titleAgent.MaxTokens = 80 - cfg.Agents[AgentTitle] = titleAgent - } + titleAgent := cfg.Agents[AgentTitle] + titleAgent.MaxTokens = 80 + cfg.Agents[AgentTitle] = titleAgent return cfg, nil } @@ -368,10 +367,10 @@ func setProviderDefaults() { // XAI configuration if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" { - viper.SetDefault("agents.coder.model", models.XAIGrok2) - viper.SetDefault("agents.summarizer.model", models.XAIGrok2) - viper.SetDefault("agents.task.model", models.XAIGrok2) - viper.SetDefault("agents.title.model", models.XAIGrok3MiniFast) + viper.SetDefault("agents.coder.model", models.XAIGrok4) // Most capable model with reasoning + vision + viper.SetDefault("agents.summarizer.model", models.XAIGrok3) // Good balance for summarization + viper.SetDefault("agents.task.model", models.XAIGrok3Mini) // Reasoning support for complex tasks + viper.SetDefault("agents.title.model", models.XAIGrok3MiniFast) // Fast + cheap for simple titles return } From 921316239ebb199d4bf99508b5479e6ba0ba20ba Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Fri, 11 Jul 2025 17:43:45 +0400 Subject: [PATCH 5/9] Cleared out obsolete part of the readme updates --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index dec333f0..aeb2f465 100644 --- a/README.md +++ b/README.md @@ -306,13 +306,6 @@ opencode -d opencode -c /path/to/project ``` -## Documentation - -- [Image Recognition](docs/image-recognition.md) - Guide for using vision-enabled models with images -- [Web Search](docs/web-search.md) - Using web search capabilities with supported models -- [Custom Commands](docs/custom-commands.md) - Creating custom commands with named arguments -- [Configuration](docs/configuration.md) - Detailed configuration options - ## Non-interactive Prompt Mode You can run OpenCode in non-interactive mode by passing a prompt directly as a command-line argument. This is useful for scripting, automation, or when you want a quick answer without launching the full TUI. From d1eb10fb906872d5a4a19734cc5b46ebd4326f9f Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Sat, 12 Jul 2025 00:39:10 +0400 Subject: [PATCH 6/9] fix: XAI provider reasoning and tool schema handling - Fix EventThinkingDelta to use event.Thinking field instead of event.Content - Add ReasoningContent field to ProviderResponse for XAI reasoning models - Fix tool schema double-wrapping issue causing "Invalid function schema" errors - Add reasoning_effort validation: convert "medium" to "high" for Grok-3-mini models - Add XAIHTTPClient for proper request handling with reasoning support - Add ReasoningHandler to process reasoning content and emit thinking deltas - Fix ls tool to handle uninitialized config in test environments - Fix message rendering to properly display reasoning content - Add content truncation for very long messages to prevent UI issues - Fix platform-specific help key bindings (macOS vs Linux) --- internal/llm/agent/agent.go | 8 +- internal/llm/models/models.go | 7 +- internal/llm/provider/provider.go | 1 + internal/llm/provider/xai.go | 199 ++++++++++++++++-------- internal/llm/provider/xai_deferred.go | 155 +++++++++--------- internal/llm/provider/xai_http.go | 153 ++++++++++++++++++ internal/llm/provider/xai_reasoning.go | 148 ++++++++++++++++++ internal/llm/tools/ls.go | 25 ++- internal/tui/components/chat/editor.go | 9 ++ internal/tui/components/chat/message.go | 55 ++++++- internal/tui/components/core/status.go | 12 +- internal/tui/tui.go | 49 ++++-- opencode-schema.json | 24 ++- 13 files changed, 682 insertions(+), 163 deletions(-) create mode 100644 internal/llm/provider/xai_http.go create mode 100644 internal/llm/provider/xai_reasoning.go diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b0d1d81a..aa4cc58b 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -455,7 +455,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg switch event.Type { case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) + assistantMsg.AppendReasoningContent(event.Thinking) return a.messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) @@ -485,6 +485,12 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) + + // Append content if not already present (for non-streaming responses like reasoning models) + if event.Response.Content != "" && assistantMsg.Content().String() == "" { + assistantMsg.AppendContent(event.Response.Content) + } + if err := a.messages.Update(ctx, *assistantMsg); err != nil { return fmt.Errorf("failed to update message: %w", err) } diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 31074e2e..18685681 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -43,9 +43,10 @@ var ProviderPopularity = map[ModelProvider]int{ ProviderGemini: 4, ProviderGROQ: 5, ProviderOpenRouter: 6, - ProviderBedrock: 7, - ProviderAzure: 8, - ProviderVertexAI: 9, + ProviderXAI: 7, + ProviderBedrock: 8, + ProviderAzure: 9, + ProviderVertexAI: 10, } var SupportedModels = map[ModelID]Model{ diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 2c77bff9..16357796 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -36,6 +36,7 @@ type TokenUsage struct { type ProviderResponse struct { Content string + ReasoningContent string // For xAI reasoning content (internal use) ToolCalls []message.ToolCall Usage TokenUsage FinishReason message.FinishReason diff --git a/internal/llm/provider/xai.go b/internal/llm/provider/xai.go index 1f3e6e57..314faa36 100644 --- a/internal/llm/provider/xai.go +++ b/internal/llm/provider/xai.go @@ -1,12 +1,8 @@ package provider import ( - "bytes" "context" - "encoding/json" "fmt" - "io" - "net/http" "sync" "time" @@ -39,6 +35,10 @@ type xaiClient struct { deferredOptions DeferredOptions // Options for deferred completions liveSearchEnabled bool // Enable Live Search liveSearchOptions LiveSearchOptions // Options for Live Search + + // New architectural components + reasoningHandler *ReasoningHandler // Handles reasoning content processing + httpClient *XAIHTTPClient // Custom HTTP client for xAI API } type XAIClient ProviderClient @@ -96,7 +96,7 @@ func WithLiveSearchOptions(opts LiveSearchOptions) XAIOption { func newXAIClient(opts providerClientOptions) XAIClient { // Create base OpenAI client with xAI-specific settings opts.openaiOptions = append(opts.openaiOptions, - WithOpenAIBaseURL("https://api.x.ai/v1"), + WithOpenAIBaseURL("https://api.x.ai"), ) baseClient := newOpenAIClient(opts) @@ -107,6 +107,15 @@ func newXAIClient(opts providerClientOptions) XAIClient { fingerprintHistory: make([]FingerprintRecord, 0), } + // Initialize new architectural components + xClient.reasoningHandler = NewReasoningHandler(xClient) + xClient.httpClient = NewXAIHTTPClient(HTTPClientConfig{ + BaseURL: "https://api.x.ai", + APIKey: opts.apiKey, + UserAgent: "opencode/1.0", + Timeout: 30 * time.Second, + }) + // Apply xAI-specific options if any for _, opt := range opts.xaiOptions { opt(xClient) @@ -210,20 +219,32 @@ func (x *xaiClient) calculateCacheCostSavings(usage TokenUsage) float64 { func (x *xaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { // Use deferred completion if enabled if x.deferredEnabled { + logging.Debug("Using deferred completion") return x.SendDeferred(ctx, messages, tools, x.deferredOptions) } // Use custom HTTP client for Live Search in regular completions if x.liveSearchEnabled { + logging.Debug("Using live search") return x.sendWithLiveSearch(ctx, messages, tools) } + // Use reasoning handler for models with reasoning capability + if x.reasoningHandler.ShouldUseReasoning() { + logging.Debug("Using reasoning handler for model", + "model", x.providerOptions.model.ID, + "reasoning_effort", x.options.reasoningEffort) + return x.sendWithReasoningSupport(ctx, messages, tools) + } + // Use concurrent client if configured if x.concurrent != nil { + logging.Debug("Using concurrent client") return x.concurrent.send(ctx, messages, tools) } // Call the base OpenAI implementation + logging.Debug("Using base OpenAI implementation") response, err := x.openaiClient.send(ctx, messages, tools) if err != nil { return nil, err @@ -237,13 +258,41 @@ func (x *xaiClient) send(ctx context.Context, messages []message.Message, tools return response, nil } +// sendWithReasoningSupport sends a request using the reasoning handler +func (x *xaiClient) sendWithReasoningSupport(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + // Build request body using reasoning handler + reqBody := x.reasoningHandler.BuildReasoningRequest(ctx, messages, tools) + + // Log the request for debugging + logging.Debug("Sending reasoning request", + "model", reqBody["model"], + "reasoning_effort", reqBody["reasoning_effort"], + "messages_count", len(messages)) + + // Send the request using HTTP client + result, err := x.httpClient.SendCompletionRequest(ctx, reqBody) + if err != nil { + return nil, fmt.Errorf("reasoning request failed: %w", err) + } + + // Convert result to ProviderResponse + response := x.convertDeferredResult(result) + + // Store reasoning content in the response for stream processing + if len(result.Choices) > 0 && result.Choices[0].Message.ReasoningContent != "" { + response.ReasoningContent = result.Choices[0].Message.ReasoningContent + } + + return response, nil +} + // sendWithLiveSearch sends a regular completion request with Live Search parameters func (x *xaiClient) sendWithLiveSearch(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { // Build request similar to deferred completions but without the deferred flag reqBody := map[string]interface{}{ "model": x.providerOptions.model.APIModel, - "messages": x.convertMessagesToAPI(messages), - "max_tokens": &x.providerOptions.maxTokens, + "messages": x.convertMessagesToAPI(messages), // Use the deferred method for proper conversion + "max_tokens": x.providerOptions.maxTokens, // Don't use pointer } // Add tools if provided @@ -273,66 +322,25 @@ func (x *xaiClient) sendWithLiveSearch(ctx context.Context, messages []message.M reqBody["parallel_tool_calls"] = x.options.parallelToolCalls } - // Add Live Search parameters - reqBody["search_parameters"] = x.liveSearchOptions - - // Send the request using custom HTTP client - return x.sendCustomHTTPRequest(ctx, reqBody) -} - -// sendCustomHTTPRequest sends a custom HTTP request to the xAI API -func (x *xaiClient) sendCustomHTTPRequest(ctx context.Context, reqBody map[string]interface{}) (*ProviderResponse, error) { - // Import required packages for this method - jsonBody, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request body: %w", err) - } - - // Get base URL (default to xAI API if not set) - baseURL := "https://api.x.ai" - if x.openaiClient.options.baseURL != "" { - baseURL = x.openaiClient.options.baseURL - } - - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/v1/chat/completions", bytes.NewReader(jsonBody)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+x.providerOptions.apiKey) - - // Send request - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + // Add Live Search parameters only if enabled + if x.liveSearchEnabled { + reqBody["search_parameters"] = x.liveSearchOptions } - defer resp.Body.Close() - // Read response body - body, err := io.ReadAll(resp.Body) + // Log the request for debugging + logging.Debug("Sending custom HTTP request", + "model", reqBody["model"], + "reasoning_effort", reqBody["reasoning_effort"], + "messages_count", len(x.convertMessagesToAPI(messages))) + + // Send the request using HTTP client + result, err := x.httpClient.SendCompletionRequest(ctx, reqBody) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("live search request failed: %w", err) } - - // Check status code - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Parse response as OpenAI-style completion result (same format as deferred) - var result DeferredResult - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - logging.Debug("Live Search completion received", "citations", len(result.Citations)) - - // Convert result to ProviderResponse (reuse existing conversion logic) - return x.convertDeferredResult(&result), nil + + // Convert result to ProviderResponse + return x.convertDeferredResult(result), nil } func (x *xaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { @@ -341,6 +349,11 @@ func (x *xaiClient) stream(ctx context.Context, messages []message.Message, tool return x.concurrent.stream(ctx, messages, tools) } + // Use reasoning handler for models with reasoning capability + if x.reasoningHandler.ShouldUseReasoning() { + return x.streamWithReasoning(ctx, messages, tools) + } + // Get the base stream baseChan := x.openaiClient.stream(ctx, messages, tools) @@ -364,6 +377,65 @@ func (x *xaiClient) stream(ctx context.Context, messages []message.Message, tool return eventChan } +// streamWithReasoning handles streaming for reasoning models +func (x *xaiClient) streamWithReasoning(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + logging.Debug("Using reasoning handler for stream", + "model", x.providerOptions.model.ID, + "reasoning_effort", x.options.reasoningEffort) + + // Create a channel to return events + eventChan := make(chan ProviderEvent) + + go func() { + defer close(eventChan) + + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in reasoning stream", "panic", r) + eventChan <- ProviderEvent{ + Type: EventError, + Error: fmt.Errorf("panic in reasoning stream: %v", r), + } + } + }() + + // Check context first + select { + case <-ctx.Done(): + logging.Debug("Context cancelled before reasoning request", "error", ctx.Err()) + eventChan <- ProviderEvent{ + Type: EventError, + Error: ctx.Err(), + } + return + default: + } + + logging.Debug("Starting reasoning request") + + // Get response using reasoning support + response, err := x.sendWithReasoningSupport(ctx, messages, tools) + if err != nil { + logging.Error("Reasoning request failed", "error", err) + eventChan <- ProviderEvent{ + Type: EventError, + Error: err, + } + return + } + + // Process response using reasoning handler + events := x.reasoningHandler.ProcessReasoningResponse(response) + + // Send all events + for _, event := range events { + eventChan <- event + } + }() + + return eventChan +} + // GetFingerprintHistory returns the fingerprint history for auditing and compliance func (x *xaiClient) GetFingerprintHistory() []FingerprintRecord { x.mu.Lock() @@ -415,6 +487,7 @@ func (x *xaiClient) StreamBatch(ctx context.Context, requests []BatchRequest) [] return channels } + // convertMessages overrides the base implementation to support xAI-specific image handling func (x *xaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { // Add system message first diff --git a/internal/llm/provider/xai_deferred.go b/internal/llm/provider/xai_deferred.go index 27d00083..e4618341 100644 --- a/internal/llm/provider/xai_deferred.go +++ b/internal/llm/provider/xai_deferred.go @@ -7,9 +7,9 @@ import ( "fmt" "io" "net/http" + "strings" "time" - "github.com/opencode-ai/opencode/internal/llm/models" "github.com/opencode-ai/opencode/internal/llm/tools" "github.com/opencode-ai/opencode/internal/logging" "github.com/opencode-ai/opencode/internal/message" @@ -170,7 +170,7 @@ func (x *xaiClient) sendDeferred(ctx context.Context, messages []message.Message } // Apply reasoning effort if applicable - if x.shouldApplyReasoningEffort() && x.options.reasoningEffort != "" { + if x.shouldApplyReasoningEffort() { reqBody.ReasoningEffort = x.options.reasoningEffort } @@ -377,11 +377,16 @@ func (x *xaiClient) convertDeferredResult(result *DeferredResult) *ProviderRespo } inputTokens = result.Usage.PromptTokens - cachedTokens + // Handle content and reasoning_content separately to maintain proper data structure + content := choice.Message.Content + reasoningContent := choice.Message.ReasoningContent + // Create response resp := &ProviderResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: finishReason, + Content: content, + ReasoningContent: reasoningContent, + ToolCalls: toolCalls, + FinishReason: finishReason, Usage: TokenUsage{ InputTokens: inputTokens, OutputTokens: result.Usage.CompletionTokens, @@ -404,84 +409,61 @@ func (x *xaiClient) convertDeferredResult(result *DeferredResult) *ProviderRespo func (x *xaiClient) convertMessagesToAPI(messages []message.Message) []map[string]interface{} { var apiMessages []map[string]interface{} - // Add system message + // Add system message first apiMessages = append(apiMessages, map[string]interface{}{ "role": "system", "content": x.providerOptions.systemMessage, }) - // Convert user messages for _, msg := range messages { - switch msg.Role { - case message.User: - // Check if message has images - hasImages := len(msg.BinaryContent()) > 0 || len(msg.ImageURLContent()) > 0 - - if hasImages { - // Build content array for multimodal message - var content []map[string]interface{} - - // Add text content if present - if msg.Content().String() != "" { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content().String(), - }) - } - - // Add binary images (base64 encoded) - for _, binaryContent := range msg.BinaryContent() { - content = append(content, map[string]interface{}{ - "type": "image_url", - "image_url": map[string]interface{}{ - "url": binaryContent.String(models.ProviderOpenAI), // data:image/jpeg;base64, - "detail": "high", // Default to high detail - }, - }) - } + apiMsg := map[string]interface{}{ + "role": string(msg.Role), + } - // Add image URLs (web URLs) - for _, imageURLContent := range msg.ImageURLContent() { - detail := imageURLContent.Detail - if detail == "" { - detail = "auto" - } + // Convert content based on message type + switch msg.Role { + case message.User, message.System: + // Handle potential multipart content + var content []map[string]interface{} + textParts := []string{} + + for _, part := range msg.Parts { + switch p := part.(type) { + case message.TextContent: + textParts = append(textParts, p.Text) + case message.BinaryContent: + // xAI expects images in a specific format content = append(content, map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{ - "url": imageURLContent.URL, - "detail": detail, + "url": fmt.Sprintf("data:%s;base64,%s", p.MIMEType, p.Data), }, }) } - - apiMsg := map[string]interface{}{ - "role": "user", - "content": content, - } - apiMessages = append(apiMessages, apiMsg) - } else { - // Simple text message - apiMsg := map[string]interface{}{ - "role": "user", - "content": msg.Content().String(), - } - apiMessages = append(apiMessages, apiMsg) } - case message.Assistant: - apiMsg := map[string]interface{}{ - "role": "assistant", + // If we have text parts, add them first + if len(textParts) > 0 { + content = append([]map[string]interface{}{{ + "type": "text", + "text": strings.Join(textParts, "\n"), + }}, content...) } - if msg.Content().String() != "" { + if len(content) > 0 { + apiMsg["content"] = content + } else { apiMsg["content"] = msg.Content().String() } - if len(msg.ToolCalls()) > 0 { - var toolCalls []map[string]interface{} - for _, tc := range msg.ToolCalls() { - toolCalls = append(toolCalls, map[string]interface{}{ + case message.Assistant: + apiMsg["content"] = msg.Content().String() + + // Add tool calls if present + if toolCalls := msg.ToolCalls(); len(toolCalls) > 0 { + var apiToolCalls []map[string]interface{} + for _, tc := range toolCalls { + apiToolCalls = append(apiToolCalls, map[string]interface{}{ "id": tc.ID, "type": "function", "function": map[string]interface{}{ @@ -490,20 +472,22 @@ func (x *xaiClient) convertMessagesToAPI(messages []message.Message) []map[strin }, }) } - apiMsg["tool_calls"] = toolCalls + apiMsg["tool_calls"] = apiToolCalls } - apiMessages = append(apiMessages, apiMsg) - case message.Tool: + // Handle tool results for _, result := range msg.ToolResults() { apiMessages = append(apiMessages, map[string]interface{}{ "role": "tool", - "content": result.Content, "tool_call_id": result.ToolCallID, + "content": result.Content, }) } + continue // Skip adding the message itself } + + apiMessages = append(apiMessages, apiMsg) } return apiMessages @@ -515,19 +499,46 @@ func (x *xaiClient) convertToolsToAPI(tools []tools.BaseTool) []map[string]inter for _, tool := range tools { info := tool.Info() + + // Check if Parameters already contains the full schema (with "type" and "properties") + var parameters map[string]interface{} + params := info.Parameters + if _, hasType := params["type"]; hasType { + // Parameters already contains the full schema + parameters = params + } else { + // Parameters only contains properties, wrap them + parameters = map[string]interface{}{ + "type": "object", + "properties": info.Parameters, + "required": info.Required, + } + } + apiTools = append(apiTools, map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": info.Name, "description": info.Description, - "parameters": map[string]interface{}{ - "type": "object", - "properties": info.Parameters, - "required": info.Required, - }, + "parameters": parameters, }, }) } return apiTools } + +// sanitizeContent removes control characters that could corrupt terminal display +func sanitizeContent(content string) string { + // Remove ANSI escape sequences (ESC character) + content = strings.ReplaceAll(content, "\x1b", "") + // Remove carriage returns (which can cause display issues) + content = strings.ReplaceAll(content, "\r", "") + // Remove other control characters that might cause issues + content = strings.ReplaceAll(content, "\x00", "") // null + content = strings.ReplaceAll(content, "\x07", "") // bell + content = strings.ReplaceAll(content, "\x08", "") // backspace + // Replace form feed with newline to preserve structure + content = strings.ReplaceAll(content, "\x0c", "\n") + return content +} diff --git a/internal/llm/provider/xai_http.go b/internal/llm/provider/xai_http.go new file mode 100644 index 00000000..bbcfa454 --- /dev/null +++ b/internal/llm/provider/xai_http.go @@ -0,0 +1,153 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/opencode-ai/opencode/internal/logging" +) + +// HTTPClientConfig holds configuration for HTTP requests +type HTTPClientConfig struct { + BaseURL string + APIKey string + Timeout time.Duration + UserAgent string +} + +// XAIHTTPClient handles HTTP communication with xAI API +type XAIHTTPClient struct { + config HTTPClientConfig + client *http.Client +} + +// NewXAIHTTPClient creates a new XAI HTTP client +func NewXAIHTTPClient(config HTTPClientConfig) *XAIHTTPClient { + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + + if config.BaseURL == "" { + config.BaseURL = "https://api.x.ai" + } + + // Ensure HTTPS + if strings.HasPrefix(config.BaseURL, "http://") { + config.BaseURL = strings.Replace(config.BaseURL, "http://", "https://", 1) + logging.Debug("Converted HTTP to HTTPS", "url", config.BaseURL) + } + + return &XAIHTTPClient{ + config: config, + client: &http.Client{Timeout: config.Timeout}, + } +} + +// SendCompletionRequest sends a chat completion request to xAI API +func (c *XAIHTTPClient) SendCompletionRequest(ctx context.Context, reqBody map[string]interface{}) (*DeferredResult, error) { + // Marshal request body + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + // Create HTTP request + url := c.config.BaseURL + "/v1/chat/completions" + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + c.setRequestHeaders(req, len(jsonBody)) + + // Log request details (with masked API key) + c.logRequest(url, len(jsonBody)) + + // Send request + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Check status code + if resp.StatusCode != http.StatusOK { + logging.Error("HTTP request failed", + "status", resp.StatusCode, + "body", string(body)) + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + logging.Debug("HTTP response received", "status", resp.StatusCode, "body_size", len(body)) + + // Parse response + var result DeferredResult + if err := json.Unmarshal(body, &result); err != nil { + logging.Error("Failed to parse response", + "error", err, + "body", string(body)) + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Log the parsed result + c.logResponse(&result) + + return &result, nil +} + +// setRequestHeaders sets standard headers for xAI API requests +func (c *XAIHTTPClient) setRequestHeaders(req *http.Request, bodySize int) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.APIKey) + + if c.config.UserAgent != "" { + req.Header.Set("User-Agent", c.config.UserAgent) + } +} + +// logRequest logs request details with masked API key +func (c *XAIHTTPClient) logRequest(url string, bodySize int) { + maskedKey := c.getMaskedAPIKey() + logging.Debug("Sending HTTP request", + "url", url, + "body_size", bodySize, + "api_key_masked", maskedKey) +} + +// logResponse logs response details +func (c *XAIHTTPClient) logResponse(result *DeferredResult) { + if len(result.Choices) > 0 { + choice := result.Choices[0] + logging.Debug("XAI HTTP response parsed", + "citations", len(result.Citations), + "content_length", len(choice.Message.Content), + "reasoning_length", len(choice.Message.ReasoningContent), + "has_content", choice.Message.Content != "", + "has_reasoning", choice.Message.ReasoningContent != "", + "finish_reason", choice.FinishReason) + } else { + logging.Debug("No choices in HTTP response") + } +} + +// getMaskedAPIKey returns a masked version of the API key for logging +func (c *XAIHTTPClient) getMaskedAPIKey() string { + if len(c.config.APIKey) <= 6 { + return "***" + } + return c.config.APIKey[:3] + "***" + c.config.APIKey[len(c.config.APIKey)-3:] +} \ No newline at end of file diff --git a/internal/llm/provider/xai_reasoning.go b/internal/llm/provider/xai_reasoning.go new file mode 100644 index 00000000..d7011a00 --- /dev/null +++ b/internal/llm/provider/xai_reasoning.go @@ -0,0 +1,148 @@ +package provider + +import ( + "context" + "strings" + + "github.com/opencode-ai/opencode/internal/llm/models" + "github.com/opencode-ai/opencode/internal/llm/tools" + "github.com/opencode-ai/opencode/internal/logging" + "github.com/opencode-ai/opencode/internal/message" +) + +// ReasoningConfig holds configuration for reasoning requests +type ReasoningConfig struct { + Model string + ReasoningEffort string + MaxTokens int + Tools []tools.BaseTool +} + +// ReasoningHandler manages XAI reasoning content processing +type ReasoningHandler struct { + client *xaiClient +} + +// NewReasoningHandler creates a new reasoning handler +func NewReasoningHandler(client *xaiClient) *ReasoningHandler { + return &ReasoningHandler{ + client: client, + } +} + +// ShouldUseReasoning determines if reasoning should be used for a request +func (r *ReasoningHandler) ShouldUseReasoning() bool { + canReason := r.client.providerOptions.model.CanReason + hasReasoningEffort := r.client.options.reasoningEffort != "" + shouldApply := r.client.shouldApplyReasoningEffort() + + logging.Debug("Checking reasoning conditions", + "model", r.client.providerOptions.model.ID, + "can_reason", canReason, + "reasoning_effort", r.client.options.reasoningEffort, + "has_reasoning_effort", hasReasoningEffort, + "should_apply", shouldApply) + + return canReason && hasReasoningEffort && shouldApply +} + +// ProcessReasoningResponse handles reasoning content from API responses +func (r *ReasoningHandler) ProcessReasoningResponse(response *ProviderResponse) []ProviderEvent { + var events []ProviderEvent + + // Send reasoning content as thinking delta first (if present) + if response.ReasoningContent != "" { + sanitizedReasoning := r.sanitizeReasoningContent(response.ReasoningContent) + events = append(events, ProviderEvent{ + Type: EventThinkingDelta, + Thinking: sanitizedReasoning, + }) + + logging.Debug("Reasoning content processed", + "original_length", len(response.ReasoningContent), + "sanitized_length", len(sanitizedReasoning)) + } + + // Send regular content as delta if present + if response.Content != "" { + events = append(events, ProviderEvent{ + Type: EventContentDelta, + Content: response.Content, + }) + } + + // Clear reasoning content from response before sending complete event + response.ReasoningContent = "" + + // Send complete event + events = append(events, ProviderEvent{ + Type: EventComplete, + Response: response, + }) + + return events +} + +// sanitizeReasoningContent removes control characters that could corrupt terminal display +func (r *ReasoningHandler) sanitizeReasoningContent(content string) string { + // Remove ANSI escape sequences (ESC character) + content = strings.ReplaceAll(content, "\x1b", "") + // Remove carriage returns (which can cause display issues) + content = strings.ReplaceAll(content, "\r", "") + // Remove other control characters that might cause issues + content = strings.ReplaceAll(content, "\x00", "") // null + content = strings.ReplaceAll(content, "\x07", "") // bell + content = strings.ReplaceAll(content, "\x08", "") // backspace + // Replace form feed with newline to preserve structure + content = strings.ReplaceAll(content, "\x0c", "\n") + return content +} + +// BuildReasoningRequest creates a request body for reasoning models +func (r *ReasoningHandler) BuildReasoningRequest(ctx context.Context, messages []message.Message, tools []tools.BaseTool) map[string]interface{} { + reqBody := map[string]interface{}{ + "model": r.client.providerOptions.model.APIModel, + "messages": r.client.convertMessagesToAPI(messages), + "max_tokens": r.client.providerOptions.maxTokens, + "stream": false, // Explicitly disable streaming for reasoning requests + } + + // Add tools if provided + if len(tools) > 0 { + reqBody["tools"] = r.client.convertToolsToAPI(tools) + } + + // Apply reasoning effort only if the model supports it + // xAI grok models do not accept reasoning_effort parameter + if r.client.options.reasoningEffort != "" && r.client.shouldApplyReasoningEffort() { + reasoningEffort := r.client.options.reasoningEffort + + // Grok-3-mini models only support "low" or "high", not "medium" + if (r.client.providerOptions.model.ID == models.XAIGrok3Mini || + r.client.providerOptions.model.ID == models.XAIGrok3MiniFast) && + reasoningEffort == "medium" { + // Convert medium to high for Grok-3-mini models + reasoningEffort = "high" + logging.Debug("Converting reasoning effort from medium to high for Grok-3-mini") + } + + reqBody["reasoning_effort"] = reasoningEffort + } + + // Apply response format if configured + if r.client.options.responseFormat != nil { + reqBody["response_format"] = r.client.options.responseFormat + } + + // Apply tool choice if configured + if r.client.options.toolChoice != nil { + reqBody["tool_choice"] = r.client.options.toolChoice + } + + // Apply parallel tool calls if configured + if r.client.options.parallelToolCalls != nil { + reqBody["parallel_tool_calls"] = r.client.options.parallelToolCalls + } + + return reqBody +} diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index 0febbf8e..2c39f467 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -67,6 +67,26 @@ func NewLsTool() BaseTool { return &lsTool{} } +// getWorkingDirectory safely gets the working directory, falling back to os.Getwd() if config is not loaded +func getWorkingDirectory() string { + // Try config first, but handle panic if config not loaded + defer func() { + recover() // Silently recover from panic if config not loaded + }() + + if wd := config.WorkingDirectory(); wd != "" { + return wd + } + + // Fall back to current working directory + if wd, err := os.Getwd(); err == nil { + return wd + } + + // Last resort - return current directory + return "." +} + func (l *lsTool) Info() ToolInfo { return ToolInfo{ Name: LSToolName, @@ -96,11 +116,12 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { searchPath := params.Path if searchPath == "" { - searchPath = config.WorkingDirectory() + // Try to get working directory from config, fall back to current working directory + searchPath = getWorkingDirectory() } if !filepath.IsAbs(searchPath) { - searchPath = filepath.Join(config.WorkingDirectory(), searchPath) + searchPath = filepath.Join(getWorkingDirectory(), searchPath) } if _, err := os.Stat(searchPath); os.IsNotExist(err) { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index eccbaae1..2357dc46 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -163,6 +163,15 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } m.attachments = append(m.attachments, msg.Attachment) case tea.KeyMsg: + // Let specific global shortcuts pass through to higher level + globalKeys := []string{"ctrl+h", "ctrl+?", "cmd+?", "cmd+shift+/", "ctrl+shift+/", "ctrl+l", "ctrl+s", "ctrl+k", "ctrl+f", "ctrl+o", "ctrl+t"} + for _, globalKey := range globalKeys { + if msg.String() == globalKey { + // Don't consume global shortcuts, let them bubble up + return m, nil + } + } + if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) { m.deleteMode = true return m, nil diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index 0732366d..58c1efa4 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -28,6 +28,7 @@ const ( toolMessageType maxResultHeight = 10 + maxMessageLines = 100 // Limit very long messages to prevent display issues ) type uiMessage struct { @@ -39,8 +40,22 @@ type uiMessage struct { } func toMarkdown(content string, focused bool, width int) string { + // Ensure minimum width to prevent rendering issues + if width < 20 { + width = 80 + } + + // For very long content, use plain text to avoid markdown rendering issues + if len(content) > 3000 { + return content + } + r := styles.GetMarkdownRenderer(width) - rendered, _ := r.Render(content) + rendered, err := r.Render(content) + if err != nil { + // Fallback to plain content if markdown rendering fails + return content + } return rendered } @@ -135,6 +150,33 @@ func renderAssistantMessage( t := theme.CurrentTheme() baseStyle := styles.BaseStyle() + // Combine reasoning content and regular content if both are present + if thinkingContent != "" && content != "" { + // For very long reasoning content, truncate it more aggressively + if len(thinkingContent) > 2000 { + lines := strings.Split(thinkingContent, "\n") + if len(lines) > 20 { + thinkingContent = strings.Join(lines[:20], "\n") + "\n\n[Reasoning content truncated...]" + } + } + content = thinkingContent + "\n\n" + content + } else if thinkingContent != "" && content == "" { + // For standalone reasoning content, also truncate if very long + if len(thinkingContent) > 2000 { + lines := strings.Split(thinkingContent, "\n") + if len(lines) > 30 { + content = strings.Join(lines[:30], "\n") + "\n\n[Content truncated...]" + } else { + content = thinkingContent + } + } else { + content = thinkingContent + } + } + + // Final truncation by line count to prevent display issues + content = truncateHeight(content, maxMessageLines) + // Add finish info if available if finished { switch finishData.Reason { @@ -185,7 +227,16 @@ func renderAssistantMessage( position++ // for the space } else if thinking && thinkingContent != "" { // Render the thinking content - content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width) + content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width, info...) + messages = append(messages, uiMessage{ + ID: msg.ID, + messageType: assistantMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + }) + position += messages[0].height + position++ // for the space } for i, toolCall := range msg.ToolCalls() { diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 0dc227a8..57b4ffea 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -2,6 +2,7 @@ package core import ( "fmt" + "runtime" "strings" "time" @@ -72,10 +73,19 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var helpWidget = "" +// getHelpKeyDisplay returns the primary help key for the current OS +func getHelpKeyDisplay() string { + if runtime.GOOS == "darwin" { + return "ctrl+h" + } + return "ctrl+?" +} + // getHelpWidget returns the help widget with current theme colors func getHelpWidget() string { t := theme.CurrentTheme() - helpText := "ctrl+? help" + // Use the correct help key for the current OS + helpText := getHelpKeyDisplay() + " help" return styles.Padded(). Background(t.TextMuted()). diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1c9c2f03..a3eca477 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -3,6 +3,7 @@ package tui import ( "context" "fmt" + "runtime" "strings" "github.com/charmbracelet/bubbles/key" @@ -41,6 +42,30 @@ const ( quitKey = "q" ) +// getOSModifier returns the appropriate modifier key for the current OS +func getOSModifier() string { + if runtime.GOOS == "darwin" { + return "cmd" + } + return "ctrl" +} + +// getHelpKeyLabel returns the appropriate help key label for the current OS +func getHelpKeyLabel() string { + if runtime.GOOS == "darwin" { + return "ctrl+h" + } + return "ctrl+?" +} + +// getHelpKeys returns the appropriate help key bindings for the current OS +func getHelpKeys() []string { + if runtime.GOOS == "darwin" { + return []string{"cmd+shift+/", "cmd+?", "ctrl+h", "ctrl+?"} + } + return []string{"ctrl+shift+/", "ctrl+?", "ctrl+h"} +} + var keys = keyMap{ Logs: key.NewBinding( key.WithKeys("ctrl+l"), @@ -52,8 +77,8 @@ var keys = keyMap{ key.WithHelp("ctrl+c", "quit"), ), Help: key.NewBinding( - key.WithKeys("ctrl+_", "ctrl+h"), - key.WithHelp("ctrl+?", "toggle help"), + key.WithKeys(getHelpKeys()...), + key.WithHelp(getHelpKeyLabel(), "toggle help"), ), SwitchSession: key.NewBinding( @@ -81,8 +106,8 @@ var keys = keyMap{ } var helpEsc = key.NewBinding( - key.WithKeys("?"), - key.WithHelp("?", "toggle help"), + key.WithKeys("?", "cmd+shift+/", "ctrl+shift+/", "cmd+?", "ctrl+?"), + key.WithHelp(getHelpKeyLabel(), "toggle help"), ) var returnKey = key.NewBinding( @@ -505,7 +530,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showModelDialog = false return a, nil } - if a.currentPage == page.ChatPage && !a.showQuit && !a.showPermissions && !a.showSessionDialog && !a.showCommandDialog { + if !a.showQuit && !a.showPermissions && !a.showSessionDialog && !a.showCommandDialog { a.showModelDialog = true return a, nil } @@ -518,7 +543,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, a.themeDialog.Init() } return a, nil - case key.Matches(msg, returnKey) || key.Matches(msg): + case key.Matches(msg, returnKey): if msg.String() == quitKey { if a.currentPage == page.LogsPage { return a, a.moveToPage(page.ChatPage) @@ -550,6 +575,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } case key.Matches(msg, keys.Logs): + if a.currentPage == page.LogsPage { + // If already on logs page, switch back to chat + return a, a.moveToPage(page.ChatPage) + } return a, a.moveToPage(page.LogsPage) case key.Matches(msg, keys.Help): if a.showQuit { @@ -558,13 +587,11 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showHelp = !a.showHelp return a, nil case key.Matches(msg, helpEsc): - if a.app.CoderAgent.IsBusy() { - if a.showQuit { - return a, nil - } - a.showHelp = !a.showHelp + if a.showQuit { return a, nil } + a.showHelp = !a.showHelp + return a, nil case key.Matches(msg, keys.Filepicker): a.showFilepicker = !a.showFilepicker a.filepicker.ToggleFilepicker(a.showFilepicker) diff --git a/opencode-schema.json b/opencode-schema.json index d1a46165..1b3a7570 100644 --- a/opencode-schema.json +++ b/opencode-schema.json @@ -30,7 +30,7 @@ "o4-mini", "azure.gpt-4.1-mini", "openrouter.o3", - "grok-3-beta", + "grok-3", "o3-mini", "qwen-qwq", "azure.o1", @@ -54,15 +54,19 @@ "azure.o3", "azure.gpt-4.5-preview", "openrouter.claude-3-opus", - "grok-3-mini-fast-beta", + "grok-3-mini-fast", "claude-4-sonnet", "azure.o4-mini", - "grok-3-fast-beta", + "grok-3-fast", "claude-3.5-sonnet", "azure.o1-mini", "openrouter.claude-3.7-sonnet", "openrouter.gpt-4.5-preview", - "grok-3-mini-beta", + "grok-3-mini", + "grok-4-0709", + "grok-2-1212", + "grok-2-vision-1212", + "grok-2-image-1212", "claude-3.7-sonnet", "gemini-2.0-flash", "openrouter.deepseek-r1-free", @@ -144,7 +148,7 @@ "o4-mini", "azure.gpt-4.1-mini", "openrouter.o3", - "grok-3-beta", + "grok-3", "o3-mini", "qwen-qwq", "azure.o1", @@ -168,15 +172,19 @@ "azure.o3", "azure.gpt-4.5-preview", "openrouter.claude-3-opus", - "grok-3-mini-fast-beta", + "grok-3-mini-fast", "claude-4-sonnet", "azure.o4-mini", - "grok-3-fast-beta", + "grok-3-fast", "claude-3.5-sonnet", "azure.o1-mini", "openrouter.claude-3.7-sonnet", "openrouter.gpt-4.5-preview", - "grok-3-mini-beta", + "grok-3-mini", + "grok-4-0709", + "grok-2-1212", + "grok-2-vision-1212", + "grok-2-image-1212", "claude-3.7-sonnet", "gemini-2.0-flash", "openrouter.deepseek-r1-free", From 70cf02bc75dd1edaf02ecf8a103d25e63a0f250b Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Sat, 12 Jul 2025 01:12:55 +0400 Subject: [PATCH 7/9] fix: Handle Grok 4 automatic reasoning in XAI provider - Updated ShouldUseReasoning() to properly handle Grok 4's automatic reasoning - Grok 4 now uses the reasoning handler path without requiring reasoning_effort parameter - This ensures proper tool schema handling for Grok 4 models --- .claude/settings.local.json | 19 +++++++++ internal/llm/provider/xai.go | 49 +++++++++++----------- internal/llm/provider/xai_deferred.go | 12 +++--- internal/llm/provider/xai_deferred_test.go | 6 +-- internal/llm/provider/xai_http.go | 36 ++++++++-------- internal/llm/provider/xai_models.go | 16 +++---- internal/llm/provider/xai_reasoning.go | 13 +++++- internal/llm/provider/xai_validation.go | 2 +- 8 files changed, 91 insertions(+), 62 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 00000000..d9f3b944 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,19 @@ +{ + "permissions": { + "allow": [ + "Bash(ls:*)", + "Bash(export PATH=$PATH:/Users/alexbel/go/bin)", + "Bash(opencode:*)", + "Bash(go build:*)", + "Bash(go install:*)", + "Bash(go run:*)", + "Bash(go clean:*)", + "Bash(rm:*)", + "Bash(find:*)", + "Bash(sqlite3:*)", + "Bash(cp:*)", + "Bash(go test:*)" + ], + "deny": [] + } +} \ No newline at end of file diff --git a/internal/llm/provider/xai.go b/internal/llm/provider/xai.go index 314faa36..1810b543 100644 --- a/internal/llm/provider/xai.go +++ b/internal/llm/provider/xai.go @@ -35,7 +35,7 @@ type xaiClient struct { deferredOptions DeferredOptions // Options for deferred completions liveSearchEnabled bool // Enable Live Search liveSearchOptions LiveSearchOptions // Options for Live Search - + // New architectural components reasoningHandler *ReasoningHandler // Handles reasoning content processing httpClient *XAIHTTPClient // Custom HTTP client for xAI API @@ -96,7 +96,7 @@ func WithLiveSearchOptions(opts LiveSearchOptions) XAIOption { func newXAIClient(opts providerClientOptions) XAIClient { // Create base OpenAI client with xAI-specific settings opts.openaiOptions = append(opts.openaiOptions, - WithOpenAIBaseURL("https://api.x.ai"), + WithOpenAIBaseURL("https://api.x.ai/v1"), ) baseClient := newOpenAIClient(opts) @@ -110,7 +110,7 @@ func newXAIClient(opts providerClientOptions) XAIClient { // Initialize new architectural components xClient.reasoningHandler = NewReasoningHandler(xClient) xClient.httpClient = NewXAIHTTPClient(HTTPClientConfig{ - BaseURL: "https://api.x.ai", + BaseURL: "https://api.x.ai/v1", APIKey: opts.apiKey, UserAgent: "opencode/1.0", Timeout: 30 * time.Second, @@ -231,7 +231,7 @@ func (x *xaiClient) send(ctx context.Context, messages []message.Message, tools // Use reasoning handler for models with reasoning capability if x.reasoningHandler.ShouldUseReasoning() { - logging.Debug("Using reasoning handler for model", + logging.Debug("Using reasoning handler for model", "model", x.providerOptions.model.ID, "reasoning_effort", x.options.reasoningEffort) return x.sendWithReasoningSupport(ctx, messages, tools) @@ -262,27 +262,27 @@ func (x *xaiClient) send(ctx context.Context, messages []message.Message, tools func (x *xaiClient) sendWithReasoningSupport(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { // Build request body using reasoning handler reqBody := x.reasoningHandler.BuildReasoningRequest(ctx, messages, tools) - + // Log the request for debugging - logging.Debug("Sending reasoning request", + logging.Debug("Sending reasoning request", "model", reqBody["model"], "reasoning_effort", reqBody["reasoning_effort"], "messages_count", len(messages)) - + // Send the request using HTTP client result, err := x.httpClient.SendCompletionRequest(ctx, reqBody) if err != nil { return nil, fmt.Errorf("reasoning request failed: %w", err) } - + // Convert result to ProviderResponse response := x.convertDeferredResult(result) - + // Store reasoning content in the response for stream processing if len(result.Choices) > 0 && result.Choices[0].Message.ReasoningContent != "" { response.ReasoningContent = result.Choices[0].Message.ReasoningContent } - + return response, nil } @@ -292,7 +292,7 @@ func (x *xaiClient) sendWithLiveSearch(ctx context.Context, messages []message.M reqBody := map[string]interface{}{ "model": x.providerOptions.model.APIModel, "messages": x.convertMessagesToAPI(messages), // Use the deferred method for proper conversion - "max_tokens": x.providerOptions.maxTokens, // Don't use pointer + "max_tokens": x.providerOptions.maxTokens, // Don't use pointer } // Add tools if provided @@ -328,17 +328,17 @@ func (x *xaiClient) sendWithLiveSearch(ctx context.Context, messages []message.M } // Log the request for debugging - logging.Debug("Sending custom HTTP request", + logging.Debug("Sending custom HTTP request", "model", reqBody["model"], "reasoning_effort", reqBody["reasoning_effort"], "messages_count", len(x.convertMessagesToAPI(messages))) - + // Send the request using HTTP client result, err := x.httpClient.SendCompletionRequest(ctx, reqBody) if err != nil { return nil, fmt.Errorf("live search request failed: %w", err) } - + // Convert result to ProviderResponse return x.convertDeferredResult(result), nil } @@ -379,16 +379,16 @@ func (x *xaiClient) stream(ctx context.Context, messages []message.Message, tool // streamWithReasoning handles streaming for reasoning models func (x *xaiClient) streamWithReasoning(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - logging.Debug("Using reasoning handler for stream", + logging.Debug("Using reasoning handler for stream", "model", x.providerOptions.model.ID, "reasoning_effort", x.options.reasoningEffort) - + // Create a channel to return events eventChan := make(chan ProviderEvent) - + go func() { defer close(eventChan) - + defer func() { if r := recover(); r != nil { logging.Error("Panic in reasoning stream", "panic", r) @@ -398,7 +398,7 @@ func (x *xaiClient) streamWithReasoning(ctx context.Context, messages []message. } } }() - + // Check context first select { case <-ctx.Done(): @@ -410,9 +410,9 @@ func (x *xaiClient) streamWithReasoning(ctx context.Context, messages []message. return default: } - + logging.Debug("Starting reasoning request") - + // Get response using reasoning support response, err := x.sendWithReasoningSupport(ctx, messages, tools) if err != nil { @@ -423,16 +423,16 @@ func (x *xaiClient) streamWithReasoning(ctx context.Context, messages []message. } return } - + // Process response using reasoning handler events := x.reasoningHandler.ProcessReasoningResponse(response) - + // Send all events for _, event := range events { eventChan <- event } }() - + return eventChan } @@ -487,7 +487,6 @@ func (x *xaiClient) StreamBatch(ctx context.Context, requests []BatchRequest) [] return channels } - // convertMessages overrides the base implementation to support xAI-specific image handling func (x *xaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { // Add system message first diff --git a/internal/llm/provider/xai_deferred.go b/internal/llm/provider/xai_deferred.go index e4618341..5e83576b 100644 --- a/internal/llm/provider/xai_deferred.go +++ b/internal/llm/provider/xai_deferred.go @@ -201,13 +201,13 @@ func (x *xaiClient) sendDeferred(ctx context.Context, messages []message.Message } // Get base URL (default to xAI API if not set) - baseURL := "https://api.x.ai" + baseURL := "https://api.x.ai/v1" if x.openaiClient.options.baseURL != "" { baseURL = x.openaiClient.options.baseURL } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/v1/chat/completions", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } @@ -253,12 +253,12 @@ func (x *xaiClient) sendDeferred(ctx context.Context, messages []message.Message // pollDeferredResult polls for the deferred completion result func (x *xaiClient) pollDeferredResult(ctx context.Context, requestID string, opts DeferredOptions) (*DeferredResult, error) { // Get base URL (default to xAI API if not set) - baseURL := "https://api.x.ai" + baseURL := "https://api.x.ai/v1" if x.openaiClient.options.baseURL != "" { baseURL = x.openaiClient.options.baseURL } - url := fmt.Sprintf("%s/v1/chat/deferred-completion/%s", baseURL, requestID) + url := fmt.Sprintf("%s/chat/deferred-completion/%s", baseURL, requestID) // Create HTTP client client := &http.Client{Timeout: 30 * time.Second} @@ -499,7 +499,7 @@ func (x *xaiClient) convertToolsToAPI(tools []tools.BaseTool) []map[string]inter for _, tool := range tools { info := tool.Info() - + // Check if Parameters already contains the full schema (with "type" and "properties") var parameters map[string]interface{} params := info.Parameters @@ -514,7 +514,7 @@ func (x *xaiClient) convertToolsToAPI(tools []tools.BaseTool) []map[string]inter "required": info.Required, } } - + apiTools = append(apiTools, map[string]interface{}{ "type": "function", "function": map[string]interface{}{ diff --git a/internal/llm/provider/xai_deferred_test.go b/internal/llm/provider/xai_deferred_test.go index 1db8a58a..937e6e9a 100644 --- a/internal/llm/provider/xai_deferred_test.go +++ b/internal/llm/provider/xai_deferred_test.go @@ -171,7 +171,7 @@ func TestXAIProvider_DeferredCompletionsMock(t *testing.T) { count := atomic.AddInt32(&requestCount, 1) switch r.URL.Path { - case "/v1/chat/completions": + case "/chat/completions": // Initial deferred request assert.Equal(t, "POST", r.Method) @@ -187,7 +187,7 @@ func TestXAIProvider_DeferredCompletionsMock(t *testing.T) { RequestID: requestID, }) - case "/v1/chat/deferred-completion/" + requestID: + case "/chat/deferred-completion/" + requestID: // Polling request assert.Equal(t, "GET", r.Method) @@ -283,7 +283,7 @@ func TestXAIProvider_DeferredCompletionsMock(t *testing.T) { // Create mock server that always returns 202 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case "/v1/chat/completions": + case "/chat/completions": // Return request ID w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(DeferredCompletionResponse{ diff --git a/internal/llm/provider/xai_http.go b/internal/llm/provider/xai_http.go index bbcfa454..3ff997de 100644 --- a/internal/llm/provider/xai_http.go +++ b/internal/llm/provider/xai_http.go @@ -15,10 +15,10 @@ import ( // HTTPClientConfig holds configuration for HTTP requests type HTTPClientConfig struct { - BaseURL string - APIKey string - Timeout time.Duration - UserAgent string + BaseURL string + APIKey string + Timeout time.Duration + UserAgent string } // XAIHTTPClient handles HTTP communication with xAI API @@ -32,17 +32,17 @@ func NewXAIHTTPClient(config HTTPClientConfig) *XAIHTTPClient { if config.Timeout == 0 { config.Timeout = 30 * time.Second } - + if config.BaseURL == "" { - config.BaseURL = "https://api.x.ai" + config.BaseURL = "https://api.x.ai/v1" } - + // Ensure HTTPS if strings.HasPrefix(config.BaseURL, "http://") { config.BaseURL = strings.Replace(config.BaseURL, "http://", "https://", 1) logging.Debug("Converted HTTP to HTTPS", "url", config.BaseURL) } - + return &XAIHTTPClient{ config: config, client: &http.Client{Timeout: config.Timeout}, @@ -58,7 +58,7 @@ func (c *XAIHTTPClient) SendCompletionRequest(ctx context.Context, reqBody map[s } // Create HTTP request - url := c.config.BaseURL + "/v1/chat/completions" + url := c.config.BaseURL + "/chat/completions" req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -85,18 +85,18 @@ func (c *XAIHTTPClient) SendCompletionRequest(ctx context.Context, reqBody map[s // Check status code if resp.StatusCode != http.StatusOK { - logging.Error("HTTP request failed", + logging.Error("HTTP request failed", "status", resp.StatusCode, "body", string(body)) return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) } - + logging.Debug("HTTP response received", "status", resp.StatusCode, "body_size", len(body)) // Parse response var result DeferredResult if err := json.Unmarshal(body, &result); err != nil { - logging.Error("Failed to parse response", + logging.Error("Failed to parse response", "error", err, "body", string(body)) return nil, fmt.Errorf("failed to parse response: %w", err) @@ -104,7 +104,7 @@ func (c *XAIHTTPClient) SendCompletionRequest(ctx context.Context, reqBody map[s // Log the parsed result c.logResponse(&result) - + return &result, nil } @@ -113,7 +113,7 @@ func (c *XAIHTTPClient) setRequestHeaders(req *http.Request, bodySize int) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+c.config.APIKey) - + if c.config.UserAgent != "" { req.Header.Set("User-Agent", c.config.UserAgent) } @@ -122,8 +122,8 @@ func (c *XAIHTTPClient) setRequestHeaders(req *http.Request, bodySize int) { // logRequest logs request details with masked API key func (c *XAIHTTPClient) logRequest(url string, bodySize int) { maskedKey := c.getMaskedAPIKey() - logging.Debug("Sending HTTP request", - "url", url, + logging.Debug("Sending HTTP request", + "url", url, "body_size", bodySize, "api_key_masked", maskedKey) } @@ -132,7 +132,7 @@ func (c *XAIHTTPClient) logRequest(url string, bodySize int) { func (c *XAIHTTPClient) logResponse(result *DeferredResult) { if len(result.Choices) > 0 { choice := result.Choices[0] - logging.Debug("XAI HTTP response parsed", + logging.Debug("XAI HTTP response parsed", "citations", len(result.Citations), "content_length", len(choice.Message.Content), "reasoning_length", len(choice.Message.ReasoningContent), @@ -150,4 +150,4 @@ func (c *XAIHTTPClient) getMaskedAPIKey() string { return "***" } return c.config.APIKey[:3] + "***" + c.config.APIKey[len(c.config.APIKey)-3:] -} \ No newline at end of file +} diff --git a/internal/llm/provider/xai_models.go b/internal/llm/provider/xai_models.go index 3abb8cb0..c3022f8d 100644 --- a/internal/llm/provider/xai_models.go +++ b/internal/llm/provider/xai_models.go @@ -47,12 +47,12 @@ type XAIImageModelInfo struct { Aliases []string `json:"aliases"` } -// XAILanguageModelsResponse represents the response from /v1/language-models +// XAILanguageModelsResponse represents the response from /language-models type XAILanguageModelsResponse struct { Models []XAIModelInfo `json:"models"` } -// XAIImageModelsResponse represents the response from /v1/image-generation-models +// XAIImageModelsResponse represents the response from /image-generation-models type XAIImageModelsResponse struct { Models []XAIImageModelInfo `json:"models"` } @@ -87,7 +87,7 @@ func (x *xaiClient) DiscoverModelCapabilities(ctx context.Context, modelID strin // getLanguageModelCapabilities fetches capabilities from language models endpoint func (x *xaiClient) getLanguageModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { - url := fmt.Sprintf("%s/v1/language-models/%s", x.getBaseURL(), modelID) + url := fmt.Sprintf("%s/language-models/%s", x.getBaseURL(), modelID) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -156,7 +156,7 @@ func (x *xaiClient) getLanguageModelCapabilities(ctx context.Context, modelID st // getImageModelCapabilities fetches capabilities from image generation models endpoint func (x *xaiClient) getImageModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { - url := fmt.Sprintf("%s/v1/image-generation-models/%s", x.getBaseURL(), modelID) + url := fmt.Sprintf("%s/image-generation-models/%s", x.getBaseURL(), modelID) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -202,7 +202,7 @@ func (x *xaiClient) getImageModelCapabilities(ctx context.Context, modelID strin // getBasicModelCapabilities fetches basic model info as fallback func (x *xaiClient) getBasicModelCapabilities(ctx context.Context, modelID string) (*ModelCapabilities, error) { - url := fmt.Sprintf("%s/v1/models/%s", x.getBaseURL(), modelID) + url := fmt.Sprintf("%s/models/%s", x.getBaseURL(), modelID) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -235,7 +235,7 @@ func (x *xaiClient) getBaseURL() string { if x.options.baseURL != "" { return x.options.baseURL } - return "https://api.x.ai" + return "https://api.x.ai/v1" } // ListAllModels fetches all available models from xAI API @@ -244,7 +244,7 @@ func (x *xaiClient) ListAllModels(ctx context.Context) ([]XAIModelInfo, []XAIIma var imageModels []XAIImageModelInfo // Fetch language models - langURL := fmt.Sprintf("%s/v1/language-models", x.getBaseURL()) + langURL := fmt.Sprintf("%s/language-models", x.getBaseURL()) langReq, err := http.NewRequestWithContext(ctx, "GET", langURL, nil) if err != nil { return nil, nil, err @@ -267,7 +267,7 @@ func (x *xaiClient) ListAllModels(ctx context.Context) ([]XAIModelInfo, []XAIIma } // Fetch image generation models - imgURL := fmt.Sprintf("%s/v1/image-generation-models", x.getBaseURL()) + imgURL := fmt.Sprintf("%s/image-generation-models", x.getBaseURL()) imgReq, err := http.NewRequestWithContext(ctx, "GET", imgURL, nil) if err != nil { return languageModels, nil, err diff --git a/internal/llm/provider/xai_reasoning.go b/internal/llm/provider/xai_reasoning.go index d7011a00..bdc74413 100644 --- a/internal/llm/provider/xai_reasoning.go +++ b/internal/llm/provider/xai_reasoning.go @@ -36,13 +36,24 @@ func (r *ReasoningHandler) ShouldUseReasoning() bool { hasReasoningEffort := r.client.options.reasoningEffort != "" shouldApply := r.client.shouldApplyReasoningEffort() + // Special case for Grok 4: it has automatic reasoning and should always use + // the reasoning handler when CanReason is true, regardless of reasoning_effort + isGrok4 := r.client.providerOptions.model.ID == models.XAIGrok4 + logging.Debug("Checking reasoning conditions", "model", r.client.providerOptions.model.ID, "can_reason", canReason, "reasoning_effort", r.client.options.reasoningEffort, "has_reasoning_effort", hasReasoningEffort, - "should_apply", shouldApply) + "should_apply", shouldApply, + "is_grok4", isGrok4) + + // Grok 4 uses reasoning handler whenever it can reason + if isGrok4 && canReason { + return true + } + // Other models need reasoning effort to be set and applicable return canReason && hasReasoningEffort && shouldApply } diff --git a/internal/llm/provider/xai_validation.go b/internal/llm/provider/xai_validation.go index f4e553bc..1895853e 100644 --- a/internal/llm/provider/xai_validation.go +++ b/internal/llm/provider/xai_validation.go @@ -29,7 +29,7 @@ type XAIAPIKeyInfo struct { // ValidateAPIKey validates the xAI API key and returns detailed information about it func (x *xaiClient) ValidateAPIKey(ctx context.Context) (*XAIAPIKeyInfo, error) { - url := fmt.Sprintf("%s/v1/api-key", x.getBaseURL()) + url := fmt.Sprintf("%s/api-key", x.getBaseURL()) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { From 684538d1ba755f4f92b99b1dac1dab4cba71e027 Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Sat, 12 Jul 2025 01:13:20 +0400 Subject: [PATCH 8/9] Clear out out of scope configs --- .claude/settings.local.json | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index d9f3b944..00000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(ls:*)", - "Bash(export PATH=$PATH:/Users/alexbel/go/bin)", - "Bash(opencode:*)", - "Bash(go build:*)", - "Bash(go install:*)", - "Bash(go run:*)", - "Bash(go clean:*)", - "Bash(rm:*)", - "Bash(find:*)", - "Bash(sqlite3:*)", - "Bash(cp:*)", - "Bash(go test:*)" - ], - "deny": [] - } -} \ No newline at end of file From cebdf6a3046509da4222f3c27328c86d140b9554 Mon Sep 17 00:00:00 2001 From: Alex Belets Date: Sat, 12 Jul 2025 03:21:30 +0400 Subject: [PATCH 9/9] feat: Add provider-based tool filtering for model compatibility - Implement FilterToolsByProvider to restrict tools based on provider capabilities - Add Providers field to ToolInfo struct for specifying tool availability - Integrate filtering across all provider clients (OpenAI, Gemini, xAI, Copilot) - Add comprehensive tests for tool filtering logic with case-insensitive matching - Update web search tool to be xAI-only as it requires live search capabilities - Clarify Grok 4 reasoning behavior in docs and comments (internal reasoning only) --- README.md | 4 +- internal/llm/agent/tools_test.go | 5 + internal/llm/models/xai.go | 6 +- internal/llm/provider/copilot.go | 8 +- internal/llm/provider/gemini.go | 40 ++++- internal/llm/provider/openai.go | 13 +- internal/llm/provider/provider.go | 2 +- internal/llm/provider/tool_filter.go | 50 ++++++ internal/llm/provider/tool_filter_test.go | 191 ++++++++++++++++++++++ internal/llm/provider/xai.go | 43 ++++- internal/llm/provider/xai_deferred.go | 6 +- internal/llm/provider/xai_reasoning.go | 106 +++++++----- internal/llm/tools/web_search.go | 177 ++++++++++---------- internal/llm/tools/web_search_test.go | 5 + 14 files changed, 500 insertions(+), 156 deletions(-) create mode 100644 internal/llm/provider/tool_filter.go create mode 100644 internal/llm/provider/tool_filter_test.go diff --git a/README.md b/README.md index aeb2f465..7323f3d1 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ OpenCode supports a variety of AI models from different providers: ### xAI -- Grok 4 (grok-4-0709) - Most capable, with reasoning_effort support +- Grok 4 (grok-4-0709) - Most capable, with internal reasoning - Grok 3 (grok-3) - Advanced model (no reasoning support) - Grok 3 Fast (grok-3-fast) - Optimized for speed (no reasoning support) - Grok 3 Mini (grok-3-mini) - Smaller model with reasoning_effort support @@ -261,7 +261,7 @@ OpenCode supports a variety of AI models from different providers: **Special Features:** - **Web Search**: All xAI models support live web search for current information - **Reasoning Support** (verified via API): - - Grok 4 (grok-4-0709): Has automatic reasoning (returns reasoning_content) but does NOT accept `reasoningEffort` parameter + - Grok 4 (grok-4-0709): Has internal reasoning capabilities but does NOT expose reasoning_content or accept reasoningEffort parameter - Grok 3 Mini models: Support `reasoningEffort` parameter (only "low" or "high", not "medium") - Grok 2 models, Grok 3/3-fast: No reasoning support - **Vision Support**: grok-2-vision-1212 supports image understanding diff --git a/internal/llm/agent/tools_test.go b/internal/llm/agent/tools_test.go index 07cbc7ac..7f680406 100644 --- a/internal/llm/agent/tools_test.go +++ b/internal/llm/agent/tools_test.go @@ -76,6 +76,11 @@ func TestCoderAgentToolsIncludesWebSearch(t *testing.T) { if len(info.Required) == 0 || info.Required[0] != "query" { t.Error("Web search tool should require 'query' parameter") } + + // Verify provider restrictions + if len(info.Providers) != 1 || info.Providers[0] != "xai" { + t.Errorf("Web search tool should be restricted to xai provider, got %v", info.Providers) + } } else { t.Error("Web search tool does not implement BaseTool interface correctly") } diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go index 95877e0d..03c23a7d 100644 --- a/internal/llm/models/xai.go +++ b/internal/llm/models/xai.go @@ -2,7 +2,7 @@ package models // xAI Model Capabilities (verified via API testing): // - Reasoning support: -// - grok-4-0709: Has reasoning (returns reasoning_content) but does NOT accept reasoning_effort parameter +// - grok-4-0709: Has internal reasoning capabilities but does NOT expose reasoning_content or accept reasoning_effort parameter // - grok-3-mini, grok-3-mini-fast: Support reasoning_effort parameter ("low" or "high" only, NOT "medium") // - grok-2 models, grok-3, grok-3-fast: No reasoning support // - Vision support: grok-2-vision-1212 and grok-4 support image understanding @@ -136,9 +136,9 @@ var XAIModels = map[ModelID]Model{ CostPer1MOutCached: 0, ContextWindow: 131_072, DefaultMaxTokens: 20_000, - CanReason: true, // Automatic reasoning (no reasoning_effort parameter) + CanReason: true, // Has reasoning capabilities but doesn't expose reasoning content SupportsAttachments: true, // Grok 4 supports vision SupportsImageGeneration: false, // Will be detected dynamically via API - // Capabilities: streaming, function calling, structured outputs, automatic reasoning, web search, vision + // Capabilities: streaming, function calling, structured outputs, web search, vision }, } diff --git a/internal/llm/provider/copilot.go b/internal/llm/provider/copilot.go index 5d70e718..cb820731 100644 --- a/internal/llm/provider/copilot.go +++ b/internal/llm/provider/copilot.go @@ -247,9 +247,12 @@ func (c *copilotClient) convertMessages(messages []message.Message) (copilotMess } func (c *copilotClient) convertTools(tools []toolsPkg.BaseTool) []openai.ChatCompletionToolParam { - copilotTools := make([]openai.ChatCompletionToolParam, len(tools)) + // Filter tools based on provider compatibility + providerName := string(c.providerOptions.model.Provider) + filteredTools := FilterToolsByProvider(tools, providerName) - for i, tool := range tools { + copilotTools := make([]openai.ChatCompletionToolParam, len(filteredTools)) + for i, tool := range filteredTools { info := tool.Info() copilotTools[i] = openai.ChatCompletionToolParam{ Function: openai.FunctionDefinitionParam{ @@ -668,4 +671,3 @@ func WithCopilotBearerToken(bearerToken string) CopilotOption { options.bearerToken = bearerToken } } - diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index ebc36119..f7afe176 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -132,11 +132,15 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont return history } -func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { +func (g *geminiClient) convertTools(inputTools []tools.BaseTool) []*genai.Tool { + // Filter tools based on provider compatibility + providerName := string(g.providerOptions.model.Provider) + filteredTools := FilterToolsByProvider(inputTools, providerName) + geminiTool := &genai.Tool{} - geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools)) + geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(filteredTools)) - for _, tool := range tools { + for _, tool := range filteredTools { info := tool.Info() declaration := &genai.FunctionDeclaration{ Name: info.Name, @@ -225,7 +229,11 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too content = string(part.Text) case part.FunctionCall != nil: id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) + args, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + logging.Error("Failed to marshal function call args", "error", err, "function", part.FunctionCall.Name) + args = []byte("{}") + } toolCalls = append(toolCalls, message.ToolCall{ ID: id, Name: part.FunctionCall.Name, @@ -274,7 +282,17 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) + chat, err := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history) + if err != nil { + eventChan := make(chan ProviderEvent, 1) + go func() { + defer close(eventChan) + eventChan <- ProviderEvent{ + Error: fmt.Errorf("failed to create Gemini chat: %w", err), + } + }() + return eventChan + } attempts := 0 eventChan := make(chan ProviderEvent) @@ -337,7 +355,11 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t } case part.FunctionCall != nil: id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) + args, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + logging.Error("Failed to marshal function call args", "error", err, "function", part.FunctionCall.Name) + args = []byte("{}") + } newCall := message.ToolCall{ ID: id, Name: part.FunctionCall.Name, @@ -430,7 +452,11 @@ func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message. for _, part := range resp.Candidates[0].Content.Parts { if part.FunctionCall != nil { id := "call_" + uuid.New().String() - args, _ := json.Marshal(part.FunctionCall.Args) + args, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + logging.Error("Failed to marshal function call args", "error", err, "function", part.FunctionCall.Name) + args = []byte("{}") + } toolCalls = append(toolCalls, message.ToolCall{ ID: id, Name: part.FunctionCall.Name, diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 0f956267..2bd10eb3 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -128,10 +128,13 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag return } -func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { - openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) +func (o *openaiClient) convertTools(inputTools []tools.BaseTool) []openai.ChatCompletionToolParam { + // Filter tools based on provider compatibility + providerName := string(o.providerOptions.model.Provider) + filteredTools := FilterToolsByProvider(inputTools, providerName) - for i, tool := range tools { + openaiTools := make([]openai.ChatCompletionToolParam, len(filteredTools)) + for i, tool := range filteredTools { info := tool.Info() openaiTools[i] = openai.ChatCompletionToolParam{ Function: openai.FunctionDefinitionParam{ @@ -221,9 +224,9 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar // shouldApplyReasoningEffort determines if the reasoning_effort parameter should be applied // based on the model and provider. Some models support reasoning but do not accept -// the reasoning_effort parameter (e.g., xAI's grok-4 has automatic reasoning). +// the reasoning_effort parameter (e.g., xAI's grok-4 has internal reasoning but doesn't accept the parameter). func (o *openaiClient) shouldApplyReasoningEffort() bool { - // xAI grok-4 supports reasoning but does not accept reasoning_effort parameter + // xAI grok-4 has internal reasoning but does not accept reasoning_effort parameter if o.providerOptions.model.Provider == models.ProviderXAI && o.providerOptions.model.ID == models.XAIGrok4 { return false diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 16357796..aa228d2a 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -36,7 +36,7 @@ type TokenUsage struct { type ProviderResponse struct { Content string - ReasoningContent string // For xAI reasoning content (internal use) + ReasoningContent string // For xAI reasoning content (internal use) ToolCalls []message.ToolCall Usage TokenUsage FinishReason message.FinishReason diff --git a/internal/llm/provider/tool_filter.go b/internal/llm/provider/tool_filter.go new file mode 100644 index 00000000..b37fa7aa --- /dev/null +++ b/internal/llm/provider/tool_filter.go @@ -0,0 +1,50 @@ +package provider + +import ( + "strings" + + "github.com/opencode-ai/opencode/internal/llm/tools" +) + +// FilterToolsByProvider filters tools based on provider compatibility. +// If a tool has no provider restrictions (empty Providers field), it's available to all providers. +// Otherwise, it's only available to providers listed in the Providers field. +// The provider name comparison is case-insensitive. +func FilterToolsByProvider(inputTools []tools.BaseTool, providerName string) []tools.BaseTool { + if len(inputTools) == 0 { + return nil + } + + // Pre-allocate slice with capacity to avoid reallocation + // In most cases, most tools will be available + filteredTools := make([]tools.BaseTool, 0, len(inputTools)) + + for _, tool := range inputTools { + if isToolAvailableForProvider(tool, providerName) { + filteredTools = append(filteredTools, tool) + } + } + + return filteredTools +} + +// isToolAvailableForProvider checks if a tool is available for the given provider. +// Returns true if the tool has no provider restrictions or if the provider is in the allowed list. +// The comparison is case-insensitive to handle variations in provider name casing. +func isToolAvailableForProvider(tool tools.BaseTool, providerName string) bool { + info := tool.Info() + + // If no providers specified, tool is universally available + if len(info.Providers) == 0 { + return true + } + + // Check if this provider is in the allowed list (case-insensitive) + for _, allowedProvider := range info.Providers { + if strings.EqualFold(allowedProvider, providerName) { + return true + } + } + + return false +} diff --git a/internal/llm/provider/tool_filter_test.go b/internal/llm/provider/tool_filter_test.go new file mode 100644 index 00000000..4285a713 --- /dev/null +++ b/internal/llm/provider/tool_filter_test.go @@ -0,0 +1,191 @@ +package provider + +import ( + "context" + "testing" + + "github.com/opencode-ai/opencode/internal/llm/tools" +) + +// mockTool is a test implementation of tools.BaseTool +type mockTool struct { + name string + providers []string +} + +func (m *mockTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: m.name, + Description: "Mock tool for testing", + Parameters: map[string]any{"test": "param"}, + Required: []string{"test"}, + Providers: m.providers, + } +} + +func (m *mockTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + return tools.ToolResponse{ + Type: tools.ToolResponseTypeText, + Content: "mock response", + }, nil +} + +func TestFilterToolsByProvider(t *testing.T) { + tests := []struct { + name string + tools []tools.BaseTool + providerName string + wantCount int + wantTools []string + }{ + { + name: "no provider restrictions - all tools available", + tools: []tools.BaseTool{ + &mockTool{name: "tool1", providers: []string{}}, + &mockTool{name: "tool2", providers: nil}, + &mockTool{name: "tool3", providers: []string{}}, + }, + providerName: "openai", + wantCount: 3, + wantTools: []string{"tool1", "tool2", "tool3"}, + }, + { + name: "provider specific tools - only matching tools returned", + tools: []tools.BaseTool{ + &mockTool{name: "universal", providers: []string{}}, + &mockTool{name: "xai_only", providers: []string{"xai"}}, + &mockTool{name: "openai_only", providers: []string{"openai"}}, + &mockTool{name: "multi_provider", providers: []string{"openai", "anthropic"}}, + }, + providerName: "openai", + wantCount: 3, + wantTools: []string{"universal", "openai_only", "multi_provider"}, + }, + { + name: "case insensitive provider matching", + tools: []tools.BaseTool{ + &mockTool{name: "tool1", providers: []string{"OpenAI"}}, + &mockTool{name: "tool2", providers: []string{"OPENAI"}}, + &mockTool{name: "tool3", providers: []string{"openai"}}, + }, + providerName: "openai", + wantCount: 3, + wantTools: []string{"tool1", "tool2", "tool3"}, + }, + { + name: "xai provider with web search tool", + tools: []tools.BaseTool{ + &mockTool{name: "general_tool", providers: []string{}}, + &mockTool{name: "web_search", providers: []string{"xai"}}, + &mockTool{name: "other_xai_tool", providers: []string{"xai"}}, + }, + providerName: "xai", + wantCount: 3, + wantTools: []string{"general_tool", "web_search", "other_xai_tool"}, + }, + { + name: "no matching tools for provider", + tools: []tools.BaseTool{ + &mockTool{name: "xai_only", providers: []string{"xai"}}, + &mockTool{name: "openai_only", providers: []string{"openai"}}, + }, + providerName: "anthropic", + wantCount: 0, + wantTools: []string{}, + }, + { + name: "empty tool list", + tools: []tools.BaseTool{}, + providerName: "openai", + wantCount: 0, + wantTools: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filtered := FilterToolsByProvider(tt.tools, tt.providerName) + + if len(filtered) != tt.wantCount { + t.Errorf("FilterToolsByProvider() returned %d tools, want %d", len(filtered), tt.wantCount) + } + + // Verify the correct tools were returned + gotTools := make(map[string]bool) + for _, tool := range filtered { + gotTools[tool.Info().Name] = true + } + + for _, wantTool := range tt.wantTools { + if !gotTools[wantTool] { + t.Errorf("Expected tool %s not found in filtered results", wantTool) + } + } + + // Ensure no extra tools were included + for toolName := range gotTools { + found := false + for _, wantTool := range tt.wantTools { + if toolName == wantTool { + found = true + break + } + } + if !found { + t.Errorf("Unexpected tool %s found in filtered results", toolName) + } + } + }) + } +} + +func BenchmarkFilterToolsByProvider(b *testing.B) { + // Create a mix of tools with different provider restrictions + tools := []tools.BaseTool{ + &mockTool{name: "universal1", providers: []string{}}, + &mockTool{name: "universal2", providers: nil}, + &mockTool{name: "xai_only", providers: []string{"xai"}}, + &mockTool{name: "openai_only", providers: []string{"openai"}}, + &mockTool{name: "multi_provider1", providers: []string{"openai", "anthropic"}}, + &mockTool{name: "multi_provider2", providers: []string{"xai", "gemini", "openai"}}, + &mockTool{name: "anthropic_only", providers: []string{"anthropic"}}, + &mockTool{name: "gemini_only", providers: []string{"gemini"}}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = FilterToolsByProvider(tools, "openai") + } +} + +func TestFilterToolsByProvider_EdgeCases(t *testing.T) { + t.Run("nil tools slice", func(t *testing.T) { + result := FilterToolsByProvider(nil, "openai") + if result != nil { + t.Errorf("Expected nil result for nil input, got %v", result) + } + }) + + t.Run("empty provider name", func(t *testing.T) { + tools := []tools.BaseTool{ + &mockTool{name: "tool1", providers: []string{""}}, + &mockTool{name: "tool2", providers: []string{"openai", ""}}, + } + result := FilterToolsByProvider(tools, "") + if len(result) != 2 { + t.Errorf("Expected 2 tools for empty provider match, got %d", len(result)) + } + }) + + t.Run("whitespace in provider names", func(t *testing.T) { + tools := []tools.BaseTool{ + &mockTool{name: "tool1", providers: []string{" openai "}}, + &mockTool{name: "tool2", providers: []string{"openai"}}, + } + result := FilterToolsByProvider(tools, "openai") + // Should only match tool2 since we don't trim whitespace + if len(result) != 1 || result[0].Info().Name != "tool2" { + t.Errorf("Expected only tool2 to match, got %d tools", len(result)) + } + }) +} \ No newline at end of file diff --git a/internal/llm/provider/xai.go b/internal/llm/provider/xai.go index 1810b543..b75ef83b 100644 --- a/internal/llm/provider/xai.go +++ b/internal/llm/provider/xai.go @@ -13,6 +13,16 @@ import ( "github.com/opencode-ai/opencode/internal/message" ) +const ( + // defaultXAITimeout is the standard timeout for regular API requests + defaultXAITimeout = 30 * time.Second + // reasoningXAITimeout is an extended timeout for reasoning models + // which can take several minutes to process complex requests + reasoningXAITimeout = 5 * time.Minute + // defaultXAIBaseURL is the default base URL for xAI API + defaultXAIBaseURL = "https://api.x.ai/v1" +) + // FingerprintRecord tracks system fingerprint information for auditing and compliance purposes. // It helps monitor xAI system changes and optimize caching performance. type FingerprintRecord struct { @@ -96,11 +106,15 @@ func WithLiveSearchOptions(opts LiveSearchOptions) XAIOption { func newXAIClient(opts providerClientOptions) XAIClient { // Create base OpenAI client with xAI-specific settings opts.openaiOptions = append(opts.openaiOptions, - WithOpenAIBaseURL("https://api.x.ai/v1"), + WithOpenAIBaseURL(defaultXAIBaseURL), ) baseClient := newOpenAIClient(opts) - openaiClientImpl := baseClient.(*openaiClient) + openaiClientImpl, ok := baseClient.(*openaiClient) + if !ok { + // This is a programming error - xAI client extends openAI client + panic("internal error: xAI client requires openaiClient implementation") + } xClient := &xaiClient{ openaiClient: *openaiClientImpl, @@ -109,11 +123,21 @@ func newXAIClient(opts providerClientOptions) XAIClient { // Initialize new architectural components xClient.reasoningHandler = NewReasoningHandler(xClient) + + // Use the base URL from OpenAI client options if set, otherwise default to xAI API + baseURL := defaultXAIBaseURL + if openaiClientImpl.options.baseURL != "" { + baseURL = openaiClientImpl.options.baseURL + } + + // Configure HTTP client with appropriate timeout based on model capabilities + timeout := selectTimeoutForModel(opts.model) + xClient.httpClient = NewXAIHTTPClient(HTTPClientConfig{ - BaseURL: "https://api.x.ai/v1", + BaseURL: baseURL, APIKey: opts.apiKey, UserAgent: "opencode/1.0", - Timeout: 30 * time.Second, + Timeout: timeout, }) // Apply xAI-specific options if any @@ -124,9 +148,18 @@ func newXAIClient(opts providerClientOptions) XAIClient { return xClient } +// selectTimeoutForModel returns the appropriate timeout based on model capabilities +func selectTimeoutForModel(model models.Model) time.Duration { + if model.CanReason { + return reasoningXAITimeout + } + return defaultXAITimeout +} + // shouldApplyReasoningEffort overrides the base implementation for xAI-specific logic func (x *xaiClient) shouldApplyReasoningEffort() bool { - // xAI grok-4 supports reasoning but does not accept reasoning_effort parameter + // Grok-4 has internal reasoning capabilities but doesn't accept the reasoning_effort parameter + // or expose its reasoning process. Other xAI thinking models accept reasoning_effort with values "low" or "high" if x.providerOptions.model.ID == models.XAIGrok4 { return false } diff --git a/internal/llm/provider/xai_deferred.go b/internal/llm/provider/xai_deferred.go index 5e83576b..36410a8a 100644 --- a/internal/llm/provider/xai_deferred.go +++ b/internal/llm/provider/xai_deferred.go @@ -497,7 +497,11 @@ func (x *xaiClient) convertMessagesToAPI(messages []message.Message) []map[strin func (x *xaiClient) convertToolsToAPI(tools []tools.BaseTool) []map[string]interface{} { var apiTools []map[string]interface{} - for _, tool := range tools { + // Filter tools based on provider compatibility + providerName := string(x.providerOptions.model.Provider) + filteredTools := FilterToolsByProvider(tools, providerName) + + for _, tool := range filteredTools { info := tool.Info() // Check if Parameters already contains the full schema (with "type" and "properties") diff --git a/internal/llm/provider/xai_reasoning.go b/internal/llm/provider/xai_reasoning.go index bdc74413..2b76db3b 100644 --- a/internal/llm/provider/xai_reasoning.go +++ b/internal/llm/provider/xai_reasoning.go @@ -30,31 +30,57 @@ func NewReasoningHandler(client *xaiClient) *ReasoningHandler { } } -// ShouldUseReasoning determines if reasoning should be used for a request +// ShouldUseReasoning determines if reasoning should be used for a request. +// It checks if the model supports reasoning and handles special cases like Grok-4. func (r *ReasoningHandler) ShouldUseReasoning() bool { - canReason := r.client.providerOptions.model.CanReason - hasReasoningEffort := r.client.options.reasoningEffort != "" + model := r.client.providerOptions.model + + // Early return if model doesn't support reasoning + if !model.CanReason { + return false + } + + // Special case: Grok-4 always uses reasoning handler when it can reason + // even though it doesn't accept reasoning_effort parameter + if model.ID == models.XAIGrok4 { + return true + } + + // For other models, check if reasoning effort is configured + reasoningEffort := r.client.options.reasoningEffort + if reasoningEffort == "" { + return false + } + + // Check if reasoning should be applied based on client-specific logic shouldApply := r.client.shouldApplyReasoningEffort() - // Special case for Grok 4: it has automatic reasoning and should always use - // the reasoning handler when CanReason is true, regardless of reasoning_effort - isGrok4 := r.client.providerOptions.model.ID == models.XAIGrok4 + logging.Debug("Reasoning conditions evaluated", + "model", model.ID, + "can_reason", model.CanReason, + "reasoning_effort", reasoningEffort, + "should_apply", shouldApply) - logging.Debug("Checking reasoning conditions", - "model", r.client.providerOptions.model.ID, - "can_reason", canReason, - "reasoning_effort", r.client.options.reasoningEffort, - "has_reasoning_effort", hasReasoningEffort, - "should_apply", shouldApply, - "is_grok4", isGrok4) + return shouldApply +} - // Grok 4 uses reasoning handler whenever it can reason - if isGrok4 && canReason { - return true +// normalizeReasoningEffort adjusts reasoning effort values based on model capabilities. +// xAI's thinking models (all except Grok-4) only support "low" or "high", not "medium". +// Grok-4 has internal reasoning but doesn't accept the reasoning_effort parameter or expose reasoning content. +func (r *ReasoningHandler) normalizeReasoningEffort(effort string) string { + model := r.client.providerOptions.model + + // All xAI thinking models except Grok-4 only support "low" or "high" + // Grok-4 has internal reasoning but doesn't accept this parameter + if model.ID != models.XAIGrok4 && effort == "medium" { + logging.Debug("Normalizing reasoning effort for xAI thinking model", + "model", model.ID, + "original", effort, + "normalized", "high") + return "high" } - // Other models need reasoning effort to be set and applicable - return canReason && hasReasoningEffort && shouldApply + return effort } // ProcessReasoningResponse handles reasoning content from API responses @@ -96,16 +122,24 @@ func (r *ReasoningHandler) ProcessReasoningResponse(response *ProviderResponse) // sanitizeReasoningContent removes control characters that could corrupt terminal display func (r *ReasoningHandler) sanitizeReasoningContent(content string) string { - // Remove ANSI escape sequences (ESC character) - content = strings.ReplaceAll(content, "\x1b", "") - // Remove carriage returns (which can cause display issues) - content = strings.ReplaceAll(content, "\r", "") - // Remove other control characters that might cause issues - content = strings.ReplaceAll(content, "\x00", "") // null - content = strings.ReplaceAll(content, "\x07", "") // bell - content = strings.ReplaceAll(content, "\x08", "") // backspace - // Replace form feed with newline to preserve structure - content = strings.ReplaceAll(content, "\x0c", "\n") + // Define control characters to remove + replacements := []struct { + old string + new string + }{ + {"\x1b", ""}, // ANSI escape sequences (ESC character) + {"\r", ""}, // carriage returns (can cause display issues) + {"\x00", ""}, // null + {"\x07", ""}, // bell + {"\x08", ""}, // backspace + {"\x0c", "\n"}, // form feed - replace with newline to preserve structure + } + + // Apply all replacements + for _, repl := range replacements { + content = strings.ReplaceAll(content, repl.old, repl.new) + } + return content } @@ -123,20 +157,10 @@ func (r *ReasoningHandler) BuildReasoningRequest(ctx context.Context, messages [ reqBody["tools"] = r.client.convertToolsToAPI(tools) } - // Apply reasoning effort only if the model supports it - // xAI grok models do not accept reasoning_effort parameter + // Apply reasoning effort parameter for models that support it + // Grok-4 has internal reasoning but doesn't accept this parameter if r.client.options.reasoningEffort != "" && r.client.shouldApplyReasoningEffort() { - reasoningEffort := r.client.options.reasoningEffort - - // Grok-3-mini models only support "low" or "high", not "medium" - if (r.client.providerOptions.model.ID == models.XAIGrok3Mini || - r.client.providerOptions.model.ID == models.XAIGrok3MiniFast) && - reasoningEffort == "medium" { - // Convert medium to high for Grok-3-mini models - reasoningEffort = "high" - logging.Debug("Converting reasoning effort from medium to high for Grok-3-mini") - } - + reasoningEffort := r.normalizeReasoningEffort(r.client.options.reasoningEffort) reqBody["reasoning_effort"] = reasoningEffort } diff --git a/internal/llm/tools/web_search.go b/internal/llm/tools/web_search.go index ba6e109d..24ebb0b5 100644 --- a/internal/llm/tools/web_search.go +++ b/internal/llm/tools/web_search.go @@ -51,99 +51,100 @@ func (t *WebSearchTool) Info() ToolInfo { "type": "string", "description": "The search query to execute", }, - "mode": map[string]interface{}{ - "type": "string", - "description": "Search mode: 'auto' (default), 'on', or 'off'", - "enum": []string{"auto", "on", "off"}, - }, - "max_search_results": map[string]interface{}{ - "type": "integer", - "description": "Maximum number of search results (1-20, default: 20)", - "minimum": 1, - "maximum": 20, - }, - "from_date": map[string]interface{}{ - "type": "string", - "description": "Start date for search results in YYYY-MM-DD format", - "pattern": "^\\d{4}-\\d{2}-\\d{2}$", - }, - "to_date": map[string]interface{}{ - "type": "string", - "description": "End date for search results in YYYY-MM-DD format", - "pattern": "^\\d{4}-\\d{2}-\\d{2}$", - }, - "return_citations": map[string]interface{}{ - "type": "boolean", - "description": "Whether to return citations (default: true)", - }, - "sources": map[string]interface{}{ - "type": "array", - "description": "List of data sources to search", - "items": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "type": map[string]interface{}{ - "type": "string", - "description": "Source type", - "enum": []string{"web", "x", "news", "rss"}, - }, - "country": map[string]interface{}{ - "type": "string", - "description": "ISO alpha-2 country code (web, news)", - "pattern": "^[A-Z]{2}$", - }, - "excluded_websites": map[string]interface{}{ - "type": "array", - "description": "Websites to exclude (max 5, web/news)", - "items": map[string]interface{}{"type": "string"}, - "maxItems": 5, - }, - "allowed_websites": map[string]interface{}{ - "type": "array", - "description": "Allowed websites only (max 5, web only)", - "items": map[string]interface{}{"type": "string"}, - "maxItems": 5, - }, - "safe_search": map[string]interface{}{ - "type": "boolean", - "description": "Enable safe search (default: true, web/news)", - }, - "included_x_handles": map[string]interface{}{ - "type": "array", - "description": "X handles to include (max 10, x only)", - "items": map[string]interface{}{"type": "string"}, - "maxItems": 10, - }, - "excluded_x_handles": map[string]interface{}{ - "type": "array", - "description": "X handles to exclude (max 10, x only)", - "items": map[string]interface{}{"type": "string"}, - "maxItems": 10, - }, - "post_favorite_count": map[string]interface{}{ - "type": "integer", - "description": "Minimum favorite count for X posts", - "minimum": 0, - }, - "post_view_count": map[string]interface{}{ - "type": "integer", - "description": "Minimum view count for X posts", - "minimum": 0, - }, - "links": map[string]interface{}{ - "type": "array", - "description": "RSS feed URLs (1 link max, rss only)", - "items": map[string]interface{}{"type": "string", "format": "uri"}, - "maxItems": 1, - }, + "mode": map[string]interface{}{ + "type": "string", + "description": "Search mode: 'auto' (default), 'on', or 'off'", + "enum": []string{"auto", "on", "off"}, + }, + "max_search_results": map[string]interface{}{ + "type": "integer", + "description": "Maximum number of search results (1-20, default: 20)", + "minimum": 1, + "maximum": 20, + }, + "from_date": map[string]interface{}{ + "type": "string", + "description": "Start date for search results in YYYY-MM-DD format", + "pattern": "^\\d{4}-\\d{2}-\\d{2}$", + }, + "to_date": map[string]interface{}{ + "type": "string", + "description": "End date for search results in YYYY-MM-DD format", + "pattern": "^\\d{4}-\\d{2}-\\d{2}$", + }, + "return_citations": map[string]interface{}{ + "type": "boolean", + "description": "Whether to return citations (default: true)", + }, + "sources": map[string]interface{}{ + "type": "array", + "description": "List of data sources to search", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "type": map[string]interface{}{ + "type": "string", + "description": "Source type", + "enum": []string{"web", "x", "news", "rss"}, + }, + "country": map[string]interface{}{ + "type": "string", + "description": "ISO alpha-2 country code (web, news)", + "pattern": "^[A-Z]{2}$", + }, + "excluded_websites": map[string]interface{}{ + "type": "array", + "description": "Websites to exclude (max 5, web/news)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 5, + }, + "allowed_websites": map[string]interface{}{ + "type": "array", + "description": "Allowed websites only (max 5, web only)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 5, + }, + "safe_search": map[string]interface{}{ + "type": "boolean", + "description": "Enable safe search (default: true, web/news)", + }, + "included_x_handles": map[string]interface{}{ + "type": "array", + "description": "X handles to include (max 10, x only)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 10, + }, + "excluded_x_handles": map[string]interface{}{ + "type": "array", + "description": "X handles to exclude (max 10, x only)", + "items": map[string]interface{}{"type": "string"}, + "maxItems": 10, + }, + "post_favorite_count": map[string]interface{}{ + "type": "integer", + "description": "Minimum favorite count for X posts", + "minimum": 0, + }, + "post_view_count": map[string]interface{}{ + "type": "integer", + "description": "Minimum view count for X posts", + "minimum": 0, + }, + "links": map[string]interface{}{ + "type": "array", + "description": "RSS feed URLs (1 link max, rss only)", + "items": map[string]interface{}{"type": "string", "format": "uri"}, + "maxItems": 1, }, - "required": []string{"type"}, }, + "required": []string{"type"}, }, }, - "required": []string{"query"}, }, - Required: []string{"query"}, + "required": []string{"query"}, + }, + Required: []string{"query"}, + Providers: []string{"xai"}, // Web search is currently only supported by xAI models } } diff --git a/internal/llm/tools/web_search_test.go b/internal/llm/tools/web_search_test.go index a4f4be5e..c875f779 100644 --- a/internal/llm/tools/web_search_test.go +++ b/internal/llm/tools/web_search_test.go @@ -70,6 +70,11 @@ func TestWebSearchTool(t *testing.T) { if len(info.Required) != 1 || info.Required[0] != "query" { t.Errorf("Expected required fields to be ['query'], got %v", info.Required) } + + // Check provider restriction + if len(info.Providers) != 1 || info.Providers[0] != "xai" { + t.Errorf("Expected providers to be ['xai'], got %v", info.Providers) + } }) t.Run("Run with valid query", func(t *testing.T) {