diff --git a/node-zerox/src/index.ts b/node-zerox/src/index.ts index bf22c35e..0cdd2e5b 100644 --- a/node-zerox/src/index.ts +++ b/node-zerox/src/index.ts @@ -34,6 +34,7 @@ import { OperationMode, Page, PageStatus, + ValidationLog, ZeroxArgs, ZeroxOutput, } from "./types"; @@ -82,6 +83,7 @@ export const zerox = async ({ let priorPage: string = ""; let pages: Page[] = []; let imagePaths: string[] = []; + let validationLog: ValidationLog = { extracted: [] }; const startTime = new Date(); if (openaiAPIKey && openaiAPIKey.length > 0) { @@ -276,10 +278,10 @@ export const zerox = async ({ }); } - const response = CompletionProcessor.process( - OperationMode.OCR, - rawResponse - ); + const response = CompletionProcessor.process({ + mode: OperationMode.OCR, + response: rawResponse as CompletionResponse, + }); inputTokenCount += response.inputTokens; outputTokenCount += response.outputTokens; @@ -361,6 +363,7 @@ export const zerox = async ({ schema: Record ): Promise> => { let result: Record = {}; + let validationResult: Record | null = null; try { await runRetries( async () => { @@ -381,16 +384,19 @@ export const zerox = async ({ }); } - const response = CompletionProcessor.process( - OperationMode.EXTRACTION, - rawResponse - ); + const response = CompletionProcessor.process({ + mode: OperationMode.EXTRACTION, + response: rawResponse as ExtractionResponse, + schema, + }); inputTokenCount += response.inputTokens; outputTokenCount += response.outputTokens; numSuccessfulExtractionRequests++; - + if (response.issues && response.issues.length > 0) { + validationResult = { page: pageNumber, issues: response.issues }; + } for (const key of Object.keys(schema?.properties ?? {})) { const value = response.extracted[key]; if (value !== null && value !== undefined) { @@ -409,7 +415,7 @@ export const zerox = async ({ throw error; } - return result; + return { result, validationResult }; }; if (perPageSchema) { @@ -438,6 +444,7 @@ export const zerox = async ({ extractionTasks.push( (async () => { let result: Record = {}; + let validationResult: Record | null = null; try { await runRetries( async () => { @@ -459,20 +466,25 @@ export const zerox = async ({ }); } - const response = CompletionProcessor.process( - OperationMode.EXTRACTION, - rawResponse - ); + const response = CompletionProcessor.process({ + mode: OperationMode.EXTRACTION, + response: rawResponse as ExtractionResponse, + schema, + }); inputTokenCount += response.inputTokens; outputTokenCount += response.outputTokens; numSuccessfulExtractionRequests++; + if (response.issues && response.issues.length > 0) { + validationResult = { page: null, issues: response.issues }; + } + result = response.extracted; }, maxRetries, 0 ); - return result; + return { result, validationResult }; } catch (error) { numFailedExtractionRequests++; throw error; @@ -482,8 +494,12 @@ export const zerox = async ({ } const results = await Promise.all(extractionTasks); - extracted = results.reduce((acc, result) => { - Object.entries(result || {}).forEach(([key, value]) => { + validationLog.extracted = results.reduce( + (acc, result) => (result.validationResult ? [...acc, result.validationResult] : acc), + [] + ); + extracted = results.reduce((acc, resultObj) => { + Object.entries(resultObj?.result || {}).forEach(([key, value]) => { if (!acc[key]) { acc[key] = []; } @@ -573,6 +589,7 @@ export const zerox = async ({ } : null, }, + validationLog, }; } finally { if (correctOrientation && scheduler) { diff --git a/node-zerox/src/types.ts b/node-zerox/src/types.ts index 3d72836d..eef2aa50 100644 --- a/node-zerox/src/types.ts +++ b/node-zerox/src/types.ts @@ -49,6 +49,7 @@ export interface ZeroxOutput { outputTokens: number; pages: Page[]; summary: Summary; + validationLog: ValidationLog; } export interface AzureCredentials { @@ -177,7 +178,11 @@ export interface ExtractionResponse { outputTokens: number; } -export type ProcessedExtractionResponse = Omit; +// export type ProcessedExtractionResponse = Omit; + +export interface ProcessedExtractionResponse extends Omit { + issues: any; +} interface BaseLLMParams { frequencyPenalty?: number; @@ -254,3 +259,24 @@ export interface ExcelSheetContent { contentLength: number; sheetName: string; } + +// Define extraction-specific parameters +export interface ExtractionProcessParams { + mode: OperationMode.EXTRACTION; + response: ExtractionResponse; + schema: Record; +} + +// Define OCR-specific parameters +export interface CompletionProcessParams { + mode: OperationMode.OCR; + response: CompletionResponse; + schema?: undefined; +} + +// Union type for all possible parameter combinations +export type ProcessParams = ExtractionProcessParams | CompletionProcessParams; + +export interface ValidationLog { + extracted: { page: number | null, issues: any }[]; +} \ No newline at end of file diff --git a/node-zerox/src/utils/common.ts b/node-zerox/src/utils/common.ts index de6b6529..78b5bf9f 100644 --- a/node-zerox/src/utils/common.ts +++ b/node-zerox/src/utils/common.ts @@ -1,3 +1,5 @@ +import { parse } from "flatted"; + export const camelToSnakeCase = (str: string) => str.replace(/[A-Z]/g, (letter: string) => `_${letter.toLowerCase()}`); @@ -119,3 +121,17 @@ export const splitSchema = ( : null, }; }; + +export const formatJsonValue = ( + value: any, + useFlatted: boolean = false +): any => { + if (typeof value === "string") { + try { + return useFlatted ? parse(value) : JSON.parse(value); + } catch { + return value; + } + } + return value; +}; diff --git a/node-zerox/src/utils/fixSchemaValidationErrors.ts b/node-zerox/src/utils/fixSchemaValidationErrors.ts new file mode 100644 index 00000000..4d6bdd2f --- /dev/null +++ b/node-zerox/src/utils/fixSchemaValidationErrors.ts @@ -0,0 +1,114 @@ +import { formatJsonValue } from "../utils"; +import { JSONSchema, JSONSchemaDefinition } from "openai/lib/jsonschema"; +import { ZodError } from "zod"; + +/** + * Handles specific cases of ZodError by traversing the + * error paths in the original value and modifying invalid entries (e.g., replacing + * invalid enum values with null or converting strings to booleans or numbers) + * + * Handled cases: + * - Boolean strings ("true" or "false") should be converted to actual booleans + * - Numeric strings (e.g., "123") should be converted to numbers + * - For other cases, default to the default value or null + * + * @param {ZodError} err - The error object containing validation details + * @returns {any} - The modified value object with resolved issues + */ +export const fixSchemaValidationErrors = ({ + err, + schema, + value: originalValue, +}: { + err: ZodError>; + schema: JSONSchema; + value: Record; +}) => { + const errors = err.issues; + let value = originalValue; + + errors.forEach((error) => { + const lastKey = error.path[error.path.length - 1]; + + let parent = value; + for (let i = 0; i < error.path.length - 1; i++) { + parent = parent?.[error.path[i]]; + } + + let defaultValue = null; + if (schema) { + let schemaProperty = schema; + let properties: JSONSchema | JSONSchemaDefinition[] = + schemaProperty.properties || schemaProperty; + + for (let i = 0; i < error.path.length; i++) { + const pathKey = error.path[i]; + if (properties && properties[pathKey as keyof typeof properties]) { + schemaProperty = properties[ + pathKey as keyof typeof properties + ] as JSONSchema; + if (schemaProperty.type === "array" && schemaProperty.items) { + // If array of object (table) + if ((schemaProperty.items as JSONSchema).type === "object") { + properties = (schemaProperty.items as JSONSchema).properties || {}; + i++; // Skip the numeric path (row index) + } else { + properties = schemaProperty.items as JSONSchema; + } + } else { + properties = schemaProperty.properties || {}; + } + } + } + + if (schemaProperty && "default" in schemaProperty) { + defaultValue = schemaProperty.default; + } + } + + if (parent && typeof parent === "object") { + const currentValue = parent[lastKey]; + + if ( + error.code === "invalid_type" && + error.expected === "boolean" && + error.received === "string" && + (currentValue === "true" || currentValue === "false") + ) { + parent[lastKey] = currentValue === "true"; + } else if ( + error.code === "invalid_type" && + error.expected === "number" && + error.received === "string" && + !isNaN(Number(currentValue)) + ) { + parent[lastKey] = Number(currentValue); + } else if ( + error.code === "invalid_type" && + error.expected === "array" && + error.received === "string" + ) { + // TODO: could this be problematic? no check if the parsed array conformed to the schema + const value = formatJsonValue(currentValue); + if (Array.isArray(value)) { + parent[lastKey] = value; + } + } else if ( + error.code === "invalid_type" && + (error.expected === "array" || + error.expected === "boolean" || + error.expected === "integer" || + error.expected === "number" || + error.expected === "string" || + error.expected.includes(" | ")) && // `Expected` for enums comes back as z.enum(['a', 'b']) => expected: "'a' | 'b'" + currentValue === undefined + ) { + parent[lastKey] = defaultValue !== null ? defaultValue : null; + } else { + parent[lastKey] = defaultValue !== null ? defaultValue : null; + } + } + }); + + return value; +}; diff --git a/node-zerox/src/utils/model.ts b/node-zerox/src/utils/model.ts index 7cc0f176..de6ba555 100644 --- a/node-zerox/src/utils/model.ts +++ b/node-zerox/src/utils/model.ts @@ -1,13 +1,29 @@ import { + CompletionProcessParams, CompletionResponse, + ExtractionProcessParams, ExtractionResponse, LLMParams, ModelProvider, OperationMode, ProcessedCompletionResponse, ProcessedExtractionResponse, + ProcessParams, } from "../types"; import { formatMarkdown } from "./common"; +import { validate } from "./validate"; + +const isExtractionParams = ( + params: ProcessParams +): params is ExtractionProcessParams => { + return params.mode === OperationMode.EXTRACTION; +}; + +const isCompletionParams = ( + params: ProcessParams +): params is CompletionProcessParams => { + return params.mode === OperationMode.OCR; +}; export const isCompletionResponse = ( mode: OperationMode, @@ -16,45 +32,43 @@ export const isCompletionResponse = ( return mode === OperationMode.OCR; }; -const isExtractionResponse = ( - mode: OperationMode, - response: CompletionResponse | ExtractionResponse -): response is ExtractionResponse => { - return mode === OperationMode.EXTRACTION; -}; - export class CompletionProcessor { - static process( - mode: T, - response: CompletionResponse | ExtractionResponse - ): T extends OperationMode.EXTRACTION - ? ProcessedExtractionResponse - : ProcessedCompletionResponse { - const { logprobs, ...responseWithoutLogprobs } = response; - if (isCompletionResponse(mode, response)) { + // Overload for extraction mode + static process(params: ExtractionProcessParams): ProcessedExtractionResponse; + + // Overload for OCR mode + static process(params: CompletionProcessParams): ProcessedCompletionResponse; + + static process( + params: ProcessParams + ): ProcessedExtractionResponse | ProcessedCompletionResponse { + if (isCompletionParams(params)) { + const { response } = params; + const { logprobs, ...responseWithoutLogprobs } = response; + const content = response.content; return { ...responseWithoutLogprobs, content: typeof content === "string" ? formatMarkdown(content) : content, contentLength: response.content?.length || 0, - } as T extends OperationMode.EXTRACTION - ? ProcessedExtractionResponse - : ProcessedCompletionResponse; + } as ProcessedCompletionResponse; } - if (isExtractionResponse(mode, response)) { - const extracted = response.extracted; + if (isExtractionParams(params)) { + const { response, schema } = params; + const { logprobs, ...responseWithoutLogprobs } = response; + const extracted = + typeof response.extracted === "object" + ? response.extracted + : JSON.parse(response.extracted); + const result = validate({ schema, value: extracted }); return { ...responseWithoutLogprobs, - extracted: - typeof extracted === "object" ? extracted : JSON.parse(extracted), - } as T extends OperationMode.EXTRACTION - ? ProcessedExtractionResponse - : ProcessedCompletionResponse; + extracted: result.value, + issues: result.issues, + } as ProcessedExtractionResponse; } - return responseWithoutLogprobs as T extends OperationMode.EXTRACTION - ? ProcessedExtractionResponse - : ProcessedCompletionResponse; + throw new Error(`Unsupported operation mode: ${params["mode"]}`); } } diff --git a/node-zerox/src/utils/validate.ts b/node-zerox/src/utils/validate.ts new file mode 100644 index 00000000..48c37443 --- /dev/null +++ b/node-zerox/src/utils/validate.ts @@ -0,0 +1,67 @@ +import { z } from "zod"; +import { fixSchemaValidationErrors } from "./fixSchemaValidationErrors"; + +const zodTypeMapping = { + array: (itemSchema: any) => z.array(itemSchema), + boolean: z.boolean(), + integer: z.number().int(), + number: z.number(), + object: (properties: any) => z.object(properties).strict(), + string: z.string(), +}; + +export const generateZodSchema = (schemaDef: any): z.ZodObject => { + const properties: Record = {}; + + for (const [key, value] of Object.entries(schemaDef.properties) as any) { + let zodType; + + if (value.enum && Array.isArray(value.enum) && value.enum.length > 0) { + zodType = z.enum(value.enum as [string, ...string[]]); + } else { + // @ts-ignore + zodType = zodTypeMapping[value.type]; + } + + if (value.type === "array" && value.items.type === "object") { + properties[key] = zodType(generateZodSchema(value.items)); + } else if (value.type === "array" && value.items.type !== "object") { + // @ts-ignore + properties[key] = zodType(zodTypeMapping[value.items.type]); + } else if (value.type === "object") { + properties[key] = generateZodSchema(value); + } else { + properties[key] = zodType; + } + + // Make properties nullable by default + properties[key] = properties?.[key]?.nullable(); + + if (value.description) { + properties[key] = properties?.[key]?.describe(value?.description); + } + } + + return z.object(properties).strict(); +}; + +export const validate = ({ + schema, + value, +}: { + schema: Record; + value: unknown; +}) => { + const zodSchema = generateZodSchema(schema); + + const result = zodSchema.safeParse(value); + if (result.success) return { value: result.data, issues: [] }; + + const fixedData = fixSchemaValidationErrors({ + err: result.error, + schema, + value: value as Record, + }); + + return { issues: result.error.issues, value: fixedData }; +}; diff --git a/package-lock.json b/package-lock.json index a4992e5a..a6f0a449 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "@google/generative-ai": "^0.21.0", "axios": "^1.7.2", "child_process": "^1.0.2", + "flatted": "^3.3.1", "fs-extra": "^11.2.0", "heic-convert": "^2.1.0", "libreoffice-convert": "^1.6.0", @@ -27,7 +28,8 @@ "tesseract.js": "^5.1.1", "util": "^0.12.5", "uuid": "^11.0.3", - "xlsx": "^0.18.5" + "xlsx": "^0.18.5", + "zod": "^3.23.8" }, "devDependencies": { "@types/fs-extra": "^11.0.4", @@ -3986,6 +3988,11 @@ "node": ">=8" } }, + "node_modules/flatted": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.1.tgz", + "integrity": "sha512-X8cqMLLie7KsNUDSdzeN8FYK9rEt4Dt67OsG/DNGnYTSDBG4uFAJFBnUeiV+zCVAvwFy56IjM9sH51jVaEhNxw==" + }, "node_modules/follow-redirects": { "version": "1.15.6", "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", @@ -6804,6 +6811,14 @@ "engines": { "node": "*" } + }, + "node_modules/zod": { + "version": "3.23.8", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.23.8.tgz", + "integrity": "sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } } } } diff --git a/package.json b/package.json index e850420a..0ae27724 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ "@google/generative-ai": "^0.21.0", "axios": "^1.7.2", "child_process": "^1.0.2", + "flatted": "^3.3.1", "fs-extra": "^11.2.0", "heic-convert": "^2.1.0", "libreoffice-convert": "^1.6.0", @@ -31,7 +32,8 @@ "tesseract.js": "^5.1.1", "util": "^0.12.5", "uuid": "^11.0.3", - "xlsx": "^0.18.5" + "xlsx": "^0.18.5", + "zod": "^3.23.8" }, "devDependencies": { "@types/fs-extra": "^11.0.4",