Skip to content

Commit c3ae347

Browse files
authored
Refactor(ai): replace dynamic metadata maps (refactor/ai-typed-metadata) (#28)
1 parent 54aeb07 commit c3ae347

File tree

4 files changed

+118
-61
lines changed

4 files changed

+118
-61
lines changed

pkg/ai/capabilities.go

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,24 @@ type CapabilitiesRequest struct {
3838

3939
// CapabilitiesResponse defines the complete capability information for the AI plugin
4040
type CapabilitiesResponse struct {
41-
Version string `json:"version"`
42-
Models []ModelCapability `json:"models"`
43-
Databases []DatabaseCapability `json:"databases"`
44-
Features []FeatureCapability `json:"features"`
45-
Health HealthStatusReport `json:"health"`
46-
Limits ResourceLimits `json:"limits"`
47-
LastUpdated time.Time `json:"last_updated"`
48-
Metadata map[string]interface{} `json:"metadata,omitempty"`
41+
Version string `json:"version"`
42+
Models []ModelCapability `json:"models"`
43+
Databases []DatabaseCapability `json:"databases"`
44+
Features []FeatureCapability `json:"features"`
45+
Health HealthStatusReport `json:"health"`
46+
Limits ResourceLimits `json:"limits"`
47+
LastUpdated time.Time `json:"last_updated"`
48+
Metadata CapabilityMetadata `json:"metadata"`
49+
}
50+
51+
// CapabilityMetadata provides contextual information about the response.
52+
type CapabilityMetadata struct {
53+
CacheStatus string `json:"cache_status"`
54+
Source string `json:"source"`
55+
ModelCount int `json:"model_count"`
56+
FeatureCount int `json:"feature_count"`
57+
HealthChecked bool `json:"health_checked"`
58+
GeneratedAt time.Time `json:"generated_at"`
4959
}
5060

5161
// ModelCapability represents the capabilities of an AI model
@@ -206,10 +216,17 @@ func (d *CapabilityDetector) GetCapabilities(ctx context.Context, req *Capabilit
206216
}
207217

208218
// Build fresh capability response
219+
now := time.Now()
220+
metadata := CapabilityMetadata{
221+
CacheStatus: "miss",
222+
Source: "live",
223+
HealthChecked: req.CheckHealth,
224+
GeneratedAt: now,
225+
}
209226
response := &CapabilitiesResponse{
210227
Version: "1.0.0",
211-
LastUpdated: time.Now(),
212-
Metadata: make(map[string]interface{}),
228+
LastUpdated: now,
229+
Metadata: metadata,
213230
}
214231

215232
// Collect models if requested
@@ -219,6 +236,7 @@ func (d *CapabilityDetector) GetCapabilities(ctx context.Context, req *Capabilit
219236
return nil, fmt.Errorf("failed to detect model capabilities: %w", err)
220237
}
221238
response.Models = models
239+
metadata.ModelCount = len(models)
222240
}
223241

224242
// Collect database capabilities if requested
@@ -228,7 +246,9 @@ func (d *CapabilityDetector) GetCapabilities(ctx context.Context, req *Capabilit
228246

229247
// Collect feature capabilities if requested
230248
if req.IncludeFeatures {
231-
response.Features = d.detectFeatureCapabilities()
249+
features := d.detectFeatureCapabilities()
250+
response.Features = features
251+
metadata.FeatureCount = len(features)
232252
}
233253

234254
// Perform health checks if requested
@@ -243,6 +263,9 @@ func (d *CapabilityDetector) GetCapabilities(ctx context.Context, req *Capabilit
243263
// Always include resource limits
244264
response.Limits = d.getResourceLimits()
245265

266+
// Attach finalized metadata snapshot
267+
response.Metadata = metadata
268+
246269
// Update cache
247270
d.cache.update(response)
248271
d.lastUpdate = time.Now()
@@ -646,10 +669,13 @@ func (d *CapabilityDetector) getCachedCapabilities(req *CapabilitiesRequest) (*C
646669
}
647670

648671
// Create a filtered response based on request
672+
metadata := d.cache.data.Metadata
673+
metadata.HealthChecked = req.CheckHealth
674+
649675
response := &CapabilitiesResponse{
650676
Version: d.cache.data.Version,
651677
LastUpdated: d.cache.data.LastUpdated,
652-
Metadata: d.cache.data.Metadata,
678+
Metadata: metadata,
653679
}
654680

655681
if req.IncludeModels {

pkg/ai/manager.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ var (
5151
ErrInvalidConfig = errors.New("invalid configuration")
5252
)
5353

54+
// ProviderConfigInfo captures metadata about a provider's requirements.
55+
type ProviderConfigInfo struct {
56+
RequiresAPIKey bool `json:"requires_api_key"`
57+
ProviderType string `json:"provider_type"`
58+
}
59+
5460
// ProviderInfo represents information about an AI provider
5561
type ProviderInfo struct {
5662
Name string `json:"name"`
@@ -59,7 +65,7 @@ type ProviderInfo struct {
5965
Endpoint string `json:"endpoint"`
6066
Models []interfaces.ModelInfo `json:"models"`
6167
LastChecked time.Time `json:"last_checked"`
62-
Config map[string]interface{} `json:"config,omitempty"`
68+
Config ProviderConfigInfo `json:"config"`
6369
Health *interfaces.HealthStatus `json:"health,omitempty"`
6470
}
6571

@@ -360,6 +366,10 @@ func (m *Manager) DiscoverProviders(ctx context.Context) ([]*ProviderInfo, error
360366
Endpoint: endpoint,
361367
Models: models,
362368
LastChecked: time.Now(),
369+
Config: ProviderConfigInfo{
370+
ProviderType: "local",
371+
RequiresAPIKey: false,
372+
},
363373
}
364374

365375
providers = append(providers, provider)
@@ -630,9 +640,9 @@ func (m *Manager) getOnlineProviders() []*ProviderInfo {
630640
Endpoint: entry.Endpoint,
631641
Models: entry.Models,
632642
LastChecked: time.Now(),
633-
Config: map[string]interface{}{
634-
"requires_api_key": entry.RequiresAPIKey,
635-
"provider_type": providerType,
643+
Config: ProviderConfigInfo{
644+
RequiresAPIKey: entry.RequiresAPIKey,
645+
ProviderType: providerType,
636646
},
637647
})
638648
}

pkg/plugin/service.go

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -229,30 +229,12 @@ func (s *AIPluginService) defaultDatabaseType() string {
229229
return "mysql"
230230
}
231231

232-
func extractDatabaseTypeFromMap(values map[string]any) string {
233-
if values == nil {
234-
return ""
235-
}
236-
237-
keys := []string{"database_type", "databaseDialect", "database_dialect", "dialect"}
238-
for _, key := range keys {
239-
if raw, ok := values[key]; ok {
240-
if str, ok := raw.(string); ok {
241-
if normalized := normalizeDatabaseType(str); normalized != "" {
242-
return normalized
243-
}
244-
}
245-
}
246-
}
247-
return ""
248-
}
249-
250-
func (s *AIPluginService) resolveDatabaseType(explicit string, configMap map[string]any) string {
232+
func (s *AIPluginService) resolveDatabaseType(explicit string, overrides GenerationConfigOverrides) string {
251233
if normalized := normalizeDatabaseType(explicit); normalized != "" {
252234
return normalized
253235
}
254236

255-
if fromConfig := extractDatabaseTypeFromMap(configMap); fromConfig != "" {
237+
if fromConfig := overrides.preferredDatabaseType(); fromConfig != "" {
256238
return fromConfig
257239
}
258240

@@ -703,6 +685,45 @@ var (
703685
MinCompatibleAPITestingVersion = GRPCInterfaceVersion
704686
)
705687

688+
// GenerationMetadata describes metadata returned with AI generation responses.
689+
type GenerationMetadata struct {
690+
Confidence float32 `json:"confidence"`
691+
Model string `json:"model,omitempty"`
692+
Dialect string `json:"dialect"`
693+
}
694+
695+
// CapabilitySummary is returned when the capability detector is unavailable.
696+
type CapabilitySummary struct {
697+
PluginReady bool `json:"plugin_ready"`
698+
AIAvailable bool `json:"ai_available"`
699+
DegradedMode bool `json:"degraded_mode"`
700+
PluginVersion string `json:"plugin_version"`
701+
APIVersion string `json:"api_version"`
702+
}
703+
704+
// GenerationConfigOverrides captures optional generation configuration overrides.
705+
type GenerationConfigOverrides struct {
706+
DatabaseTypePrimary string `json:"database_type"`
707+
DatabaseDialect string `json:"databaseDialect"`
708+
DatabaseDialectAlt string `json:"database_dialect"`
709+
Dialect string `json:"dialect"`
710+
}
711+
712+
func (g GenerationConfigOverrides) preferredDatabaseType() string {
713+
candidates := []string{
714+
g.DatabaseTypePrimary,
715+
g.DatabaseDialect,
716+
g.DatabaseDialectAlt,
717+
g.Dialect,
718+
}
719+
for _, candidate := range candidates {
720+
if normalized := normalizeDatabaseType(candidate); normalized != "" {
721+
return normalized
722+
}
723+
}
724+
return ""
725+
}
726+
706727
// handleAIGenerate handles ai.generate calls
707728
func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.DataQuery) (*server.DataQueryResult, error) {
708729
start := time.Now()
@@ -732,9 +753,9 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data
732753
}
733754

734755
// Parse optional config
735-
var configMap map[string]interface{}
756+
var generationOverrides GenerationConfigOverrides
736757
if params.Config != "" {
737-
if err := json.Unmarshal([]byte(params.Config), &configMap); err != nil {
758+
if err := json.Unmarshal([]byte(params.Config), &generationOverrides); err != nil {
738759
logging.Logger.Warn("Failed to parse config JSON", "error", err)
739760
}
740761
}
@@ -757,7 +778,7 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data
757778
}
758779

759780
// Get database type from configuration, fallback to mysql if not configured
760-
databaseType := s.resolveDatabaseType(params.DatabaseType, configMap)
781+
databaseType := s.resolveDatabaseType(params.DatabaseType, generationOverrides)
761782
context["database_type"] = databaseType
762783

763784
sqlResult, err := s.aiEngine.GenerateSQL(ctx, &ai.GenerateSQLRequest{
@@ -790,12 +811,12 @@ func (s *AIPluginService) handleAIGenerate(ctx context.Context, req *server.Data
790811
simpleFormat := fmt.Sprintf("sql:%s\nexplanation:%s", sqlResult.SQL, sqlResult.Explanation)
791812

792813
// Build minimal meta information for UI display
793-
metaData := map[string]interface{}{
794-
"confidence": sqlResult.ConfidenceScore,
795-
"model": sqlResult.ModelUsed,
796-
"dialect": databaseType,
814+
meta := GenerationMetadata{
815+
Confidence: sqlResult.ConfidenceScore,
816+
Model: sqlResult.ModelUsed,
817+
Dialect: databaseType,
797818
}
798-
metaJSON, err := json.Marshal(metaData)
819+
metaJSON, err := json.Marshal(meta)
799820
if err != nil {
800821
metaJSON = []byte(fmt.Sprintf(`{"confidence": %f, "model": "%s"}`,
801822
sqlResult.ConfidenceScore, sqlResult.ModelUsed))
@@ -851,14 +872,14 @@ func (s *AIPluginService) handleAICapabilities(ctx context.Context, req *server.
851872
if s.capabilityDetector == nil {
852873
logging.Logger.Warn("Capability detector not available - returning minimal capabilities")
853874
// Return minimal capabilities when detector is not available
854-
minimalCaps := map[string]interface{}{
855-
"plugin_ready": true,
856-
"ai_available": false,
857-
"degraded_mode": true,
858-
"plugin_version": PluginVersion,
859-
"api_version": APIVersion,
875+
fallback := CapabilitySummary{
876+
PluginReady: true,
877+
AIAvailable: false,
878+
DegradedMode: true,
879+
PluginVersion: PluginVersion,
880+
APIVersion: APIVersion,
860881
}
861-
capsJSON, _ := json.Marshal(minimalCaps)
882+
capsJSON, _ := json.Marshal(fallback)
862883
return &server.DataQueryResult{
863884
Data: []*server.Pair{
864885
{Key: "api_version", Value: APIVersion},
@@ -1102,12 +1123,12 @@ func (s *AIPluginService) handleLegacyQuery(ctx context.Context, req *server.Dat
11021123
simpleFormat := fmt.Sprintf("sql:%s\nexplanation:%s", sqlResult.SQL, sqlResult.Explanation)
11031124

11041125
// Build minimal meta information for UI display
1105-
metaData := map[string]interface{}{
1106-
"confidence": sqlResult.ConfidenceScore,
1107-
"model": sqlResult.ModelUsed,
1108-
"dialect": databaseType,
1126+
meta := GenerationMetadata{
1127+
Confidence: sqlResult.ConfidenceScore,
1128+
Model: sqlResult.ModelUsed,
1129+
Dialect: databaseType,
11091130
}
1110-
metaJSON, err := json.Marshal(metaData)
1131+
metaJSON, err := json.Marshal(meta)
11111132
if err != nil {
11121133
metaJSON = []byte(fmt.Sprintf(`{"confidence": %f, "model": "%s"}`,
11131134
sqlResult.ConfidenceScore, sqlResult.ModelUsed))

pkg/plugin/service_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,19 +400,19 @@ func TestResolveDatabaseType(t *testing.T) {
400400
}
401401

402402
t.Run("uses explicit value when provided", func(t *testing.T) {
403-
assert.Equal(t, "mysql", svc.resolveDatabaseType("mysql", nil))
403+
assert.Equal(t, "mysql", svc.resolveDatabaseType("mysql", GenerationConfigOverrides{}))
404404
})
405405

406406
t.Run("normalizes postgres aliases", func(t *testing.T) {
407-
assert.Equal(t, "postgresql", svc.resolveDatabaseType("pg", nil))
407+
assert.Equal(t, "postgresql", svc.resolveDatabaseType("pg", GenerationConfigOverrides{}))
408408
})
409409

410410
t.Run("falls back to config default", func(t *testing.T) {
411-
assert.Equal(t, "postgresql", svc.resolveDatabaseType("", nil))
411+
assert.Equal(t, "postgresql", svc.resolveDatabaseType("", GenerationConfigOverrides{}))
412412
})
413413

414414
t.Run("uses config map overrides", func(t *testing.T) {
415-
configMap := map[string]any{"database_dialect": "sqlite3"}
416-
assert.Equal(t, "sqlite", svc.resolveDatabaseType("", configMap))
415+
overrides := GenerationConfigOverrides{DatabaseDialectAlt: "sqlite3"}
416+
assert.Equal(t, "sqlite", svc.resolveDatabaseType("", overrides))
417417
})
418418
}

0 commit comments

Comments
 (0)