Skip to content

Commit ff18b34

Browse files
committed
fix(ai): secure runtime credentials across stack
Secure runtime API key handling across front/back-end, fix client leaks Rationale A. Prevent API keys from being exposed in request payloads and logs B. Remove key material from runtime cache hashing and reuse logic C. Ensure runtime clients and HTTP transports are closed when unused Changes A. Frontend now injects API keys via X-Auth header only and extends tests B. Backend propagates metadata API keys, hardens availability errors, and warns on stale clients C. Universal client pool adds reference counting with proper Close housekeeping Impact A. Aligns with security redline by avoiding key exposure and leaking sockets B. Backward compatible for existing UI/API consumers C. Low risk; added logs aid debugging and stale clients close gracefully Test A. go test ./... B. npm run test -- --run Refs A. Security issue 8, Resource leak issue 5, Error swallowing issue 7
1 parent 3bb9062 commit ff18b34

File tree

9 files changed

+396
-202
lines changed

9 files changed

+396
-202
lines changed

frontend/src/services/aiService.ts

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ export const aiService = {
5757
provider: config.provider,
5858
endpoint: config.endpoint,
5959
model: config.model,
60-
api_key: config.apiKey,
6160
max_tokens: config.maxTokens,
6261
timeout: formatTimeout(config.timeout)
6362
}
@@ -67,7 +66,7 @@ export const aiService = {
6766
message: string
6867
provider: string
6968
error?: string
70-
}>('test_connection', payload)
69+
}>('test_connection', payload, { apiKey: config.apiKey })
7170

7271
return {
7372
success: toBoolean(result.success),
@@ -130,12 +129,11 @@ export const aiService = {
130129
include_explanation: request.includeExplanation,
131130
provider: request.provider,
132131
endpoint: request.endpoint,
133-
api_key: request.apiKey,
134132
max_tokens: request.maxTokens,
135133
timeout: formatTimeout(request.timeout),
136134
database_type: request.databaseDialect
137135
})
138-
})
136+
}, { apiKey: request.apiKey })
139137

140138
console.log('📥 [aiService] Received backend result', {
141139
hasContent: !!result.content,
@@ -214,12 +212,11 @@ export const aiService = {
214212
provider: config.provider,
215213
endpoint: config.endpoint,
216214
model: config.model,
217-
api_key: config.apiKey,
218215
max_tokens: config.maxTokens,
219216
timeout: formatTimeout(config.timeout),
220217
database_type: config.databaseDialect
221218
}
222-
})
219+
}, { apiKey: config.apiKey })
223220
}
224221
}
225222

