Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions docling/backend/code_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Union, Set
from pathlib import Path
from io import BytesIO

from docling_core.types.doc import DoclingDocument, DocumentOrigin
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat

class CodeFileBackend(DeclarativeDocumentBackend):

LANGUAGE_MAPPINGS = {
'.py': 'python',
'.js': 'javascript',
'.java': 'java'
}

def __init__(self, in_doc, path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
self.path_or_stream = path_or_stream
self.valid = True

try:
if isinstance(self.path_or_stream, BytesIO):
self.source_code = self.path_or_stream.getvalue().decode("utf-8")
if isinstance(self.path_or_stream, Path):
with open(self.path_or_stream, encoding="utf-8") as f:
self.source_code = f.read()

self.language = self._detect_language()

except Exception as e:
raise RuntimeError(
f"Could not initialize code backend for file with hash {self.document_hash}."
) from e

def _detect_language(self) -> str:
"""Detect programming language from file extension."""
file_ext = self.file.suffix.lower()
return self.LANGUAGE_MAPPINGS.get(file_ext, 'text')

@classmethod
def supported_formats(cls) -> Set[InputFormat]:
return {
InputFormat.CODE_PYTHON,
InputFormat.CODE_JAVASCRIPT,
InputFormat.CODE_JAVA
}

@classmethod
def supports_pagination(cls) -> bool:
return False

def is_valid(self) -> bool:
return self.valid

def convert(self) -> DoclingDocument:
mime_type = f"text/x-{self.language}-source" if self.language in ['java'] else f"text/x-{self.language}"

origin = DocumentOrigin(
filename=self.file.name or f"file{self.file.suffix}",
mimetype=mime_type,
binary_hash=self.document_hash,
)

doc = DoclingDocument(name=self.file.stem or "file", origin=origin)

if self.is_valid():
doc.add_code(
text=self.source_code,
code_language=self.language
)

return doc
6 changes: 6 additions & 0 deletions docling/chunking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
#

from docling_core.transforms.chunker.base import BaseChunk, BaseChunker, BaseMeta
from docling_core.transforms.chunker.base_code_chunker import CodeChunk
from docling_core.transforms.chunker.hierarchical_chunker import (
DocChunk,
DocMeta,
HierarchicalChunker,
)
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
from docling_core.transforms.chunker.language_code_chunkers import (
JavaFunctionChunker,
JavaScriptFunctionChunker,
PythonFunctionChunker,
)
10 changes: 10 additions & 0 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class InputFormat(str, Enum):
JSON_DOCLING = "json_docling"
AUDIO = "audio"
VTT = "vtt"
CODE_PYTHON = "py"
CODE_JAVASCRIPT = "js"
CODE_JAVA = "java"



class OutputFormat(str, Enum):
Expand Down Expand Up @@ -96,6 +100,9 @@ class OutputFormat(str, Enum):
InputFormat.JSON_DOCLING: ["json"],
InputFormat.AUDIO: ["wav", "mp3"],
InputFormat.VTT: ["vtt"],
InputFormat.CODE_PYTHON: ["py"],
InputFormat.CODE_JAVASCRIPT: ["js"],
InputFormat.CODE_JAVA: ["java"],
}

FormatToMimeType: dict[InputFormat, list[str]] = {
Expand Down Expand Up @@ -130,6 +137,9 @@ class OutputFormat(str, Enum):
InputFormat.JSON_DOCLING: ["application/json"],
InputFormat.AUDIO: ["audio/x-wav", "audio/mpeg", "audio/wav", "audio/mp3"],
InputFormat.VTT: ["text/vtt"],
InputFormat.CODE_PYTHON: ["text/x-python"],
InputFormat.CODE_JAVASCRIPT: ["text/x-javascript"],
InputFormat.CODE_JAVA: ["text/x-java-source"],
}

MimeTypeToFormat: dict[str, list[InputFormat]] = {
Expand Down
6 changes: 6 additions & 0 deletions docling/datamodel/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ def _mime_from_extension(ext):
mime = FormatToMimeType[InputFormat.XLSX][0]
elif ext in FormatToExtensions[InputFormat.VTT]:
mime = FormatToMimeType[InputFormat.VTT][0]
elif ext in FormatToExtensions[InputFormat.CODE_JAVA]:
mime = FormatToMimeType[InputFormat.CODE_JAVA][0]
elif ext in FormatToExtensions[InputFormat.CODE_PYTHON]:
mime = FormatToMimeType[InputFormat.CODE_PYTHON][0]
elif ext in FormatToExtensions[InputFormat.CODE_JAVASCRIPT]:
mime = FormatToMimeType[InputFormat.CODE_JAVASCRIPT][0]

return mime

Expand Down
16 changes: 16 additions & 0 deletions docling/document_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.asciidoc_backend import AsciiDocBackend
from docling.backend.code_backend import CodeFileBackend
from docling.backend.csv_backend import CsvDocumentBackend
from docling.backend.csv_backend import CsvDocumentBackend
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
Expand Down Expand Up @@ -74,6 +76,11 @@ class CsvFormatOption(FormatOption):
backend: Type[AbstractDocumentBackend] = CsvDocumentBackend


class CodeFormatOption(FormatOption):
pipeline_cls: Type = SimplePipeline
backend: Type[AbstractDocumentBackend] = CodeFileBackend


class ExcelFormatOption(FormatOption):
pipeline_cls: Type = SimplePipeline
backend: Type[AbstractDocumentBackend] = MsExcelDocumentBackend
Expand Down Expand Up @@ -174,6 +181,15 @@ def _get_default_option(format: InputFormat) -> FormatOption:
InputFormat.VTT: FormatOption(
pipeline_cls=SimplePipeline, backend=WebVTTDocumentBackend
),
InputFormat.CODE_JAVA: FormatOption(
pipeline_cls=SimplePipeline, backend=CodeFileBackend
),
InputFormat.CODE_PYTHON: FormatOption(
pipeline_cls=SimplePipeline, backend=CodeFileBackend
),
InputFormat.CODE_JAVASCRIPT: FormatOption(
pipeline_cls=SimplePipeline, backend=CodeFileBackend
),
}
if (options := format_to_default_options.get(format)) is not None:
return options
Expand Down
133 changes: 133 additions & 0 deletions tests/data/java/FlightLoader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2013 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package com.acmeair.loader;

