diff --git a/docling/backend/code_backend.py b/docling/backend/code_backend.py new file mode 100644 index 000000000..7737f53d5 --- /dev/null +++ b/docling/backend/code_backend.py @@ -0,0 +1,73 @@ +from typing import Union, Set +from pathlib import Path +from io import BytesIO + +from docling_core.types.doc import DoclingDocument, DocumentOrigin +from docling.backend.abstract_backend import DeclarativeDocumentBackend +from docling.datamodel.base_models import InputFormat + +class CodeFileBackend(DeclarativeDocumentBackend): + + LANGUAGE_MAPPINGS = { + '.py': 'python', + '.js': 'javascript', + '.java': 'java' + } + + def __init__(self, in_doc, path_or_stream: Union[BytesIO, Path]): + super().__init__(in_doc, path_or_stream) + self.path_or_stream = path_or_stream + self.valid = True + + try: + if isinstance(self.path_or_stream, BytesIO): + self.source_code = self.path_or_stream.getvalue().decode("utf-8") + if isinstance(self.path_or_stream, Path): + with open(self.path_or_stream, encoding="utf-8") as f: + self.source_code = f.read() + + self.language = self._detect_language() + + except Exception as e: + raise RuntimeError( + f"Could not initialize code backend for file with hash {self.document_hash}." + ) from e + + def _detect_language(self) -> str: + """Detect programming language from file extension.""" + file_ext = self.file.suffix.lower() + return self.LANGUAGE_MAPPINGS.get(file_ext, 'text') + + @classmethod + def supported_formats(cls) -> Set[InputFormat]: + return { + InputFormat.CODE_PYTHON, + InputFormat.CODE_JAVASCRIPT, + InputFormat.CODE_JAVA + } + + @classmethod + def supports_pagination(cls) -> bool: + return False + + def is_valid(self) -> bool: + return self.valid + + def convert(self) -> DoclingDocument: + mime_type = f"text/x-{self.language}-source" if self.language in ['java'] else f"text/x-{self.language}" + + origin = DocumentOrigin( + filename=self.file.name or f"file{self.file.suffix}", + mimetype=mime_type, + binary_hash=self.document_hash, + ) + + doc = DoclingDocument(name=self.file.stem or "file", origin=origin) + + if self.is_valid(): + doc.add_code( + text=self.source_code, + code_language=self.language + ) + + return doc \ No newline at end of file diff --git a/docling/chunking/__init__.py b/docling/chunking/__init__.py index e72deb971..4b01fb8a1 100644 --- a/docling/chunking/__init__.py +++ b/docling/chunking/__init__.py @@ -4,9 +4,15 @@ # from docling_core.transforms.chunker.base import BaseChunk, BaseChunker, BaseMeta +from docling_core.transforms.chunker.base_code_chunker import CodeChunk from docling_core.transforms.chunker.hierarchical_chunker import ( DocChunk, DocMeta, HierarchicalChunker, ) from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from docling_core.transforms.chunker.language_code_chunkers import ( + JavaFunctionChunker, + JavaScriptFunctionChunker, + PythonFunctionChunker, +) diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 627ecf5f7..52067f35f 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -69,6 +69,10 @@ class InputFormat(str, Enum): JSON_DOCLING = "json_docling" AUDIO = "audio" VTT = "vtt" + CODE_PYTHON = "py" + CODE_JAVASCRIPT = "js" + CODE_JAVA = "java" + class OutputFormat(str, Enum): @@ -96,6 +100,9 @@ class OutputFormat(str, Enum): InputFormat.JSON_DOCLING: ["json"], InputFormat.AUDIO: ["wav", "mp3"], InputFormat.VTT: ["vtt"], + InputFormat.CODE_PYTHON: ["py"], + InputFormat.CODE_JAVASCRIPT: ["js"], + InputFormat.CODE_JAVA: ["java"], } FormatToMimeType: dict[InputFormat, list[str]] = { @@ -130,6 +137,9 @@ class OutputFormat(str, Enum): InputFormat.JSON_DOCLING: ["application/json"], InputFormat.AUDIO: ["audio/x-wav", "audio/mpeg", "audio/wav", "audio/mp3"], InputFormat.VTT: ["text/vtt"], + InputFormat.CODE_PYTHON: ["text/x-python"], + InputFormat.CODE_JAVASCRIPT: ["text/x-javascript"], + InputFormat.CODE_JAVA: ["text/x-java-source"], } MimeTypeToFormat: dict[str, list[InputFormat]] = { diff --git a/docling/datamodel/document.py b/docling/datamodel/document.py index 8ea454826..772d1dc43 100644 --- a/docling/datamodel/document.py +++ b/docling/datamodel/document.py @@ -396,6 +396,12 @@ def _mime_from_extension(ext): mime = FormatToMimeType[InputFormat.XLSX][0] elif ext in FormatToExtensions[InputFormat.VTT]: mime = FormatToMimeType[InputFormat.VTT][0] + elif ext in FormatToExtensions[InputFormat.CODE_JAVA]: + mime = FormatToMimeType[InputFormat.CODE_JAVA][0] + elif ext in FormatToExtensions[InputFormat.CODE_PYTHON]: + mime = FormatToMimeType[InputFormat.CODE_PYTHON][0] + elif ext in FormatToExtensions[InputFormat.CODE_JAVASCRIPT]: + mime = FormatToMimeType[InputFormat.CODE_JAVASCRIPT][0] return mime diff --git a/docling/document_converter.py b/docling/document_converter.py index 5d64d6336..93e2f0ff2 100644 --- a/docling/document_converter.py +++ b/docling/document_converter.py @@ -15,6 +15,8 @@ from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.asciidoc_backend import AsciiDocBackend +from docling.backend.code_backend import CodeFileBackend +from docling.backend.csv_backend import CsvDocumentBackend from docling.backend.csv_backend import CsvDocumentBackend from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend from docling.backend.html_backend import HTMLDocumentBackend @@ -74,6 +76,11 @@ class CsvFormatOption(FormatOption): backend: Type[AbstractDocumentBackend] = CsvDocumentBackend +class CodeFormatOption(FormatOption): + pipeline_cls: Type = SimplePipeline + backend: Type[AbstractDocumentBackend] = CodeFileBackend + + class ExcelFormatOption(FormatOption): pipeline_cls: Type = SimplePipeline backend: Type[AbstractDocumentBackend] = MsExcelDocumentBackend @@ -174,6 +181,15 @@ def _get_default_option(format: InputFormat) -> FormatOption: InputFormat.VTT: FormatOption( pipeline_cls=SimplePipeline, backend=WebVTTDocumentBackend ), + InputFormat.CODE_JAVA: FormatOption( + pipeline_cls=SimplePipeline, backend=CodeFileBackend + ), + InputFormat.CODE_PYTHON: FormatOption( + pipeline_cls=SimplePipeline, backend=CodeFileBackend + ), + InputFormat.CODE_JAVASCRIPT: FormatOption( + pipeline_cls=SimplePipeline, backend=CodeFileBackend + ), } if (options := format_to_default_options.get(format)) is not None: return options diff --git a/tests/data/java/FlightLoader.java b/tests/data/java/FlightLoader.java new file mode 100644 index 000000000..9b2a1adc5 --- /dev/null +++ b/tests/data/java/FlightLoader.java @@ -0,0 +1,133 @@ +/******************************************************************************* +* Copyright (c) 2013 IBM Corp. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +package com.acmeair.loader; + +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.LineNumberReader; +import java.util.*; +import java.math.*; + +import com.acmeair.entities.AirportCodeMapping; +import com.acmeair.service.FlightService; +import com.acmeair.service.ServiceLocator; + + + + +public class FlightLoader { + + private static final int MAX_FLIGHTS_PER_SEGMENT = 30; + + + private FlightService flightService = ServiceLocator.instance().getService(FlightService.class); + + public void loadFlights() throws Exception { + InputStream csvInputStream = FlightLoader.class.getResourceAsStream("/mileage.csv"); + + LineNumberReader lnr = new LineNumberReader(new InputStreamReader(csvInputStream)); + String line1 = lnr.readLine(); + StringTokenizer st = new StringTokenizer(line1, ","); + ArrayList airports = new ArrayList(); + + // read the first line which are airport names + while (st.hasMoreTokens()) { + AirportCodeMapping acm = flightService.createAirportCodeMapping(null, st.nextToken()); + // acm.setAirportName(st.nextToken()); + airports.add(acm); + } + // read the second line which contains matching airport codes for the first line + String line2 = lnr.readLine(); + st = new StringTokenizer(line2, ","); + int ii = 0; + while (st.hasMoreTokens()) { + String airportCode = st.nextToken(); + airports.get(ii).setAirportCode(airportCode); + ii++; + } + // read the other lines which are of format: + // airport name, aiport code, distance from this airport to whatever airport is in the column from lines one and two + String line; + int flightNumber = 0; + while (true) { + line = lnr.readLine(); + if (line == null || line.trim().equals("")) { + break; + } + st = new StringTokenizer(line, ","); + String airportName = st.nextToken(); + String airportCode = st.nextToken(); + if (!alreadyInCollection(airportCode, airports)) { + AirportCodeMapping acm = flightService.createAirportCodeMapping(airportCode, airportName); + airports.add(acm); + } + int indexIntoTopLine = 0; + while (st.hasMoreTokens()) { + String milesString = st.nextToken(); + if (milesString.equals("NA")) { + indexIntoTopLine++; + continue; + } + int miles = Integer.parseInt(milesString); + String toAirport = airports.get(indexIntoTopLine).getAirportCode(); + String flightId = "AA" + flightNumber; + flightService.storeFlightSegment(flightId, airportCode, toAirport, miles); + Date now = new Date(); + for (int daysFromNow = 0; daysFromNow < MAX_FLIGHTS_PER_SEGMENT; daysFromNow++) { + Calendar c = Calendar.getInstance(); + c.setTime(now); + c.set(Calendar.HOUR_OF_DAY, 0); + c.set(Calendar.MINUTE, 0); + c.set(Calendar.SECOND, 0); + c.set(Calendar.MILLISECOND, 0); + c.add(Calendar.DATE, daysFromNow); + Date departureTime = c.getTime(); + Date arrivalTime = getArrivalTime(departureTime, miles); + flightService.createNewFlight(flightId, departureTime, arrivalTime, new BigDecimal(500), new BigDecimal(200), 10, 200, "B747"); + + } + flightNumber++; + indexIntoTopLine++; + } + } + + for (int jj = 0; jj < airports.size(); jj++) { + flightService.storeAirportMapping(airports.get(jj)); + } + lnr.close(); + } + + private static Date getArrivalTime(Date departureTime, int mileage) { + double averageSpeed = 600.0; // 600 miles/hours + double hours = (double) mileage / averageSpeed; // miles / miles/hour = hours + double partsOfHour = hours % 1.0; + int minutes = (int)(60.0 * partsOfHour); + Calendar c = Calendar.getInstance(); + c.setTime(departureTime); + c.add(Calendar.HOUR, (int)hours); + c.add(Calendar.MINUTE, minutes); + return c.getTime(); + } + + static private boolean alreadyInCollection(String airportCode, ArrayList airports) { + for (int ii = 0; ii < airports.size(); ii++) { + if (airports.get(ii).getAirportCode().equals(airportCode)) { + return true; + } + } + return false; + } +} diff --git a/tests/test_code_chunking.py b/tests/test_code_chunking.py new file mode 100644 index 000000000..993efa6f5 --- /dev/null +++ b/tests/test_code_chunking.py @@ -0,0 +1,129 @@ +from pathlib import Path + +import pytest + +from docling_core.transforms.chunker.language_code_chunkers import JavaFunctionChunker +from docling.document_converter import DocumentConverter + + +def test_java_function_chunking(): + """Test Java function chunking with DocumentConverter.""" + + source = Path(__file__).parent / "data" / "java" / "FlightLoader.java" + + if not source.exists(): + pytest.skip(f"Test file not found at {source}") + + converter = DocumentConverter() + result = converter.convert(source) + + assert result.status.value == "success", f"Conversion failed: {result.errors}" + assert result.document is not None, "Document should not be None" + + doc = result.document + + assert doc.texts, "Document should have text content" + code_texts = [text for text in doc.texts if text.label.value == "code"] + assert code_texts, "Document should have code content" + + chunker = JavaFunctionChunker(max_tokens=5000) + chunk_iter = chunker.chunk(dl_doc=doc) + + chunks = list(chunk_iter) + + assert chunks, "Should produce at least one chunk" + + for i, chunk in enumerate(chunks): + assert chunk.text, f"Chunk {i} should have text content" + assert isinstance(chunk.text, str), f"Chunk {i} text should be a string" + assert len(chunk.text) > 0, f"Chunk {i} should have non-empty text" + + assert chunk.meta is not None, f"Chunk {i} should have metadata" + assert chunk.meta.part_name, f"Chunk {i} should have a part_name" + assert chunk.meta.start_line is not None, f"Chunk {i} should have start_line" + assert chunk.meta.end_line is not None, f"Chunk {i} should have end_line" + assert chunk.meta.sha256 is not None, f"Chunk {i} should have sha256 hash" + + assert chunk.meta.chunk_type in ["function", "class", "preamble"], \ + f"Chunk {i} should have a valid chunk_type" + + function_chunks = [chunk for chunk in chunks if chunk.meta.chunk_type == "function"] + assert function_chunks, "Should have at least one function chunk" + + java_keywords = ["public", "private", "class", "void", "return", "if", "for", "while"] + for chunk in chunks: + chunk_text_lower = chunk.text.lower() + assert any(keyword in chunk_text_lower for keyword in java_keywords), \ + f"Chunk should contain Java code: {chunk.text[:100]}..." + + +def test_java_function_chunking_deterministic(): + """Test that Java function chunking produces deterministic results.""" + source = Path(__file__).parent / "data" / "java" / "FlightLoader.java" + + if not source.exists(): + pytest.skip(f"Test file not found at {source}") + + converter = DocumentConverter() + chunker = JavaFunctionChunker(max_tokens=5000) + + results = [] + for _ in range(3): + result = converter.convert(source) + doc = result.document + chunk_iter = chunker.chunk(dl_doc=doc) + chunks = list(chunk_iter) + results.append(chunks) + + chunk_counts = [len(chunks) for chunks in results] + assert all(count == chunk_counts[0] for count in chunk_counts), \ + f"Chunk counts should be identical: {chunk_counts}" + + for i, chunks in enumerate(results[1:], 1): + assert len(chunks) == len(results[0]), f"Run {i} should have same number of chunks" + + for j, (chunk1, chunk2) in enumerate(zip(results[0], chunks)): + assert chunk1.text == chunk2.text, \ + f"Chunk {j} text should be identical between runs" + assert chunk1.meta.sha256 == chunk2.meta.sha256, \ + f"Chunk {j} sha256 should be identical between runs" + assert chunk1.meta.part_name == chunk2.meta.part_name, \ + f"Chunk {j} part_name should be identical between runs" + + +def test_java_function_chunking_with_different_max_tokens(): + """Test Java function chunking with different max_tokens settings.""" + source = Path(__file__).parent / "data" / "java" / "FlightLoader.java" + + if not source.exists(): + pytest.skip(f"Test file not found at {source}") + + converter = DocumentConverter() + result = converter.convert(source) + doc = result.document + + max_tokens_values = [1000, 5000, 10000] + chunk_counts = [] + + for max_tokens in max_tokens_values: + chunker = JavaFunctionChunker(max_tokens=max_tokens) + chunk_iter = chunker.chunk(dl_doc=doc) + chunks = list(chunk_iter) + chunk_counts.append(len(chunks)) + + + assert all(count > 0 for count in chunk_counts), \ + f"All max_tokens values should produce chunks: {chunk_counts}" + + for max_tokens, chunks in zip(max_tokens_values, [list(chunker.chunk(dl_doc=doc)) for chunker in [JavaFunctionChunker(max_tokens=mt) for mt in max_tokens_values]]): + for chunk in chunks: + + assert len(chunk.text) < max_tokens * 10, \ + f"Chunk text length should be reasonable for max_tokens={max_tokens}" + + +if __name__ == "__main__": + test_java_function_chunking() + test_java_function_chunking_deterministic() + test_java_function_chunking_with_different_max_tokens() + print("All tests passed!")