|  | 
| 1 | 1 | import copy | 
|  | 2 | +import operator | 
| 2 | 3 | import warnings | 
| 3 | 4 | from pathlib import Path | 
| 4 | 5 | from typing import Iterable, Optional, Tuple, cast | 
| @@ -189,144 +190,137 @@ def __call__( | 
| 189 | 190 |                     page.predictions.tablestructure = ( | 
| 190 | 191 |                         TableStructurePrediction() | 
| 191 | 192 |                     )  # dummy | 
| 192 |  | -                    cells_orientation = detect_orientation(page.cells) | 
| 193 | 193 |                     # Keep only table bboxes | 
| 194 | 194 |                     in_tables_clusters = [ | 
| 195 | 195 |                         cluster | 
| 196 | 196 |                         for cluster in page.predictions.layout.clusters | 
| 197 | 197 |                         if cluster.label in self._table_labels | 
| 198 | 198 |                     ] | 
| 199 | 199 | 
 | 
| 200 |  | -                    if not len(in_tables_clusters): | 
|  | 200 | +                    if not in_tables_clusters: | 
| 201 | 201 |                         yield page | 
| 202 | 202 |                         continue | 
| 203 | 203 |                     # Rotate and scale table image | 
| 204 | 204 |                     page_im = cast(Image, page.get_image()) | 
| 205 |  | -                    original_scaled_page_size = ( | 
| 206 |  | -                        int(page_im.size[0] * self.scale), | 
| 207 |  | -                        int(page_im.size[1] * self.scale), | 
| 208 |  | -                    ) | 
| 209 |  | -                    scaled_page_im: Image = cast( | 
|  | 205 | +                    original_scaled_im: Image = cast( | 
| 210 | 206 |                         Image, page.get_image(scale=self.scale) | 
| 211 | 207 |                     ) | 
| 212 |  | -                    if cells_orientation: | 
| 213 |  | -                        scaled_page_im = scaled_page_im.rotate( | 
| 214 |  | -                            -cells_orientation, expand=True | 
| 215 |  | -                        ) | 
| 216 |  | -                    page_input = { | 
| 217 |  | -                        "width": scaled_page_im.size[0], | 
| 218 |  | -                        "height": scaled_page_im.size[1], | 
| 219 |  | -                        "image": numpy.asarray(scaled_page_im), | 
| 220 |  | -                    } | 
| 221 |  | -                    # Rotate and scale table cells | 
| 222 |  | -                    in_tables = [ | 
| 223 |  | -                        ( | 
| 224 |  | -                            c, | 
| 225 |  | -                            [ | 
| 226 |  | -                                round(x) * self.scale | 
| 227 |  | -                                for x in _rotate_bbox( | 
| 228 |  | -                                    c.bbox, | 
| 229 |  | -                                    orientation=-cells_orientation, | 
| 230 |  | -                                    im_size=page_im.size, | 
| 231 |  | -                                ) | 
| 232 |  | -                                .to_top_left_origin(page_im.size[1]) | 
| 233 |  | -                                .as_tuple() | 
| 234 |  | -                            ], | 
| 235 |  | -                        ) | 
| 236 |  | -                        for c in in_tables_clusters | 
| 237 |  | -                    ] | 
| 238 |  | -                    table_clusters, table_bboxes = zip(*in_tables) | 
| 239 |  | - | 
| 240 |  | -                    if len(table_bboxes): | 
| 241 |  | -                        for table_cluster, tbl_box in in_tables: | 
| 242 |  | -                            # Check if word-level cells are available from backend: | 
| 243 |  | -                            sp = page._backend.get_segmented_page() | 
| 244 |  | -                            if sp is not None: | 
| 245 |  | -                                tcells = sp.get_cells_in_bbox( | 
| 246 |  | -                                    cell_unit=TextCellUnit.WORD, | 
| 247 |  | -                                    bbox=table_cluster.bbox, | 
| 248 |  | -                                ) | 
| 249 |  | -                                if len(tcells) == 0: | 
| 250 |  | -                                    # In case word-level cells yield empty | 
| 251 |  | -                                    tcells = table_cluster.cells | 
| 252 |  | -                            else: | 
| 253 |  | -                                # Otherwise - we use normal (line/phrase) cells | 
| 254 |  | -                                tcells = table_cluster.cells | 
| 255 |  | -                            tokens = [] | 
| 256 |  | -                            for c in tcells: | 
| 257 |  | -                                # Only allow non empty strings (spaces) into the cells of a table | 
| 258 |  | -                                if len(c.text.strip()) > 0: | 
| 259 |  | -                                    new_cell = copy.deepcopy(c) | 
| 260 |  | -                                    new_cell.rect = BoundingRectangle.from_bounding_box( | 
| 261 |  | -                                        new_cell.rect.to_bounding_box().scaled( | 
| 262 |  | -                                            scale=self.scale | 
| 263 |  | -                                        ) | 
| 264 |  | -                                    ) | 
| 265 |  | -                                    # _rotate_bbox expects the size of the image in | 
| 266 |  | -                                    # which the bbox was found | 
| 267 |  | -                                    new_bbox = _rotate_bbox( | 
| 268 |  | -                                        new_cell.to_bounding_box(), | 
| 269 |  | -                                        orientation=-cells_orientation, | 
| 270 |  | -                                        im_size=original_scaled_page_size, | 
| 271 |  | -                                    ).model_dump() | 
| 272 |  | -                                    tokens.append( | 
| 273 |  | -                                        { | 
| 274 |  | -                                            "id": new_cell.index, | 
| 275 |  | -                                            "text": new_cell.text, | 
| 276 |  | -                                            "bbox": new_bbox, | 
| 277 |  | -                                        } | 
| 278 |  | -                                    ) | 
| 279 |  | -                            page_input["tokens"] = tokens | 
|  | 208 | +                    original_scaled_page_size = original_scaled_im.size | 
|  | 209 | +                    clusters_with_orientations = sorted( | 
|  | 210 | +                        ((c, detect_orientation(c.cells)) for c in in_tables_clusters), | 
|  | 211 | +                        key=operator.itemgetter(1), | 
|  | 212 | +                    ) | 
| 280 | 213 | 
 | 
| 281 |  | -                            tf_output = self.tf_predictor.multi_table_predict( | 
| 282 |  | -                                page_input, [tbl_box], do_matching=self.do_cell_matching | 
|  | 214 | +                    previous_orientation = None | 
|  | 215 | +                    for table_cluster, table_orientation in clusters_with_orientations: | 
|  | 216 | +                        # Rotate the image if needed | 
|  | 217 | +                        if previous_orientation != table_orientation: | 
|  | 218 | +                            scaled_page_im = original_scaled_im | 
|  | 219 | +                            if table_orientation: | 
|  | 220 | +                                scaled_page_im = original_scaled_im.rotate( | 
|  | 221 | +                                    -table_orientation, expand=True | 
|  | 222 | +                                ) | 
|  | 223 | +                        page_input = { | 
|  | 224 | +                            "width": scaled_page_im.size[0], | 
|  | 225 | +                            "height": scaled_page_im.size[1], | 
|  | 226 | +                            "image": numpy.asarray(scaled_page_im), | 
|  | 227 | +                        } | 
|  | 228 | +                        previous_orientation = table_orientation | 
|  | 229 | +                        # Rotate and scale the table bbox | 
|  | 230 | +                        tbl_box = [ | 
|  | 231 | +                            round(x) * self.scale | 
|  | 232 | +                            for x in _rotate_bbox( | 
|  | 233 | +                                table_cluster.bbox, | 
|  | 234 | +                                orientation=-table_orientation, | 
|  | 235 | +                                im_size=page_im.size, | 
| 283 | 236 |                             ) | 
| 284 |  | -                            table_out = tf_output[0] | 
| 285 |  | -                            table_cells = [] | 
| 286 |  | -                            for element in table_out["tf_responses"]: | 
| 287 |  | -                                if not self.do_cell_matching: | 
| 288 |  | -                                    the_bbox = BoundingBox.model_validate( | 
| 289 |  | -                                        element["bbox"] | 
| 290 |  | -                                    ).scaled(1 / self.scale) | 
| 291 |  | -                                    text_piece = page._backend.get_text_in_rect( | 
| 292 |  | -                                        the_bbox | 
|  | 237 | +                            .to_top_left_origin(page_im.size[1]) | 
|  | 238 | +                            .as_tuple() | 
|  | 239 | +                        ] | 
|  | 240 | +                        # Check if word-level cells are available from backend: | 
|  | 241 | +                        sp = page._backend.get_segmented_page() | 
|  | 242 | +                        if sp is not None: | 
|  | 243 | +                            tcells = sp.get_cells_in_bbox( | 
|  | 244 | +                                cell_unit=TextCellUnit.WORD, | 
|  | 245 | +                                bbox=table_cluster.bbox, | 
|  | 246 | +                            ) | 
|  | 247 | +                            if not tcells: | 
|  | 248 | +                                # In case word-level cells yield empty | 
|  | 249 | +                                tcells = table_cluster.cells | 
|  | 250 | +                        else: | 
|  | 251 | +                            # Otherwise - we use normal (line/phrase) cells | 
|  | 252 | +                            tcells = table_cluster.cells | 
|  | 253 | +                        tokens = [] | 
|  | 254 | +                        for c in tcells: | 
|  | 255 | +                            # Only allow non empty strings (spaces) into the cells of a table | 
|  | 256 | +                            if c.text.strip(): | 
|  | 257 | +                                new_cell = copy.deepcopy(c) | 
|  | 258 | +                                new_cell.rect = BoundingRectangle.from_bounding_box( | 
|  | 259 | +                                    new_cell.rect.to_bounding_box().scaled( | 
|  | 260 | +                                        scale=self.scale | 
| 293 | 261 |                                     ) | 
| 294 |  | -                                    element["bbox"]["token"] = text_piece | 
| 295 |  | -                                element["bbox"] = _rotate_bbox( | 
| 296 |  | -                                    BoundingBox.model_validate(element["bbox"]), | 
| 297 |  | -                                    orientation=cells_orientation, | 
| 298 |  | -                                    im_size=scaled_page_im.size, | 
|  | 262 | +                                ) | 
|  | 263 | +                                # _rotate_bbox expects the size of the image in | 
|  | 264 | +                                # which the bbox was found | 
|  | 265 | +                                new_bbox = _rotate_bbox( | 
|  | 266 | +                                    new_cell.to_bounding_box(), | 
|  | 267 | +                                    orientation=-table_orientation, | 
|  | 268 | +                                    im_size=original_scaled_page_size, | 
| 299 | 269 |                                 ).model_dump() | 
| 300 |  | -                                tc = TableCell.model_validate(element) | 
| 301 |  | -                                if tc.bbox is not None: | 
| 302 |  | -                                    tc.bbox = tc.bbox.scaled(1 / self.scale) | 
| 303 |  | -                                table_cells.append(tc) | 
| 304 |  | - | 
| 305 |  | -                            assert "predict_details" in table_out | 
| 306 |  | - | 
| 307 |  | -                            # Retrieving cols/rows, after post processing: | 
| 308 |  | -                            num_rows = table_out["predict_details"].get("num_rows", 0) | 
| 309 |  | -                            num_cols = table_out["predict_details"].get("num_cols", 0) | 
| 310 |  | -                            otsl_seq = ( | 
| 311 |  | -                                table_out["predict_details"] | 
| 312 |  | -                                .get("prediction", {}) | 
| 313 |  | -                                .get("rs_seq", []) | 
| 314 |  | -                            ) | 
|  | 270 | +                                tokens.append( | 
|  | 271 | +                                    { | 
|  | 272 | +                                        "id": new_cell.index, | 
|  | 273 | +                                        "text": new_cell.text, | 
|  | 274 | +                                        "bbox": new_bbox, | 
|  | 275 | +                                    } | 
|  | 276 | +                                ) | 
|  | 277 | +                        page_input["tokens"] = tokens | 
| 315 | 278 | 
 | 
| 316 |  | -                            tbl = Table( | 
| 317 |  | -                                otsl_seq=otsl_seq, | 
| 318 |  | -                                table_cells=table_cells, | 
| 319 |  | -                                num_rows=num_rows, | 
| 320 |  | -                                num_cols=num_cols, | 
| 321 |  | -                                id=table_cluster.id, | 
| 322 |  | -                                page_no=page.page_no, | 
| 323 |  | -                                cluster=table_cluster, | 
| 324 |  | -                                label=table_cluster.label, | 
| 325 |  | -                            ) | 
|  | 279 | +                    tf_output = self.tf_predictor.multi_table_predict( | 
|  | 280 | +                        page_input, [tbl_box], do_matching=self.do_cell_matching | 
|  | 281 | +                    ) | 
|  | 282 | +                    table_out = tf_output[0] | 
|  | 283 | +                    table_cells = [] | 
|  | 284 | +                    for element in table_out["tf_responses"]: | 
|  | 285 | +                        if not self.do_cell_matching: | 
|  | 286 | +                            the_bbox = BoundingBox.model_validate( | 
|  | 287 | +                                element["bbox"] | 
|  | 288 | +                            ).scaled(1 / self.scale) | 
|  | 289 | +                            text_piece = page._backend.get_text_in_rect(the_bbox) | 
|  | 290 | +                            element["bbox"]["token"] = text_piece | 
|  | 291 | +                        element["bbox"] = _rotate_bbox( | 
|  | 292 | +                            BoundingBox.model_validate(element["bbox"]), | 
|  | 293 | +                            orientation=table_orientation, | 
|  | 294 | +                            im_size=scaled_page_im.size, | 
|  | 295 | +                        ).model_dump() | 
|  | 296 | +                        tc = TableCell.model_validate(element) | 
|  | 297 | +                        if tc.bbox is not None: | 
|  | 298 | +                            tc.bbox = tc.bbox.scaled(1 / self.scale) | 
|  | 299 | +                        table_cells.append(tc) | 
|  | 300 | + | 
|  | 301 | +                    assert "predict_details" in table_out | 
|  | 302 | + | 
|  | 303 | +                    # Retrieving cols/rows, after post processing: | 
|  | 304 | +                    num_rows = table_out["predict_details"].get("num_rows", 0) | 
|  | 305 | +                    num_cols = table_out["predict_details"].get("num_cols", 0) | 
|  | 306 | +                    otsl_seq = ( | 
|  | 307 | +                        table_out["predict_details"] | 
|  | 308 | +                        .get("prediction", {}) | 
|  | 309 | +                        .get("rs_seq", []) | 
|  | 310 | +                    ) | 
|  | 311 | + | 
|  | 312 | +                    tbl = Table( | 
|  | 313 | +                        otsl_seq=otsl_seq, | 
|  | 314 | +                        table_cells=table_cells, | 
|  | 315 | +                        num_rows=num_rows, | 
|  | 316 | +                        num_cols=num_cols, | 
|  | 317 | +                        id=table_cluster.id, | 
|  | 318 | +                        page_no=page.page_no, | 
|  | 319 | +                        cluster=table_cluster, | 
|  | 320 | +                        label=table_cluster.label, | 
|  | 321 | +                    ) | 
| 326 | 322 | 
 | 
| 327 |  | -                            page.predictions.tablestructure.table_map[ | 
| 328 |  | -                                table_cluster.id | 
| 329 |  | -                            ] = tbl | 
|  | 323 | +                    page.predictions.tablestructure.table_map[table_cluster.id] = tbl | 
| 330 | 324 | 
 | 
| 331 | 325 |                     # For debugging purposes: | 
| 332 | 326 |                     if settings.debug.visualize_tables: | 
|  | 
0 commit comments