import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.util.*;
import java.math.*;

import com.acmeair.entities.AirportCodeMapping;
import com.acmeair.service.FlightService;
import com.acmeair.service.ServiceLocator;




public class FlightLoader {

private static final int MAX_FLIGHTS_PER_SEGMENT = 30;


private FlightService flightService = ServiceLocator.instance().getService(FlightService.class);

public void loadFlights() throws Exception {
InputStream csvInputStream = FlightLoader.class.getResourceAsStream("/mileage.csv");

LineNumberReader lnr = new LineNumberReader(new InputStreamReader(csvInputStream));
String line1 = lnr.readLine();
StringTokenizer st = new StringTokenizer(line1, ",");
ArrayList<AirportCodeMapping> airports = new ArrayList<AirportCodeMapping>();

// read the first line which are airport names
while (st.hasMoreTokens()) {
AirportCodeMapping acm = flightService.createAirportCodeMapping(null, st.nextToken());
// acm.setAirportName(st.nextToken());
airports.add(acm);
}
// read the second line which contains matching airport codes for the first line
String line2 = lnr.readLine();
st = new StringTokenizer(line2, ",");
int ii = 0;
while (st.hasMoreTokens()) {
String airportCode = st.nextToken();
airports.get(ii).setAirportCode(airportCode);
ii++;
}
// read the other lines which are of format:
// airport name, aiport code, distance from this airport to whatever airport is in the column from lines one and two
String line;
int flightNumber = 0;
while (true) {
line = lnr.readLine();
if (line == null || line.trim().equals("")) {
break;
}
st = new StringTokenizer(line, ",");
String airportName = st.nextToken();
String airportCode = st.nextToken();
if (!alreadyInCollection(airportCode, airports)) {
AirportCodeMapping acm = flightService.createAirportCodeMapping(airportCode, airportName);
airports.add(acm);
}
int indexIntoTopLine = 0;
while (st.hasMoreTokens()) {
String milesString = st.nextToken();
if (milesString.equals("NA")) {
indexIntoTopLine++;
continue;
}
int miles = Integer.parseInt(milesString);
String toAirport = airports.get(indexIntoTopLine).getAirportCode();
String flightId = "AA" + flightNumber;
flightService.storeFlightSegment(flightId, airportCode, toAirport, miles);
Date now = new Date();
for (int daysFromNow = 0; daysFromNow < MAX_FLIGHTS_PER_SEGMENT; daysFromNow++) {
Calendar c = Calendar.getInstance();
c.setTime(now);
c.set(Calendar.HOUR_OF_DAY, 0);
c.set(Calendar.MINUTE, 0);
c.set(Calendar.SECOND, 0);
c.set(Calendar.MILLISECOND, 0);
c.add(Calendar.DATE, daysFromNow);
Date departureTime = c.getTime();
Date arrivalTime = getArrivalTime(departureTime, miles);
flightService.createNewFlight(flightId, departureTime, arrivalTime, new BigDecimal(500), new BigDecimal(200), 10, 200, "B747");

}
flightNumber++;
indexIntoTopLine++;
}
}

for (int jj = 0; jj < airports.size(); jj++) {
flightService.storeAirportMapping(airports.get(jj));
}
lnr.close();
}

private static Date getArrivalTime(Date departureTime, int mileage) {
double averageSpeed = 600.0; // 600 miles/hours
double hours = (double) mileage / averageSpeed; // miles / miles/hour = hours
double partsOfHour = hours % 1.0;
int minutes = (int)(60.0 * partsOfHour);
Calendar c = Calendar.getInstance();
c.setTime(departureTime);
c.add(Calendar.HOUR, (int)hours);
c.add(Calendar.MINUTE, minutes);
return c.getTime();
}

static private boolean alreadyInCollection(String airportCode, ArrayList<AirportCodeMapping> airports) {
for (int ii = 0; ii < airports.size(); ii++) {
if (airports.get(ii).getAirportCode().equals(airportCode)) {
return true;
}
}
return false;
}
}
Loading