@@ -239,7 +236,7 @@ function formatTimeout(timeout: number | undefined): string {
239236
* is designed for database queries and transforms the request format.
240237
* The AI plugin expects: {type: 'ai', key: 'operation', sql: 'params_json'}
241238
*/
242-
async function callAPI<T>(key: string, data: any): Promise<T> {
239+
async function callAPI<T>(key: string, data: any, options: { apiKey?: string } = {}): Promise<T> {
243240
const requestBody = {
244241
type: 'ai',
245242
key,
@@ -254,12 +251,18 @@ async function callAPI<T>(key: string, data: any): Promise<T> {
254251
})
255252

256253
try {
254+
const headers: Record<string, string> = {
255+
'Content-Type': 'application/json',
256+
'X-Store-Name': API_STORE
257+
}
258+
259+
if (options.apiKey) {
260+
headers['X-Auth'] = `Bearer ${options.apiKey}`
261+
}
262+
257263
const response = await fetch(API_BASE, {
258264
method: 'POST',
259-
headers: {
260-
'Content-Type': 'application/json',
261-
'X-Store-Name': API_STORE
262-
},
265+
headers,
263266
body: JSON.stringify(requestBody)
264267
})
265268

frontend/tests/services/aiService.spec.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ describe('aiService', () => {
4141
})
4242
const payload = JSON.parse(body.sql)
4343
expect(payload.config).toContain('timeout')
44+
expect(payload.config).not.toContain('api_key')
4445

4546
return createFetchResponse({
4647
data: [
@@ -68,6 +69,38 @@ describe('aiService', () => {
6869
expect(response.meta).toEqual({ confidence: 0.9, model: 'demo' })
6970
})
7071

72+
it('sends API key through authorization header only', async () => {
73+
const apiKey = 'sk-secure'
74+
fetchMock.mockImplementationOnce(async (_url: FetchArgs[0], options: FetchArgs[1]) => {
75+
const headers = options?.headers as Record<string, string>
76+
expect(headers['X-Auth']).toBe(`Bearer ${apiKey}`)
77+
78+
const body = JSON.parse(String(options?.body))
79+
const payload = JSON.parse(body.sql)
80+
expect(payload.config).not.toContain('api_key')
81+
82+
return createFetchResponse({
83+
data: [
84+
{ key: 'success', value: true },
85+
{ key: 'content', value: 'sql:SELECT 1;' },
86+
{ key: 'meta', value: '{}' }
87+
]
88+
})
89+
})
90+
91+
await aiService.generateSQL({
92+
provider: 'openai',
93+
endpoint: 'https://api.openai.com',
94+
apiKey,
95+
model: 'gpt-5',
96+
prompt: 'SELECT 1;',
97+
timeout: 30,
98+
maxTokens: 256,
99+
includeExplanation: false,
100+
databaseDialect: 'postgresql'
101+
})
102+
})
103+
71104
it('parses health check response when backend returns boolean healthy flag', async () => {
72105
fetchMock.mockResolvedValueOnce(
73106
createFetchResponse({

pkg/ai/engine.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type GenerateSQLRequest struct {
6464
NaturalLanguage string `json:"natural_language"`
6565
DatabaseType string `json:"database_type"`
6666
Context map[string]string `json:"context,omitempty"`
67+
RuntimeAPIKey string `json:"-"`
6768
}
6869

6970
// GenerateSQLResponse represents an AI SQL generation response
@@ -113,7 +114,11 @@ func NewEngine(cfg config.AIConfig) (Engine, error) {
113114

114115
engine, err := newEngineFromManager(manager, cfg)
115116
if err != nil {
116-
_ = manager.Close()
117+
if closeErr := manager.Close(); closeErr != nil {
118+
logging.Logger.Warn("Failed to close AI manager after initialization error",
119+
"provider", cfg.DefaultService,
120+
"error", closeErr)
121+
}
117122
return nil, err
118123
}
119124
return engine, nil
@@ -163,7 +168,11 @@ func NewEngineWithManager(manager *Manager, cfg config.AIConfig) (Engine, error)
163168

164169
engine, err := newEngineFromManager(manager, cfg)
165170
if err != nil {
166-
_ = manager.Close()
171+
if closeErr := manager.Close(); closeErr != nil {
172+
logging.Logger.Warn("Failed to close AI manager after initialization error",
173+
"provider", cfg.DefaultService,
174+
"error", closeErr)
175+
}
167176
return nil, err
168177
}
169178
return engine, nil
@@ -206,6 +215,10 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G
206215
MaxTokens: defaultMaxTokens,
207216
}
208217

218+
if req.RuntimeAPIKey != "" {
219+
options.APIKey = req.RuntimeAPIKey
220+
}
221+
209222
// Add context if provided and extract preferred_model and runtime config
210223
var runtimeConfig map[string]interface{}
211224
if len(req.Context) > 0 {
@@ -236,7 +249,11 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G
236249
options.Provider = provider
237250
}
238251
if apiKey, ok := runtimeConfig["api_key"].(string); ok && apiKey != "" {
239-
options.APIKey = apiKey
252+
if options.APIKey == "" {
253+
options.APIKey = apiKey
254+
} else if options.APIKey != apiKey {
255+
logging.Logger.Warn("Runtime config API key differs from secured metadata; using secured value")
256+
}
240257
}
241258
if endpoint, ok := runtimeConfig["endpoint"].(string); ok && endpoint != "" {
242259
options.Endpoint = endpoint
@@ -313,6 +330,9 @@ func (e *aiEngine) Close() {
313330
e.generator.Close()
314331
}
315332
if e.manager != nil {
316-
_ = e.manager.Close()
333+
if err := e.manager.Close(); err != nil {
334+
logging.Logger.Warn("Failed to close AI manager during engine shutdown",
335+
"error", err)
336+
}
317337
}
318338
}

pkg/ai/generator.go

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package ai
1818

1919
import (
20+
"bytes"
2021
"context"
2122
"crypto/sha256"
2223
"encoding/hex"
@@ -38,10 +39,15 @@ type SQLGenerator struct {
3839
sqlDialects map[string]SQLDialect
3940
config config.AIConfig
4041
capabilities *SQLCapabilities
41-
runtimeClients map[string]interfaces.AIClient
42+
runtimeClients map[string]*runtimeClientEntry
4243
runtimeMu sync.RWMutex
4344
}
4445

46+
type runtimeClientEntry struct {
47+
client interfaces.AIClient
48+
apiKeyFingerprint []byte
49+
}
50+
4551
// Table represents a database table structure
4652
type Table struct {
4753
Name string `json:"name"`
@@ -142,7 +148,7 @@ func NewSQLGenerator(aiClient interfaces.AIClient, config config.AIConfig) (*SQL
142148
aiClient: aiClient,
143149
config: config,
144150
sqlDialects: make(map[string]SQLDialect),
145-
runtimeClients: make(map[string]interfaces.AIClient),
151+
runtimeClients: make(map[string]*runtimeClientEntry),
146152
}
147153

148154
// Initialize SQL dialects
@@ -630,23 +636,38 @@ func runtimeClientKey(options *GenerateOptions) string {
630636
hasher.Write([]byte(options.Endpoint))
631637
hasher.Write([]byte("|"))
632638
hasher.Write([]byte(options.Model))
633-
hasher.Write([]byte("|"))
634-
hasher.Write([]byte(options.APIKey))
635639
return hex.EncodeToString(hasher.Sum(nil))
636640
}
637641

642+
func runtimeAPIKeyFingerprint(apiKey string) []byte {
643+
if apiKey == "" {
644+
return nil
645+
}
646+
sum := sha256.Sum256([]byte(apiKey))
647+
fingerprint := make([]byte, len(sum))
648+
copy(fingerprint, sum[:])
649+
return fingerprint
650+
}
651+
638652
func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (interfaces.AIClient, bool, error) {
639653
key := runtimeClientKey(options)
654+
fingerprint := runtimeAPIKeyFingerprint(options.APIKey)
655+
640656
g.runtimeMu.RLock()
641-
if client, ok := g.runtimeClients[key]; ok {
642-
g.runtimeMu.RUnlock()
643-
return client, true, nil
657+
if entry, ok := g.runtimeClients[key]; ok {
658+
if bytes.Equal(entry.apiKeyFingerprint, fingerprint) {
659+
client := entry.client
660+
g.runtimeMu.RUnlock()
661+
return client, true, nil
662+
}
644663
}
645664
g.runtimeMu.RUnlock()
646665

647666
runtimeConfig := map[string]any{
648667
"provider": options.Provider,
649-
"api_key": options.APIKey,
668+
}
669+
if options.APIKey != "" {
670+
runtimeConfig["api_key"] = options.APIKey
650671
}
651672
if options.Endpoint != "" {
652673
runtimeConfig["base_url"] = options.Endpoint
@@ -664,23 +685,55 @@ func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (inter
664685
}
665686

666687
g.runtimeMu.Lock()
667-
if existing, ok := g.runtimeClients[key]; ok {
668-
g.runtimeMu.Unlock()
669-
_ = client.Close()
670-
return existing, true, nil
688+
var (
689+
existingEntry *runtimeClientEntry
690+
exists bool
691+
)
692+
if existingEntry, exists = g.runtimeClients[key]; exists {
693+
if bytes.Equal(existingEntry.apiKeyFingerprint, fingerprint) {
694+
g.runtimeMu.Unlock()
695+
if err := client.Close(); err != nil {
696+
logging.Logger.Warn("Failed to close redundant runtime client",
697+
"provider", options.Provider,
698+
"endpoint", options.Endpoint,
699+
"error", err)
700+
}
701+
return existingEntry.client, true, nil
702+
}
703+
}
704+
705+
g.runtimeClients[key] = &runtimeClientEntry{
706+
client: client,
707+
apiKeyFingerprint: fingerprint,
671708
}
672-
g.runtimeClients[key] = client
673709
g.runtimeMu.Unlock()
674710

711+
if exists && existingEntry != nil && existingEntry.client != nil {
712+
if err := existingEntry.client.Close(); err != nil {
713+
logging.Logger.Warn("Failed to close stale runtime client",
714+
"provider", options.Provider,
715+
"endpoint", options.Endpoint,
716+
"error", err)
717+
}
718+
}
719+
675720
return client, false, nil
676721
}
677722

678723
// Close releases all cached runtime clients held by the generator.
679724
func (g *SQLGenerator) Close() {
680725
g.runtimeMu.Lock()
681726
defer g.runtimeMu.Unlock()
682-
for key, client := range g.runtimeClients {
683-
_ = client.Close()
727+
for key, entry := range g.runtimeClients {
728+
if entry == nil || entry.client == nil {
729+
delete(g.runtimeClients, key)
730+
continue
731+
}
732+
if err := entry.client.Close(); err != nil {
733+
logging.Logger.Warn("Failed to close runtime client during generator shutdown",
734+
"key", key,
735+
"error", err)
736+
}
684737
delete(g.runtimeClients, key)
685738
}
686739
}

pkg/ai/generator_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ package ai
33
import (
44
"testing"
55

6-
"github.com/linuxsuren/atest-ext-ai/pkg/interfaces"
76
"github.com/stretchr/testify/require"
87
)
98

109
func TestRuntimeClientReuseAndClose(t *testing.T) {
1110
generator := &SQLGenerator{
12-
runtimeClients: make(map[string]interfaces.AIClient),
11+
runtimeClients: make(map[string]*runtimeClientEntry),
1312
}
1413

1514
options := &GenerateOptions{

0 commit comments

Comments
 (0)