|
| 1 | +"""Metadata Extractor: Extracts structured metadata from text based on a JSON schema.""" |
| 2 | + |
| 3 | +from typing import Any, Dict, List, Literal, Optional |
| 4 | + |
| 5 | +from langchain_core.documents import Document |
| 6 | +from langchain_core.language_models import BaseChatModel |
| 7 | +from langchain_core.prompts import ChatPromptTemplate |
| 8 | +from pydantic import Field, create_model |
| 9 | +from tqdm import tqdm |
| 10 | + |
| 11 | + |
| 12 | +class MetadataExtractor: |
| 13 | + """Extracts structured metadata from text using a language model and a JSON schema. |
| 14 | +
|
| 15 | + This class leverages a function-calling or structured-output-capable LLM to extract |
| 16 | + metadata fields defined by a user-provided JSON schema. The schema is dynamically |
| 17 | + converted into a Pydantic model, which is used to enforce type safety and |
| 18 | + validation. Only fields that are successfully extracted (i.e., not null) are |
| 19 | + included in the result. |
| 20 | +
|
| 21 | + This metadata can be used for filtering search results, or in the case of queries, |
| 22 | + it can be used to filter out documents that are not relevant to the query. |
| 23 | +
|
| 24 | + Attributes: |
| 25 | + llm (BaseChatModel): The language model used for metadata extraction. Must |
| 26 | + support structured output (e.g., via `.with_structured_output()`). |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, llm: BaseChatModel) -> None: |
| 30 | + """Initializes the MetadataExtractor with a language model. |
| 31 | +
|
| 32 | + Args: |
| 33 | + llm (BaseChatModel): A LangChain-compatible chat model that supports |
| 34 | + structured output (e.g., ChatGroq, ChatOpenAI). The model should be |
| 35 | + capable of returning Pydantic-model-compliant responses. |
| 36 | + """ |
| 37 | + self.llm = llm |
| 38 | + |
| 39 | + def invoke(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]: |
| 40 | + """Extracts metadata from the given text according to the provided JSON schema. |
| 41 | +
|
| 42 | + The method constructs a dynamic Pydantic model from the schema, instructs the |
| 43 | + LLM to extract only explicitly stated information, and returns a dictionary |
| 44 | + containing only the successfully extracted (non-null) fields. |
| 45 | +
|
| 46 | + Args: |
| 47 | + text (str): The input text from which metadata should be extracted. |
| 48 | + schema (Dict[str, Any]): A JSON schema defining the expected metadata |
| 49 | + structure. Must contain a top-level "properties" key. |
| 50 | + Each property may specify: |
| 51 | + - "type": one of "string", "number", or "boolean" |
| 52 | + - "enum": (optional) a list of allowed values (for string fields) |
| 53 | + - "description": (optional) a field description used in the prompt |
| 54 | +
|
| 55 | + Example: |
| 56 | + { |
| 57 | + "properties": { |
| 58 | + "movie_title": {"type": "string"}, |
| 59 | + "rating": {"type": "number"}, |
| 60 | + "is_positive": {"type": "boolean"}, |
| 61 | + "tone": {"type": "string", "enum": ["positive", "negative"]} |
| 62 | + } |
| 63 | + } |
| 64 | +
|
| 65 | + Returns: |
| 66 | + Dict[str, Any]: A dictionary of extracted metadata. Only fields that were |
| 67 | + present and non-null in the LLM's response are included. Fields that |
| 68 | + could not be extracted are omitted entirely. |
| 69 | +
|
| 70 | + Raises: |
| 71 | + ValueError: If the provided schema does not contain a "properties" key. |
| 72 | + ValueError: If any property type is not string, number, or boolean. |
| 73 | +
|
| 74 | + Note: |
| 75 | + The LLM is explicitly instructed not to hallucinate or use placeholder |
| 76 | + values (e.g., "Unknown"). Missing fields are returned as null by the model |
| 77 | + and then excluded from the final result. |
| 78 | + """ |
| 79 | + if "properties" not in schema: |
| 80 | + raise ValueError("Schema must contain a 'properties' key.") |
| 81 | + |
| 82 | + properties = schema.get("properties", {}) |
| 83 | + |
| 84 | + # Validate and filter properties to only allow string, number, and boolean types |
| 85 | + allowed_types = {"string", "number", "boolean"} |
| 86 | + validated_properties = {} |
| 87 | + |
| 88 | + for name, spec in properties.items(): |
| 89 | + json_type = spec.get("type", "string") |
| 90 | + if json_type not in allowed_types: |
| 91 | + raise ValueError( |
| 92 | + f"Unsupported type '{json_type}' for field '{name}'. " |
| 93 | + f"Only {allowed_types} are allowed." |
| 94 | + ) |
| 95 | + validated_properties[name] = spec |
| 96 | + |
| 97 | + # Build dynamic Pydantic model — all fields optional to allow None |
| 98 | + field_definitions: Dict[str, Any] = {} |
| 99 | + for name, spec in validated_properties.items(): |
| 100 | + json_type = spec.get("type", "string") |
| 101 | + description = spec.get("description", "") |
| 102 | + enum_vals = spec.get("enum") |
| 103 | + |
| 104 | + type_map = { |
| 105 | + "string": str, |
| 106 | + "number": float, |
| 107 | + "boolean": bool, |
| 108 | + } |
| 109 | + py_type = type_map.get(json_type, str) |
| 110 | + |
| 111 | + if enum_vals is not None: |
| 112 | + # Only allow enum for string types |
| 113 | + if json_type == "string": |
| 114 | + py_type = Literal[tuple(enum_vals)] # type: ignore |
| 115 | + else: |
| 116 | + raise ValueError( |
| 117 | + f"Enum is only supported for string fields. " |
| 118 | + f"Field '{name}' has type '{json_type}'." |
| 119 | + ) |
| 120 | + |
| 121 | + # Make every field optional so model can return null |
| 122 | + py_type = Optional[py_type] # type: ignore[assignment] |
| 123 | + field_definitions[name] = ( |
| 124 | + py_type, |
| 125 | + Field(default=None, description=description), |
| 126 | + ) |
| 127 | + |
| 128 | + dynamic_pydantic_model = create_model("ExtractedMetadata", **field_definitions) |
| 129 | + |
| 130 | + # Clear, strict prompt |
| 131 | + prompt = ChatPromptTemplate.from_messages( |
| 132 | + [ |
| 133 | + ( |
| 134 | + "system", |
| 135 | + "You are a precise metadata extraction system. " |
| 136 | + "Extract only the fields specified in the schema from the input text. " |
| 137 | + "For any field not explicitly mentioned in the text, return null for that field. " |
| 138 | + "Do NOT use placeholders like 'Unknown', 'N/A', 'Not specified', or make up values. " |
| 139 | + "Only use facts directly stated in the input.", |
| 140 | + ), |
| 141 | + ("human", "{input}"), |
| 142 | + ] |
| 143 | + ) |
| 144 | + |
| 145 | + structured_llm = self.llm.with_structured_output(dynamic_pydantic_model) |
| 146 | + chain = prompt | structured_llm |
| 147 | + |
| 148 | + try: |
| 149 | + result_obj = chain.invoke({"input": text}) |
| 150 | + result: Dict[str, Any] = {} |
| 151 | + for field in validated_properties: |
| 152 | + value = getattr(result_obj, field, None) |
| 153 | + # Only include field if it's not None |
| 154 | + if value is not None: |
| 155 | + result[field] = value |
| 156 | + return result |
| 157 | + except Exception: |
| 158 | + # On any error (e.g., parsing, validation, LLM failure), return empty dict |
| 159 | + return {} |
| 160 | + |
| 161 | + def transform_documents( |
| 162 | + self, documents: List[Document], schema: Dict[str, Any] |
| 163 | + ) -> List[Document]: |
| 164 | + """Applies metadata extraction to a list of LangChain Documents. |
| 165 | +
|
| 166 | + For each document, metadata is extracted from its `page_content` using the |
| 167 | + provided schema and merged into the document's existing metadata. The |
| 168 | + original document content is preserved. |
| 169 | +
|
| 170 | + Args: |
| 171 | + documents (List[Document]): A list of LangChain Document objects to process. |
| 172 | + schema (Dict[str, Any]): A JSON schema defining the metadata structure to |
| 173 | + extract. See `invoke()` for schema format details. |
| 174 | +
|
| 175 | + Returns: |
| 176 | + List[Document]: A new list of Document objects with enriched metadata. |
| 177 | + Each document's `metadata` field is updated with the extracted fields |
| 178 | + (excluding any that were null or missing). |
| 179 | + """ |
| 180 | + transformed_documents: List[Document] = [] |
| 181 | + |
| 182 | + for doc in tqdm(documents): |
| 183 | + extracted = self.invoke(doc.page_content, schema) |
| 184 | + new_metadata = {**doc.metadata, **extracted} |
| 185 | + transformed_documents.append( |
| 186 | + Document(page_content=doc.page_content, metadata=new_metadata) |
| 187 | + ) |
| 188 | + |
| 189 | + return transformed_documents |
0 commit comments