From 7c6d01c506593b9b74cfbda97d6da3eac389ac57 Mon Sep 17 00:00:00 2001 From: Varun3011 Date: Wed, 20 May 2026 13:55:31 -0700 Subject: [PATCH 1/2] Fix missed signature field promotion --- commonforms/inference.py | 109 +++++++++++++++++++++++++++++++++++++++ tests/inference_test.py | 41 +++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/commonforms/inference.py b/commonforms/inference.py index 9eed964..426d88d 100644 --- a/commonforms/inference.py +++ b/commonforms/inference.py @@ -1,6 +1,7 @@ from __future__ import annotations from ultralytics import YOLO from pathlib import Path +from dataclasses import dataclass from huggingface_hub import hf_hub_download from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge @@ -9,6 +10,7 @@ from commonforms.exceptions import EncryptedPdfError import formalpdf +import pypdf import pypdfium2 import logging import PIL @@ -241,6 +243,112 @@ def render_pdf(pdf_path: str) -> list[Page]: doc.document.close() +@dataclass +class TextFragment: + text: str + x0: float + y0: float + + +def extract_text_fragments(input_path: str | Path) -> dict[int, list[TextFragment]]: + reader = pypdf.PdfReader(str(input_path)) + try: + fragments = {} + for page_ix, page in enumerate(reader.pages): + box = page.cropbox if page.cropbox else page.mediabox + page_width = float(box.right - box.left) + page_height = float(box.top - box.bottom) + page_fragments: list[TextFragment] = [] + + def visitor(text, cm, tm, font_dict, font_size): + if not text.strip(): + return + + x0 = float(tm[4] - box.left) / page_width + y0 = 1 - (float(tm[5] - box.bottom) / page_height) + page_fragments.append(TextFragment(text=text, x0=x0, y0=y0)) + + page.extract_text(visitor_text=visitor) + fragments[page_ix] = page_fragments + + return fragments + finally: + reader.close() + + +def group_widget_rows(widgets: list[Widget], y_threshold: float = 0.015) -> list[list[Widget]]: + rows: list[list[Widget]] = [] + for widget in sorted(widgets, key=lambda item: item.bounding_box.y0): + if ( + rows + and abs(widget.bounding_box.y0 - rows[-1][0].bounding_box.y0) <= y_threshold + ): + rows[-1].append(widget) + else: + rows.append([widget]) + return rows + + +def promote_signature_widgets( + input_path: str | Path, results: dict[int, list[Widget]] +) -> dict[int, list[Widget]]: + """Promote likely signature fields by matching signature labels to nearby rows.""" + text_fragments = extract_text_fragments(input_path) + + for page_ix, widgets in results.items(): + if any(widget.widget_type == "Signature" for widget in widgets): + continue + + signature_labels = [ + fragment + for fragment in text_fragments.get(page_ix, []) + if "signature" in fragment.text.lower() + ] + if not signature_labels: + continue + + textbox_rows = group_widget_rows( + [widget for widget in widgets if widget.widget_type == "TextBox"] + ) + if not textbox_rows: + continue + + scored_rows = [] + for row in textbox_rows: + row_left = min(widget.bounding_box.x0 for widget in row) + row_right = max(widget.bounding_box.x1 for widget in row) + row_y = sum(widget.bounding_box.y0 for widget in row) / len(row) + row_width = row_right - row_left + + for label in signature_labels: + if row_y <= label.y0: + continue + + horizontal_penalty = 0.0 + if label.x0 < row_left: + horizontal_penalty = row_left - label.x0 + elif label.x0 > row_right: + horizontal_penalty = label.x0 - row_right + + score = ( + horizontal_penalty, + abs(row_left - label.x0), + -row_y, + -row_width, + ) + scored_rows.append((score, row)) + + if not scored_rows: + continue + + best_row = min(scored_rows, key=lambda item: item[0])[1] + candidate = min(best_row, key=lambda widget: widget.bounding_box.x0) + widget_ix = widgets.index(candidate) + widgets[widget_ix] = candidate.model_copy(update={"widget_type": "Signature"}) + + return results + + def prepare_form( input_path: str | Path, output_path: str | Path, @@ -273,6 +381,7 @@ def prepare_form( results = detector.extract_widgets( pages, confidence=confidence, image_size=image_size ) + results = promote_signature_widgets(input_path, results) writer = PyPdfFormCreator(input_path) if not keep_existing_fields: diff --git a/tests/inference_test.py b/tests/inference_test.py index 5b8693b..ec64446 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -4,6 +4,9 @@ import formalpdf import pytest +from commonforms.inference import promote_signature_widgets +from commonforms.utils import BoundingBox, Widget + def test_inference(tmp_path): # tmp_path is a built-in pythest fixture where we'll write the outputs @@ -67,6 +70,44 @@ def test_inference_ffdetr(tmp_path): doc.document.close() +def test_promote_signature_widgets_uses_signature_label_on_test_pdf(): + results = { + 1: [ + Widget( + widget_type="TextBox", + bounding_box=BoundingBox(x0=0.089, y0=0.857, x1=0.384, y1=0.895), + page=1, + ), + Widget( + widget_type="TextBox", + bounding_box=BoundingBox(x0=0.752, y0=0.859, x1=0.927, y1=0.896), + page=1, + ), + ] + } + + promoted = promote_signature_widgets("./tests/resources/input.pdf", results) + + assert promoted[1][0].widget_type == "Signature" + assert promoted[1][1].widget_type == "TextBox" + + +def test_promote_signature_widgets_skips_pages_without_signature_label(): + results = { + 0: [ + Widget( + widget_type="TextBox", + bounding_box=BoundingBox(x0=0.1, y0=0.8, x1=0.3, y1=0.84), + page=0, + ) + ] + } + + promoted = promote_signature_widgets("./tests/resources/input.pdf", results) + + assert promoted[0][0].widget_type == "TextBox" + + # TODO(joe): future tests around handling encrypted PDFs # 1. add a --password flag and test that inference doesn't fail # 2. if a password is provided, ensure that the _output_ PDF remains encrpyted From 956d9d4d758fbbdb9245b1084b123a8f007826a8 Mon Sep 17 00:00:00 2001 From: varun3011 Date: Wed, 20 May 2026 16:13:55 -0700 Subject: [PATCH 2/2] Address review feedback and fix FFDetr device handling --- commonforms/inference.py | 123 +++++++++++++++++++++++---------------- commonforms/utils.py | 7 +++ tests/inference_test.py | 35 ++++++++++- 3 files changed, 111 insertions(+), 54 deletions(-) diff --git a/commonforms/inference.py b/commonforms/inference.py index 426d88d..f8d5a6b 100644 --- a/commonforms/inference.py +++ b/commonforms/inference.py @@ -1,16 +1,13 @@ from __future__ import annotations from ultralytics import YOLO from pathlib import Path -from dataclasses import dataclass from huggingface_hub import hf_hub_download from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge -from commonforms.utils import BoundingBox, Page, Widget +from commonforms.utils import BoundingBox, Page, TextFragment, Widget from commonforms.form_creator import PyPdfFormCreator from commonforms.exceptions import EncryptedPdfError -import formalpdf -import pypdf import pypdfium2 import logging import PIL @@ -40,7 +37,9 @@ def batch(lst: list, n: int = 8): class FFDetrDetector: def __init__(self, model_or_path: str, device: int | str = "cpu") -> None: self.device = device - self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path)) + self.model = RFDETRMedium( + pretrain_weights=self.get_model_path(model_or_path), device=device + ) self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"} @@ -75,7 +74,9 @@ def extract_widgets( image_size = 1024 results = [] for b in batch([p.image for p in pages], n=batch_size): - predictions = self.model.predict(b, threshold=confidence) + predictions = self.model.predict( + b, threshold=confidence, device=self.device + ) if isinstance(predictions, list): results.extend(predictions) else: @@ -231,52 +232,67 @@ def sort_widgets(widgets: list[Widget]) -> list[Widget]: return [widget for line in lines for widget in line] -def render_pdf(pdf_path: str) -> list[Page]: - pages = [] - doc = formalpdf.open(pdf_path) +def extract_text_fragments(page: pypdfium2.PdfPage) -> list[TextFragment]: + textpage = page.get_textpage() try: - for page in doc: - image = page.render(dpi=144) - pages.append(Page(image=image, width=image.width, height=image.height)) - return pages - finally: - doc.document.close() - + fragments = [] + for term in textpage.get_text_range().splitlines(): + text = term.strip() + if not text: + continue -@dataclass -class TextFragment: - text: str - x0: float - y0: float + searcher = textpage.search(term, match_case=False, consecutive=True) + try: + match = searcher.get_next() + finally: + searcher.close() + if match is None: + continue -def extract_text_fragments(input_path: str | Path) -> dict[int, list[TextFragment]]: - reader = pypdf.PdfReader(str(input_path)) - try: - fragments = {} - for page_ix, page in enumerate(reader.pages): - box = page.cropbox if page.cropbox else page.mediabox - page_width = float(box.right - box.left) - page_height = float(box.top - box.bottom) - page_fragments: list[TextFragment] = [] + index, count = match + rect_count = textpage.count_rects(index, count) + rects = [textpage.get_rect(i) for i in range(rect_count)] + if not rects: + continue - def visitor(text, cm, tm, font_dict, font_size): - if not text.strip(): - return + left = min(rect[0] for rect in rects) + top = max(rect[3] for rect in rects) + fragments.append( + TextFragment( + text=text, + x0=left / page.get_width(), + y0=1 - (top / page.get_height()), + ) + ) - x0 = float(tm[4] - box.left) / page_width - y0 = 1 - (float(tm[5] - box.bottom) / page_height) - page_fragments.append(TextFragment(text=text, x0=x0, y0=y0)) + return fragments + finally: + textpage.close() - page.extract_text(visitor_text=visitor) - fragments[page_ix] = page_fragments - return fragments +def render_pdf(pdf_path: str) -> list[Page]: + pages = [] + doc = pypdfium2.PdfDocument(pdf_path) + try: + for page in doc: + image = page.render(scale=2).to_pil() + pages.append( + Page( + image=image, + width=image.width, + height=image.height, + text_fragments=extract_text_fragments(page), + ) + ) + return pages finally: - reader.close() + doc.close() -def group_widget_rows(widgets: list[Widget], y_threshold: float = 0.015) -> list[list[Widget]]: +def group_widget_rows( + widgets: list[Widget], y_threshold: float = 0.015 +) -> list[list[Widget]]: rows: list[list[Widget]] = [] for widget in sorted(widgets, key=lambda item: item.bounding_box.y0): if ( @@ -290,10 +306,12 @@ def group_widget_rows(widgets: list[Widget], y_threshold: float = 0.015) -> list def promote_signature_widgets( - input_path: str | Path, results: dict[int, list[Widget]] + pages: list[Page], + results: dict[int, list[Widget]], + signature_label_terms: tuple[str, ...] = ("signature",), ) -> dict[int, list[Widget]]: """Promote likely signature fields by matching signature labels to nearby rows.""" - text_fragments = extract_text_fragments(input_path) + normalized_terms = tuple(term.lower() for term in signature_label_terms) for page_ix, widgets in results.items(): if any(widget.widget_type == "Signature" for widget in widgets): @@ -301,8 +319,8 @@ def promote_signature_widgets( signature_labels = [ fragment - for fragment in text_fragments.get(page_ix, []) - if "signature" in fragment.text.lower() + for fragment in pages[page_ix].text_fragments + if any(term in fragment.text.lower() for term in normalized_terms) ] if not signature_labels: continue @@ -321,9 +339,6 @@ def promote_signature_widgets( row_width = row_right - row_left for label in signature_labels: - if row_y <= label.y0: - continue - horizontal_penalty = 0.0 if label.x0 < row_left: horizontal_penalty = row_left - label.x0 @@ -332,9 +347,10 @@ def promote_signature_widgets( score = ( horizontal_penalty, + abs(row_y - label.y0), abs(row_left - label.x0), - -row_y, -row_width, + -row_y, ) scored_rows.append((score, row)) @@ -362,11 +378,12 @@ def prepare_form( fast: bool = False, multiline: bool = False, batch_size: int = 4, + signature_label_terms: tuple[str, ...] = ("signature",), ): if "FFDNET" in model_or_path.upper(): detector = FFDNetDetector(model_or_path, device=device, fast=fast) else: - detector = FFDetrDetector(model_or_path) + detector = FFDetrDetector(model_or_path, device=device) try: pages = render_pdf(input_path) @@ -381,7 +398,11 @@ def prepare_form( results = detector.extract_widgets( pages, confidence=confidence, image_size=image_size ) - results = promote_signature_widgets(input_path, results) + + if use_signature_fields: + results = promote_signature_widgets( + pages, results, signature_label_terms=signature_label_terms + ) writer = PyPdfFormCreator(input_path) if not keep_existing_fields: diff --git a/commonforms/utils.py b/commonforms/utils.py index b85ac92..9d5e0f2 100644 --- a/commonforms/utils.py +++ b/commonforms/utils.py @@ -26,8 +26,15 @@ class Widget(BaseModel): page: int +class TextFragment(BaseModel): + text: str + x0: float + y0: float + + @dataclass class Page: image: Image.Image width: float height: float + text_fragments: list[TextFragment] diff --git a/tests/inference_test.py b/tests/inference_test.py index ec64446..f70ec5b 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -3,9 +3,10 @@ import formalpdf import pytest +from PIL import Image from commonforms.inference import promote_signature_widgets -from commonforms.utils import BoundingBox, Widget +from commonforms.utils import BoundingBox, Page, TextFragment, Widget def test_inference(tmp_path): @@ -71,6 +72,26 @@ def test_inference_ffdetr(tmp_path): def test_promote_signature_widgets_uses_signature_label_on_test_pdf(): + pages = [ + Page( + image=Image.new("RGB", (1, 1)), + width=1, + height=1, + text_fragments=[], + ), + Page( + image=Image.new("RGB", (1, 1)), + width=1, + height=1, + text_fragments=[ + TextFragment( + text="POLICYHOLDER/PATIENT SIGNATURE FAMILY RELATIONSHIP, IF NOT POLICYHOLDER DATE", + x0=0.37, + y0=0.61, + ) + ], + ), + ] results = { 1: [ Widget( @@ -86,13 +107,21 @@ def test_promote_signature_widgets_uses_signature_label_on_test_pdf(): ] } - promoted = promote_signature_widgets("./tests/resources/input.pdf", results) + promoted = promote_signature_widgets(pages, results) assert promoted[1][0].widget_type == "Signature" assert promoted[1][1].widget_type == "TextBox" def test_promote_signature_widgets_skips_pages_without_signature_label(): + pages = [ + Page( + image=Image.new("RGB", (1, 1)), + width=1, + height=1, + text_fragments=[TextFragment(text="General contact information", x0=0.1, y0=0.2)], + ) + ] results = { 0: [ Widget( @@ -103,7 +132,7 @@ def test_promote_signature_widgets_skips_pages_without_signature_label(): ] } - promoted = promote_signature_widgets("./tests/resources/input.pdf", results) + promoted = promote_signature_widgets(pages, results) assert promoted[0][0].widget_type == "TextBox"