Skip to content

Commit 3261524

Browse files
committed
fix(layout,table): perform orientation detection at the table level
Signed-off-by: Clément Doumouro <clement.doumouro@gmail.com>
1 parent 79f4da1 commit 3261524

File tree

1 file changed

+113
-119
lines changed

1 file changed

+113
-119
lines changed

docling/models/table_structure_model.py

Lines changed: 113 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import operator
23
import warnings
34
from pathlib import Path
45
from typing import Iterable, Optional, Tuple, cast
@@ -189,144 +190,137 @@ def __call__(
189190
page.predictions.tablestructure = (
190191
TableStructurePrediction()
191192
) # dummy
192-
cells_orientation = detect_orientation(page.cells)
193193
# Keep only table bboxes
194194
in_tables_clusters = [
195195
cluster
196196
for cluster in page.predictions.layout.clusters
197197
if cluster.label in self._table_labels
198198
]
199199

200-
if not len(in_tables_clusters):
200+
if not in_tables_clusters:
201201
yield page
202202
continue
203203
# Rotate and scale table image
204204
page_im = cast(Image, page.get_image())
205-
original_scaled_page_size = (
206-
int(page_im.size[0] * self.scale),
207-
int(page_im.size[1] * self.scale),
208-
)
209-
scaled_page_im: Image = cast(
205+
original_scaled_im: Image = cast(
210206
Image, page.get_image(scale=self.scale)
211207
)
212-
if cells_orientation:
213-
scaled_page_im = scaled_page_im.rotate(
214-
-cells_orientation, expand=True
215-
)
216-
page_input = {
217-
"width": scaled_page_im.size[0],
218-
"height": scaled_page_im.size[1],
219-
"image": numpy.asarray(scaled_page_im),
220-
}
221-
# Rotate and scale table cells
222-
in_tables = [
223-
(
224-
c,
225-
[
226-
round(x) * self.scale
227-
for x in _rotate_bbox(
228-
c.bbox,
229-
orientation=-cells_orientation,
230-
im_size=page_im.size,
231-
)
232-
.to_top_left_origin(page_im.size[1])
233-
.as_tuple()
234-
],
235-
)
236-
for c in in_tables_clusters
237-
]
238-
table_clusters, table_bboxes = zip(*in_tables)
239-
240-
if len(table_bboxes):
241-
for table_cluster, tbl_box in in_tables:
242-
# Check if word-level cells are available from backend:
243-
sp = page._backend.get_segmented_page()
244-
if sp is not None:
245-
tcells = sp.get_cells_in_bbox(
246-
cell_unit=TextCellUnit.WORD,
247-
bbox=table_cluster.bbox,
248-
)
249-
if len(tcells) == 0:
250-
# In case word-level cells yield empty
251-
tcells = table_cluster.cells
252-
else:
253-
# Otherwise - we use normal (line/phrase) cells
254-
tcells = table_cluster.cells
255-
tokens = []
256-
for c in tcells:
257-
# Only allow non empty strings (spaces) into the cells of a table
258-
if len(c.text.strip()) > 0:
259-
new_cell = copy.deepcopy(c)
260-
new_cell.rect = BoundingRectangle.from_bounding_box(
261-
new_cell.rect.to_bounding_box().scaled(
262-
scale=self.scale
263-
)
264-
)
265-
# _rotate_bbox expects the size of the image in
266-
# which the bbox was found
267-
new_bbox = _rotate_bbox(
268-
new_cell.to_bounding_box(),
269-
orientation=-cells_orientation,
270-
im_size=original_scaled_page_size,
271-
).model_dump()
272-
tokens.append(
273-
{
274-
"id": new_cell.index,
275-
"text": new_cell.text,
276-
"bbox": new_bbox,
277-
}
278-
)
279-
page_input["tokens"] = tokens
208+
original_scaled_page_size = original_scaled_im.size
209+
clusters_with_orientations = sorted(
210+
((c, detect_orientation(c.cells)) for c in in_tables_clusters),
211+
key=operator.itemgetter(1),
212+
)
280213

281-
tf_output = self.tf_predictor.multi_table_predict(
282-
page_input, [tbl_box], do_matching=self.do_cell_matching
214+
previous_orientation = None
215+
for table_cluster, table_orientation in clusters_with_orientations:
216+
# Rotate the image if needed
217+
if previous_orientation != table_orientation:
218+
scaled_page_im = original_scaled_im
219+
if table_orientation:
220+
scaled_page_im = original_scaled_im.rotate(
221+
-table_orientation, expand=True
222+
)
223+
page_input = {
224+
"width": scaled_page_im.size[0],
225+
"height": scaled_page_im.size[1],
226+
"image": numpy.asarray(scaled_page_im),
227+
}
228+
previous_orientation = table_orientation
229+
# Rotate and scale the table bbox
230+
tbl_box = [
231+
round(x) * self.scale
232+
for x in _rotate_bbox(
233+
table_cluster.bbox,
234+
orientation=-table_orientation,
235+
im_size=page_im.size,
283236
)
284-
table_out = tf_output[0]
285-
table_cells = []
286-
for element in table_out["tf_responses"]:
287-
if not self.do_cell_matching:
288-
the_bbox = BoundingBox.model_validate(
289-
element["bbox"]
290-
).scaled(1 / self.scale)
291-
text_piece = page._backend.get_text_in_rect(
292-
the_bbox
237+
.to_top_left_origin(page_im.size[1])
238+
.as_tuple()
239+
]
240+
# Check if word-level cells are available from backend:
241+
sp = page._backend.get_segmented_page()
242+
if sp is not None:
243+
tcells = sp.get_cells_in_bbox(
244+
cell_unit=TextCellUnit.WORD,
245+
bbox=table_cluster.bbox,
246+
)
247+
if not tcells:
248+
# In case word-level cells yield empty
249+
tcells = table_cluster.cells
250+
else:
251+
# Otherwise - we use normal (line/phrase) cells
252+
tcells = table_cluster.cells
253+
tokens = []
254+
for c in tcells:
255+
# Only allow non empty strings (spaces) into the cells of a table
256+
if c.text.strip():
257+
new_cell = copy.deepcopy(c)
258+
new_cell.rect = BoundingRectangle.from_bounding_box(
259+
new_cell.rect.to_bounding_box().scaled(
260+
scale=self.scale
293261
)
294-
element["bbox"]["token"] = text_piece
295-
element["bbox"] = _rotate_bbox(
296-
BoundingBox.model_validate(element["bbox"]),
297-
orientation=cells_orientation,
298-
im_size=scaled_page_im.size,
262+
)
263+
# _rotate_bbox expects the size of the image in
264+
# which the bbox was found
265+
new_bbox = _rotate_bbox(
266+
new_cell.to_bounding_box(),
267+
orientation=-table_orientation,
268+
im_size=original_scaled_page_size,
299269
).model_dump()
300-
tc = TableCell.model_validate(element)
301-
if tc.bbox is not None:
302-
tc.bbox = tc.bbox.scaled(1 / self.scale)
303-
table_cells.append(tc)
304-
305-
assert "predict_details" in table_out
306-
307-
# Retrieving cols/rows, after post processing:
308-
num_rows = table_out["predict_details"].get("num_rows", 0)
309-
num_cols = table_out["predict_details"].get("num_cols", 0)
310-
otsl_seq = (
311-
table_out["predict_details"]
312-
.get("prediction", {})
313-
.get("rs_seq", [])
314-
)
270+
tokens.append(
271+
{
272+
"id": new_cell.index,
273+
"text": new_cell.text,
274+
"bbox": new_bbox,
275+
}
276+
)
277+
page_input["tokens"] = tokens
315278

316-
tbl = Table(
317-
otsl_seq=otsl_seq,
318-
table_cells=table_cells,
319-
num_rows=num_rows,
320-
num_cols=num_cols,
321-
id=table_cluster.id,
322-
page_no=page.page_no,
323-
cluster=table_cluster,
324-
label=table_cluster.label,
325-
)
279+
tf_output = self.tf_predictor.multi_table_predict(
280+
page_input, [tbl_box], do_matching=self.do_cell_matching
281+
)
282+
table_out = tf_output[0]
283+
table_cells = []
284+
for element in table_out["tf_responses"]:
285+
if not self.do_cell_matching:
286+
the_bbox = BoundingBox.model_validate(
287+
element["bbox"]
288+
).scaled(1 / self.scale)
289+
text_piece = page._backend.get_text_in_rect(the_bbox)
290+
element["bbox"]["token"] = text_piece
291+
element["bbox"] = _rotate_bbox(
292+
BoundingBox.model_validate(element["bbox"]),
293+
orientation=table_orientation,
294+
im_size=scaled_page_im.size,
295+
).model_dump()
296+
tc = TableCell.model_validate(element)
297+
if tc.bbox is not None:
298+
tc.bbox = tc.bbox.scaled(1 / self.scale)
299+
table_cells.append(tc)
300+
301+
assert "predict_details" in table_out
302+
303+
# Retrieving cols/rows, after post processing:
304+
num_rows = table_out["predict_details"].get("num_rows", 0)
305+
num_cols = table_out["predict_details"].get("num_cols", 0)
306+
otsl_seq = (
307+
table_out["predict_details"]
308+
.get("prediction", {})
309+
.get("rs_seq", [])
310+
)
311+
312+
tbl = Table(
313+
otsl_seq=otsl_seq,
314+
table_cells=table_cells,
315+
num_rows=num_rows,
316+
num_cols=num_cols,
317+
id=table_cluster.id,
318+
page_no=page.page_no,
319+
cluster=table_cluster,
320+
label=table_cluster.label,
321+
)
326322

327-
page.predictions.tablestructure.table_map[
328-
table_cluster.id
329-
] = tbl
323+
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
330324

331325
# For debugging purposes:
332326
if settings.debug.visualize_tables:

0 commit comments

Comments
 (0)