diff --git a/src/client/hive.ts b/src/client/hive.ts new file mode 100644 index 0000000..1d08d4c --- /dev/null +++ b/src/client/hive.ts @@ -0,0 +1,312 @@ +import { Config, QueryResult as TDQueryResult } from '../types'; +import { maskApiKey } from '../config'; +import { getTdApiEndpointForSite } from './tdapi/endpoints'; + +type JobShowResponse = { + hive_result_schema?: string | null; + debug?: { + cmdout?: string | null; + stderr?: string | null; + }; + error?: string | null; + status?: string | null; + query?: string | null; + database?: string | null; +}; + +export interface HiveIssueJobOptions { + priority?: number; + retry_limit?: number; + pool_name?: string; +} + +export interface HiveJobStatus { + job_id: string; + status: string; + url?: string; + result_size?: number; + num_records?: number; + database?: string; + type?: string; + query?: string; + start_at?: string; + end_at?: string; + created_at?: string; + updated_at?: string; + result_schema?: Array<{ name: string; type: string }>; + error?: string; +} + +/** + * Minimal REST client for Treasure Data Hive jobs (v3 API) + */ +export class TDHiveClient { + private readonly apiKey: string; + private readonly baseUrl: string; + readonly database: string; + + constructor(config: Config) { + this.apiKey = config.td_api_key; + this.baseUrl = getTdApiEndpointForSite(config.site); + this.database = config.database || 'sample_datasets'; + } + + private headers(): Record { + return { + Authorization: `TD1 ${this.apiKey}`, + 'Content-Type': 'application/json', + Accept: 'application/json', + }; + } + + private async request(method: string, path: string, body?: unknown): Promise { + const url = `${this.baseUrl}${path}`; + try { + const res = await fetch(url, { + method, + headers: this.headers(), + body: body ? JSON.stringify(body) : undefined, + }); + + if (!res.ok) { + const text = await res.text().catch(() => ''); + if (process.env.TD_MCP_LOG_TO_CONSOLE === 'true') { + console.error(`[Hive] ${method} ${path} -> ${res.status}: ${text || res.statusText}`); + } + throw new Error(`Hive API error ${res.status}: ${text || res.statusText}`); + } + + // Some endpoints return empty body on 204 + const ct = res.headers.get('content-type') || ''; + if (ct.includes('application/json')) { + return (await res.json()) as T; + } + return (await res.text()) as unknown as T; + } catch (e) { + if (e instanceof Error) { + e.message = e.message.replace(this.apiKey, maskApiKey(this.apiKey)); + } + throw e; + } + } + + private async requestWithFallback( + method: string, + paths: string[], + body?: unknown + ): Promise { + let lastError: unknown; + for (const p of paths) { + try { + return await this.request(method, p, body); + } catch (e) { + lastError = e; + const msg = e instanceof Error ? e.message : ''; + if (!/Path and method do not match any API endpoint/i.test(msg)) { + throw e; // not a routing 404, rethrow + } + // otherwise try next path + } + } + throw lastError instanceof Error ? lastError : new Error('Hive API request failed'); + } + + private async requestText(method: string, path: string, body?: unknown): Promise { + const url = `${this.baseUrl}${path}`; + try { + const res = await fetch(url, { + method, + headers: this.headers(), + body: body ? JSON.stringify(body) : undefined, + }); + + if (!res.ok) { + const text = await res.text().catch(() => ''); + if (process.env.TD_MCP_LOG_TO_CONSOLE === 'true') { + console.error(`[Hive] ${method} ${path} -> ${res.status}: ${text || res.statusText}`); + } + throw new Error(`Hive API error ${res.status}: ${text || res.statusText}`); + } + + return await res.text(); + } catch (e) { + if (e instanceof Error) { + e.message = e.message.replace(this.apiKey, maskApiKey(this.apiKey)); + } + throw e; + } + } + + private parseResultText(text: string): unknown[][] { + const trimmed = (text || '').trim(); + if (!trimmed) return []; + if (trimmed.startsWith('[')) { + try { + const parsed = JSON.parse(trimmed); + if (Array.isArray(parsed)) { + if (parsed.length > 0 && Array.isArray(parsed[0])) { + return parsed as unknown[][]; + } + return [parsed as unknown[]]; + } + } catch { + // fall through + } + } + const rows: unknown[][] = []; + for (const line of trimmed.split(/\r?\n/)) { + const l = line.trim(); + if (!l) continue; + try { + const entry = JSON.parse(l); + rows.push(Array.isArray(entry) ? entry : [entry]); + } catch { + rows.push([l]); + } + } + return rows; + } + + /** + * Issues a Hive job and returns job id + */ + async issueHive( + query: string, + database?: string, + options?: HiveIssueJobOptions + ): Promise<{ job_id: string }> { + if (!query) throw new Error('Query is required'); + const payload: Record = { + query, + db: database || this.database, + priority: options?.priority, + retry_limit: options?.retry_limit, + pool_name: options?.pool_name, + }; + // Try v3 endpoints first (spec: /job/issue/{job_type}/{database_name}) + const db = (database || this.database).trim(); + if (!db) throw new Error('Database name is required'); + return await this.requestWithFallback<{ job_id: string }>( + 'POST', + [ + `/v3/job/issue/hive/${encodeURIComponent(db)}`, + `/v3/jobs/issue/hive/${encodeURIComponent(db)}`, + ], + payload + ); + } + + /** + * Gets job status/details + */ + async jobStatus(jobId: string): Promise { + if (!jobId) throw new Error('job_id is required'); + return this.requestWithFallback('GET', [ + `/v3/job/status/${encodeURIComponent(jobId)}`, + `/v3/jobs/status/${encodeURIComponent(jobId)}`, + `/v3/jobs/${encodeURIComponent(jobId)}/status`, + ]); + } + + /** + * Fetches job result as JSON rows (array of arrays) + */ + async jobResult(jobId: string): Promise<{ rows: unknown[][] }> { + if (!jobId) throw new Error('job_id is required'); + const text = await this.requestText( + 'GET', + `/v3/job/result/${encodeURIComponent(jobId)}?format=json` + ); + return { rows: this.parseResultText(text) }; + } + + /** + * Fetches job result schema (column names/types) + */ + async jobResultSchema(jobId: string): Promise> { + if (!jobId) throw new Error('job_id is required'); + const show = await this.request( + 'GET', + `/v3/job/show/${encodeURIComponent(jobId)}` + ); + const raw = show?.hive_result_schema; + if (typeof raw === 'string') { + try { + const arr = JSON.parse(raw) as Array; + return Array.isArray(arr) + ? arr.map((c) => ({ name: String(c[0]), type: String(c[1]) })) + : []; + } catch { + return []; + } + } + return []; + } + + /** + * Helper that waits for job completion + */ + async waitForCompletion( + jobId: string, + opts?: { pollMs?: number; timeoutMs?: number } + ): Promise { + const poll = opts?.pollMs ?? 2000; + const timeout = opts?.timeoutMs ?? 15 * 60 * 1000; // 15m default + const start = Date.now(); + // Typical terminal states: success, error, killed + while (true) { + const s = await this.jobStatus(jobId); + if (['success', 'error', 'killed'].includes(s.status)) return s; + if (Date.now() - start > timeout) { + throw new Error(`Timed out waiting for job ${jobId} to complete`); + } + await new Promise((r) => setTimeout(r, poll)); + } + } + + /** + * Convenience method: run read-only query via Hive and return results + */ + async query(sql: string, database?: string): Promise { + const { job_id } = await this.issueHive(sql, database); + const status = await this.waitForCompletion(job_id); + if (status.status !== 'success') { + const details = await this.getJobErrorDetails(job_id); + const message = details || status.error || ''; + throw new Error(`Hive job failed (${status.status})${message ? `: ${message}` : ''}`); + } + const [schema, result] = await Promise.all([ + this.jobResultSchema(job_id).catch(() => []), + this.jobResult(job_id), + ]); + + // Map to common TDQueryResult shape + const columns = Array.isArray(schema) + ? schema.map((c) => ({ name: c.name, type: c.type })) + : []; + const rows = result.rows || []; + const data = rows.map((arr) => { + const o: Record = {}; + columns.forEach((c, i) => (o[c.name] = (arr as unknown[])[i])); + return o; + }); + return { columns, data, rowCount: data.length }; + } + + private async getJobErrorDetails(jobId: string): Promise { + try { + const show = await this.request( + 'GET', + `/v3/job/show/${encodeURIComponent(jobId)}` + ); + const parts: string[] = []; + if (show?.error) parts.push(String(show.error)); + if (show?.debug?.stderr) parts.push(String(show.debug.stderr)); + if (show?.debug?.cmdout) parts.push(String(show.debug.cmdout)); + const text = parts.join('\n').trim(); + return text.length > 0 ? text : undefined; + } catch { + return undefined; + } + } +} diff --git a/src/client/tdapi/endpoints.ts b/src/client/tdapi/endpoints.ts new file mode 100644 index 0000000..9f9a267 --- /dev/null +++ b/src/client/tdapi/endpoints.ts @@ -0,0 +1,21 @@ +import { TDSite } from '../../types'; + +/** + * Mapping of Treasure Data sites to their main REST API endpoints + */ +export const TD_API_ENDPOINTS: Record = { + us01: 'https://api.treasuredata.com', + jp01: 'https://api.treasuredata.co.jp', + eu01: 'https://api.eu01.treasuredata.com', + ap02: 'https://api.ap02.treasuredata.com', + ap03: 'https://api.ap03.treasuredata.com', + // Development endpoint (best effort; can be overridden later if needed) + dev: 'https://api-development.us01.treasuredata.com', +} as const; + +export function getTdApiEndpointForSite(site: TDSite): string { + const endpoint = TD_API_ENDPOINTS[site]; + if (!endpoint) throw new Error(`Unknown TD site: ${site}`); + return endpoint; +} + diff --git a/src/security/audit-logger.ts b/src/security/audit-logger.ts index d56633e..76dc177 100644 --- a/src/security/audit-logger.ts +++ b/src/security/audit-logger.ts @@ -203,7 +203,7 @@ export class AuditLogger { const database = entry.database ? `[${entry.database}]` : ''; const rowCount = entry.rowCount !== undefined ? ` -> ${entry.rowCount} rows` : ''; - console.log( + console.error( `[${entry.timestamp.toISOString()}] ${status} ${entry.queryType} ${database}${duration}${rowCount}` ); diff --git a/src/server.ts b/src/server.ts index fdcf829..5291c5d 100644 --- a/src/server.ts +++ b/src/server.ts @@ -24,6 +24,7 @@ import { segmentSql, getSegment } from './tools/cdp'; +import { hiveQuery, hiveExecute, hiveJobStatus, hiveJobResult } from './tools/hive'; import { listProjects, listWorkflows, @@ -231,6 +232,27 @@ export class TDMcpServer { properties: {}, }, }, + // Hive Tools + { + name: hiveQuery.name, + description: hiveQuery.description, + inputSchema: hiveQuery.inputSchema, + }, + { + name: hiveExecute.name, + description: hiveExecute.description, + inputSchema: hiveExecute.inputSchema, + }, + { + name: hiveJobStatus.name, + description: hiveJobStatus.description, + inputSchema: hiveJobStatus.inputSchema, + }, + { + name: hiveJobResult.name, + description: hiveJobResult.description, + inputSchema: hiveJobResult.inputSchema, + }, // CDP Tools { name: listParentSegmentsTool.name, @@ -636,6 +658,40 @@ export class TDMcpServer { }; } + // Hive tools + case hiveQuery.name: { + const result = await hiveQuery.handler(args || {}); + return { + content: [ + { type: 'text', text: JSON.stringify(result, null, 2) }, + ], + }; + } + case hiveExecute.name: { + const result = await hiveExecute.handler(args || {}); + return { + content: [ + { type: 'text', text: JSON.stringify(result, null, 2) }, + ], + }; + } + case hiveJobStatus.name: { + const result = await hiveJobStatus.handler(args || {}); + return { + content: [ + { type: 'text', text: JSON.stringify(result, null, 2) }, + ], + }; + } + case hiveJobResult.name: { + const result = await hiveJobResult.handler(args || {}); + return { + content: [ + { type: 'text', text: JSON.stringify(result, null, 2) }, + ], + }; + } + default: throw new McpError( ErrorCode.MethodNotFound, @@ -700,4 +756,4 @@ export class TDMcpServer { await this.server.connect(transport); console.error('TD MCP Server started'); } -} \ No newline at end of file +} diff --git a/src/tools/hive/execute.ts b/src/tools/hive/execute.ts new file mode 100644 index 0000000..685f321 --- /dev/null +++ b/src/tools/hive/execute.ts @@ -0,0 +1,97 @@ +import { z } from 'zod'; +import { loadConfig } from '../../config'; +import { TDHiveClient } from '../../client/hive'; +import { QueryValidator } from '../../security/query-validator'; +import { AuditLogger } from '../../security/audit-logger'; +import { hiveQuery } from './query'; + +type HiveExecuteWriteResult = { + job_id: string; + status: string; + success: boolean; + message: string; +}; + +type HiveQueryResult = { + columns: Array<{ name: string; type: string }>; + rows: unknown[][]; + rowCount: number; + truncated: boolean; +}; + +const inputSchema = z.object({ + sql: z.string().min(1).describe('The Hive SQL to execute (write operations allowed when enable_updates=true)'), + database: z.string().optional().describe('Database to execute against'), + priority: z.number().int().optional().describe('Job priority (optional)'), + retry_limit: z.number().int().optional().describe('Retry limit (optional)'), + pool_name: z.string().optional().describe('Resource pool name (optional)'), + limit: z + .number() + .int() + .min(1) + .max(10000) + .optional() + .describe('Max rows to return for read-only SQL routed to hive_query (default 40)') +}); + +export const hiveExecute = { + name: 'hive_execute', + description: 'Execute write operations using Treasure Data Hive (requires enable_updates=true). Returns job id and final status.', + inputSchema: { + type: 'object', + properties: { + sql: { type: 'string', description: 'Hive SQL statement to execute.' }, + database: { type: 'string', description: 'Database (optional).' }, + priority: { type: 'number', description: 'Job priority (optional).' }, + retry_limit: { type: 'number', description: 'Retry limit (optional).' }, + pool_name: { type: 'string', description: 'Resource pool name (optional).' }, + limit: { type: 'number', description: 'Max rows to return for read-only SQL (default 40).', minimum: 1, maximum: 10000 } + }, + required: ['sql'] + }, + handler: async (args: unknown): Promise => { + const { sql, database, priority, retry_limit, pool_name, limit } = inputSchema.parse(args); + + const config = loadConfig(); + const client = new TDHiveClient({ ...config, database: database || config.database }); + const validator = new QueryValidator(config.enable_updates); + const auditor = new AuditLogger({ logToConsole: process.env.TD_MCP_LOG_TO_CONSOLE === 'true' }); + + // Validate the statement + const validation = validator.validate(sql); + if (!validation.isValid) { + throw new Error(`Query validation failed: ${validation.error}`); + } + + if (validator.isReadOnly(validation.queryType)) { + // Reuse hive_query for SELECT/SHOW/DESCRIBE to return data rows + return await hiveQuery.handler({ sql, database, limit: limit ?? 40 }); + } + + const start = Date.now(); + try { + const { job_id } = await client.issueHive(sql, database, { priority, retry_limit, pool_name }); + const status = await client.waitForCompletion(job_id); + const duration = Date.now() - start; + const success = status.status === 'success'; + + if (success) { + auditor.logSuccess(validation.queryType, sql, client.database, duration, status.num_records || 0); + } else { + auditor.logFailure(validation.queryType, sql, status.error || status.status, client.database, duration); + } + + return { + job_id, + status: status.status, + success, + message: success ? 'Job completed successfully' : (status.error || 'Job did not succeed') + }; + } catch (e) { + const duration = Date.now() - start; + const msg = e instanceof Error ? e.message : String(e); + auditor.logFailure(validation.queryType, sql, msg, client.database, duration); + throw new Error(`Hive execute failed: ${msg}`); + } + } +}; diff --git a/src/tools/hive/index.ts b/src/tools/hive/index.ts new file mode 100644 index 0000000..38e47d1 --- /dev/null +++ b/src/tools/hive/index.ts @@ -0,0 +1,5 @@ +export { hiveQuery } from './query.js'; +export { hiveExecute } from './execute.js'; +export { hiveJobStatus } from './job-status.js'; +export { hiveJobResult } from './job-result.js'; + diff --git a/src/tools/hive/job-result.ts b/src/tools/hive/job-result.ts new file mode 100644 index 0000000..906343e --- /dev/null +++ b/src/tools/hive/job-result.ts @@ -0,0 +1,39 @@ +import { z } from 'zod'; +import { loadConfig } from '../../config'; +import { TDHiveClient } from '../../client/hive'; + +const inputSchema = z.object({ + job_id: z.string().min(1).describe('TD job id to fetch results for') +}); + +export const hiveJobResult = { + name: 'hive_job_result', + description: 'Fetch result rows and schema for a completed Treasure Data Hive job.', + inputSchema: { + type: 'object', + properties: { + job_id: { type: 'string', description: 'Completed job ID.' } + }, + required: ['job_id'] + }, + handler: async (args: unknown) => { + const { job_id } = inputSchema.parse(args); + const config = loadConfig(); + const client = new TDHiveClient(config); + const [schema, result] = await Promise.all([ + client.jobResultSchema(job_id).catch(() => []), + client.jobResult(job_id), + ]); + + const columns = Array.isArray(schema) + ? schema.map((c) => ({ name: c.name, type: c.type })) + : []; + const rows = result.rows || []; + return { + columns, + rows, + rowCount: rows.length, + }; + } +}; + diff --git a/src/tools/hive/job-status.ts b/src/tools/hive/job-status.ts new file mode 100644 index 0000000..5eed585 --- /dev/null +++ b/src/tools/hive/job-status.ts @@ -0,0 +1,27 @@ +import { z } from 'zod'; +import { loadConfig } from '../../config'; +import { TDHiveClient } from '../../client/hive'; + +const inputSchema = z.object({ + job_id: z.string().min(1).describe('TD job id to check') +}); + +export const hiveJobStatus = { + name: 'hive_job_status', + description: 'Get status/details for a Treasure Data Hive job by job_id.', + inputSchema: { + type: 'object', + properties: { + job_id: { type: 'string', description: 'Job ID to check.' } + }, + required: ['job_id'] + }, + handler: async (args: unknown) => { + const { job_id } = inputSchema.parse(args); + const config = loadConfig(); + const client = new TDHiveClient(config); + const status = await client.jobStatus(job_id); + return status; + } +}; + diff --git a/src/tools/hive/query.ts b/src/tools/hive/query.ts new file mode 100644 index 0000000..1118d26 --- /dev/null +++ b/src/tools/hive/query.ts @@ -0,0 +1,75 @@ +import { z } from 'zod'; +import { loadConfig } from '../../config'; +import { TDHiveClient } from '../../client/hive'; +import { QueryValidator } from '../../security/query-validator'; +import { AuditLogger } from '../../security/audit-logger'; + +const inputSchema = z.object({ + sql: z.string().min(1).describe('The Hive SQL query to execute (read-only)'), + database: z.string().optional().describe('The database name (optional if TD_DATABASE is configured)'), + limit: z.number().int().min(1).max(10000).optional().describe('Max rows to return (default 40)') +}); + +export const hiveQuery = { + name: 'hive_query', + description: 'Execute a read-only SQL query using Treasure Data Hive (v3 API). Supports LIMIT injection and returns rows + schema.', + inputSchema: { + type: 'object', + properties: { + sql: { type: 'string', description: 'Hive SQL query (SELECT, SHOW, DESCRIBE only).' }, + database: { type: 'string', description: 'Database to query (optional).' }, + limit: { type: 'number', description: 'Max rows to return (default 40).', minimum: 1, maximum: 10000 } + }, + required: ['sql'] + }, + handler: async (args: unknown) => { + const { sql, database, limit } = inputSchema.parse(args); + + const config = loadConfig(); + const client = new TDHiveClient({ ...config, database: database || config.database }); + const validator = new QueryValidator(false); // disallow writes regardless of env flag for this tool + const auditor = new AuditLogger({ logToConsole: process.env.TD_MCP_LOG_TO_CONSOLE === 'true' }); + + // Validate as read-only + const validation = validator.validate(sql); + if (!validation.isValid || !validator.isReadOnly(validation.queryType)) { + throw new Error('Only read-only queries are allowed in hive_query. Use hive_execute for write operations.'); + } + + const finalLimit = limit ?? 40; + const processedSql = injectLimit(sql, finalLimit); + const start = Date.now(); + try { + const result = await client.query(processedSql, database); + const duration = Date.now() - start; + + // Convert to rows[][] for MCP + const rows = result.data.map((row) => result.columns.map((c) => row[c.name])); + const truncated = rows.length === finalLimit && !hasExplicitLimit(sql); + + auditor.logSuccess(validation.queryType, processedSql, client.database, duration, rows.length); + + return { + columns: result.columns, + rows, + rowCount: rows.length, + truncated + }; + } catch (e) { + const duration = Date.now() - start; + const msg = e instanceof Error ? e.message : String(e); + auditor.logFailure(validation.queryType, processedSql, msg, client.database, duration); + throw new Error(`Hive query failed: ${msg}`); + } + } +}; + +function injectLimit(sql: string, limit: number): string { + if (hasExplicitLimit(sql)) return sql; + const trimmed = sql.trim().replace(/;+$/, ''); + return `${trimmed} LIMIT ${limit}`; +} + +function hasExplicitLimit(sql: string): boolean { + return /\bLIMIT\s+\d+/i.test(sql); +} diff --git a/tests/client/hive.test.ts b/tests/client/hive.test.ts new file mode 100644 index 0000000..ea7068f --- /dev/null +++ b/tests/client/hive.test.ts @@ -0,0 +1,99 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { TDHiveClient } from '../../src/client/hive'; + +// Mock maskApiKey to predictable masking +vi.mock('../../src/config', () => ({ + maskApiKey: (k: string) => `${k.slice(0,4)}...${k.slice(-4)}`, +})); + +// Mock fetch globally +global.fetch = vi.fn(); + +describe('TDHiveClient', () => { + const apiKey = '1111222233334444'; + const baseConfig = { + td_api_key: apiKey, + site: 'us01' as const, + database: 'sample_datasets', + }; + const mockFetch = global.fetch as unknown as ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + function mockRes({ ok = true, status = 200, json, text, contentType = 'application/json' }: any) { + return { + ok, + status, + headers: { get: (k: string) => (k.toLowerCase() === 'content-type' ? contentType : null) }, + json: json ? async () => json : undefined, + text: text ? async () => text : async () => '', + } as any; + } + + it('issues hive job with correct payload', async () => { + const client = new TDHiveClient(baseConfig); + mockFetch.mockResolvedValueOnce(mockRes({ json: { job_id: '123' } })); + const res = await client.issueHive('SELECT 1', 'db1', { priority: 1, retry_limit: 2, pool_name: 'gold' }); + expect(res).toEqual({ job_id: '123' }); + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.treasuredata.com/v3/job/issue/hive/db1', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ Authorization: `TD1 ${apiKey}` }), + body: JSON.stringify({ query: 'SELECT 1', db: 'db1', priority: 1, retry_limit: 2, pool_name: 'gold' }), + }) + ); + }); + + it('gets job status', async () => { + const client = new TDHiveClient(baseConfig); + mockFetch.mockResolvedValueOnce(mockRes({ json: { job_id: '123', status: 'running' } })); + const res = await client.jobStatus('123'); + expect(res.status).toBe('running'); + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.treasuredata.com/v3/job/status/123', + expect.objectContaining({ method: 'GET' }) + ); + }); + + it('fetches job result with format=json', async () => { + const client = new TDHiveClient(baseConfig); + mockFetch.mockResolvedValueOnce({ ok: true, text: async () => '[[1],[2]]' } as any); + const res = await client.jobResult('123'); + expect(res.rows).toEqual([[1],[2]]); + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.treasuredata.com/v3/job/result/123?format=json', + expect.objectContaining({ method: 'GET' }) + ); + }); + + it('fetches job result schema via job/show', async () => { + const client = new TDHiveClient(baseConfig); + const hiveSchema = JSON.stringify([["c1","int"],["c2","varchar"]]); + mockFetch.mockResolvedValueOnce(mockRes({ json: { hive_result_schema: hiveSchema } })); + const res = await client.jobResultSchema('123'); + expect(res).toEqual([{ name: 'c1', type: 'int' }, { name: 'c2', type: 'varchar' }]); + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.treasuredata.com/v3/job/show/123', + expect.objectContaining({ method: 'GET' }) + ); + }); + + it('waits for completion until terminal state', async () => { + const client = new TDHiveClient(baseConfig); + mockFetch + .mockResolvedValueOnce(mockRes({ json: { status: 'queued' } })) + .mockResolvedValueOnce(mockRes({ json: { status: 'running' } })) + .mockResolvedValueOnce(mockRes({ json: { status: 'success' } })); + const status = await client.waitForCompletion('123', { pollMs: 1, timeoutMs: 1000 }); + expect(status.status).toBe('success'); + }); + + it('masks API key in error messages', async () => { + const client = new TDHiveClient(baseConfig); + mockFetch.mockResolvedValueOnce(mockRes({ ok: false, status: 403, text: `Forbidden for ${apiKey}` })); + await expect(client.issueHive('SELECT 1')).rejects.toThrow(/1111\.\.\.4444/); + }); +}); diff --git a/tests/security/audit-logger.test.ts b/tests/security/audit-logger.test.ts index d0357f6..c0f9560 100644 --- a/tests/security/audit-logger.test.ts +++ b/tests/security/audit-logger.test.ts @@ -72,8 +72,8 @@ describe('AuditLogger', () => { logger = new AuditLogger({ logToConsole: true }); logger.logSuccess('SELECT', 'SELECT 1', 'mydb', 100, 1); - expect(consoleLogSpy).toHaveBeenCalled(); - const logMessage = consoleLogSpy.mock.calls[0][0] as string; + expect(consoleErrorSpy).toHaveBeenCalled(); + const logMessage = consoleErrorSpy.mock.calls[0][0] as string; expect(logMessage).toContain('✓ SELECT [mydb] (100ms) -> 1 rows'); }); @@ -81,17 +81,16 @@ describe('AuditLogger', () => { logger = new AuditLogger({ logToConsole: true }); logger.logFailure('UPDATE', 'UPDATE x', 'Permission denied'); - expect(consoleLogSpy).toHaveBeenCalled(); expect(consoleErrorSpy).toHaveBeenCalled(); - const errorMessage = consoleErrorSpy.mock.calls[0][0] as string; - expect(errorMessage).toContain('Permission denied'); + const messages = consoleErrorSpy.mock.calls.map(call => call[0] as string); + expect(messages.some(m => m.includes('Permission denied'))).toBe(true); }); it('should not log to console when disabled', () => { logger = new AuditLogger({ logToConsole: false }); logger.logSuccess('SELECT', 'SELECT 1'); - expect(consoleLogSpy).not.toHaveBeenCalled(); + expect(consoleErrorSpy).not.toHaveBeenCalled(); }); }); @@ -205,4 +204,4 @@ describe('AuditLogger', () => { expect(logs1).toEqual(logs2); }); }); -}); \ No newline at end of file +}); diff --git a/tests/tools/hive/execute.test.ts b/tests/tools/hive/execute.test.ts new file mode 100644 index 0000000..bb281b9 --- /dev/null +++ b/tests/tools/hive/execute.test.ts @@ -0,0 +1,65 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { hiveExecute } from '../../../src/tools/hive/execute'; +import { loadConfig } from '../../../src/config'; +import { TDHiveClient } from '../../../src/client/hive'; + +vi.mock('../../../src/config'); +vi.mock('../../../src/client/hive'); + +describe('hiveExecute tool', () => { + const mockLoadConfig = loadConfig as any; + const MockHiveClient = TDHiveClient as any; + let mockClient: any; + + beforeEach(() => { + vi.clearAllMocks(); + mockLoadConfig.mockReturnValue({ td_api_key: 'k', site: 'us01', database: 'db', enable_updates: true }); + mockClient = { issueHive: vi.fn(), waitForCompletion: vi.fn(), query: vi.fn() }; + MockHiveClient.mockImplementation(() => mockClient); + }); + + it('routes read-only statements to hive_query and returns rows', async () => { + mockLoadConfig.mockReturnValueOnce({ td_api_key: 'k', site: 'us01', database: 'db', enable_updates: true }); + // hive_query will call client.query and expect columns/data + mockClient.query.mockResolvedValue({ + columns: [{ name: 'c1', type: 'int' }], + data: [{ c1: 1 }], + rowCount: 1, + }); + + const res = await hiveExecute.handler({ sql: 'SELECT 1' }); + expect(res).toEqual({ + columns: [{ name: 'c1', type: 'int' }], + rows: [[1]], + rowCount: 1, + truncated: false, + }); + expect(mockClient.query).toHaveBeenCalledWith('SELECT 1 LIMIT 40', undefined); + }); + + it('routes read-only with custom limit', async () => { + mockLoadConfig.mockReturnValueOnce({ td_api_key: 'k', site: 'us01', database: 'db', enable_updates: true }); + mockClient.query.mockResolvedValue({ + columns: [{ name: 'c1', type: 'int' }], + data: [{ c1: 1 }], + rowCount: 1, + }); + + const res = await hiveExecute.handler({ sql: 'SELECT 1', limit: 5 }); + expect(res).toEqual({ + columns: [{ name: 'c1', type: 'int' }], + rows: [[1]], + rowCount: 1, + truncated: false, + }); + expect(mockClient.query).toHaveBeenCalledWith('SELECT 1 LIMIT 5', undefined); + }); + + it('executes write and returns status', async () => { + mockClient.issueHive.mockResolvedValue({ job_id: 'jid' }); + mockClient.waitForCompletion.mockResolvedValue({ status: 'success', num_records: 0 }); + const res = await hiveExecute.handler({ sql: 'INSERT INTO t SELECT 1', database: 'db1' }); + expect(res).toEqual({ job_id: 'jid', status: 'success', success: true, message: 'Job completed successfully' }); + expect(mockClient.issueHive).toHaveBeenCalledWith('INSERT INTO t SELECT 1', 'db1', { priority: undefined, retry_limit: undefined, pool_name: undefined }); + }); +}); diff --git a/tests/tools/hive/job-result.test.ts b/tests/tools/hive/job-result.test.ts new file mode 100644 index 0000000..09b774d --- /dev/null +++ b/tests/tools/hive/job-result.test.ts @@ -0,0 +1,28 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { hiveJobResult } from '../../../src/tools/hive/job-result'; +import { loadConfig } from '../../../src/config'; +import { TDHiveClient } from '../../../src/client/hive'; + +vi.mock('../../../src/config'); +vi.mock('../../../src/client/hive'); + +describe('hiveJobResult tool', () => { + const mockLoadConfig = loadConfig as any; + const MockHiveClient = TDHiveClient as any; + let mockClient: any; + + beforeEach(() => { + vi.clearAllMocks(); + mockLoadConfig.mockReturnValue({ td_api_key: 'k', site: 'us01' }); + mockClient = { jobResultSchema: vi.fn(), jobResult: vi.fn() }; + MockHiveClient.mockImplementation(() => mockClient); + }); + + it('returns columns and rows', async () => { + mockClient.jobResultSchema.mockResolvedValue([{ name: 'c1', type: 'int' }]); + mockClient.jobResult.mockResolvedValue({ rows: [[1]] }); + const res = await hiveJobResult.handler({ job_id: 'jid' }); + expect(res).toEqual({ columns: [{ name: 'c1', type: 'int' }], rows: [[1]], rowCount: 1 }); + }); +}); + diff --git a/tests/tools/hive/job-status.test.ts b/tests/tools/hive/job-status.test.ts new file mode 100644 index 0000000..ee0c8bb --- /dev/null +++ b/tests/tools/hive/job-status.test.ts @@ -0,0 +1,28 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { hiveJobStatus } from '../../../src/tools/hive/job-status'; +import { loadConfig } from '../../../src/config'; +import { TDHiveClient } from '../../../src/client/hive'; + +vi.mock('../../../src/config'); +vi.mock('../../../src/client/hive'); + +describe('hiveJobStatus tool', () => { + const mockLoadConfig = loadConfig as any; + const MockHiveClient = TDHiveClient as any; + let mockClient: any; + + beforeEach(() => { + vi.clearAllMocks(); + mockLoadConfig.mockReturnValue({ td_api_key: 'k', site: 'us01' }); + mockClient = { jobStatus: vi.fn() }; + MockHiveClient.mockImplementation(() => mockClient); + }); + + it('returns job status', async () => { + mockClient.jobStatus.mockResolvedValue({ job_id: 'jid', status: 'running' }); + const res = await hiveJobStatus.handler({ job_id: 'jid' }); + expect(res).toEqual({ job_id: 'jid', status: 'running' }); + expect(mockClient.jobStatus).toHaveBeenCalledWith('jid'); + }); +}); + diff --git a/tests/tools/hive/query.test.ts b/tests/tools/hive/query.test.ts new file mode 100644 index 0000000..aec47c7 --- /dev/null +++ b/tests/tools/hive/query.test.ts @@ -0,0 +1,45 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { hiveQuery } from '../../../src/tools/hive/query'; +import { loadConfig } from '../../../src/config'; +import { TDHiveClient } from '../../../src/client/hive'; + +vi.mock('../../../src/config'); +vi.mock('../../../src/client/hive'); + +describe('hiveQuery tool', () => { + const mockLoadConfig = loadConfig as any; + const MockHiveClient = TDHiveClient as any; + let mockClient: any; + + beforeEach(() => { + vi.clearAllMocks(); + mockLoadConfig.mockReturnValue({ td_api_key: 'k', site: 'us01', database: 'db' }); + mockClient = { query: vi.fn() }; + MockHiveClient.mockImplementation(() => mockClient); + }); + + it('executes read-only query and injects limit', async () => { + mockClient.query.mockResolvedValue({ + columns: [{ name: 'c1', type: 'int' }], + data: [{ c1: 1 }], + rowCount: 1, + }); + + const res = await hiveQuery.handler({ sql: 'SELECT 1', limit: 5 }); + expect(res).toEqual({ + columns: [{ name: 'c1', type: 'int' }], + rows: [[1]], + rowCount: 1, + truncated: false, + }); + expect(MockHiveClient).toHaveBeenCalled(); + expect(mockClient.query).toHaveBeenCalledWith('SELECT 1 LIMIT 5', undefined); + }); + + it('rejects write statements', async () => { + await expect(hiveQuery.handler({ sql: 'UPDATE t SET a=1' })).rejects.toThrow( + 'Only read-only queries are allowed in hive_query' + ); + }); +}); +