Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 139 additions & 9 deletions commonforms/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from huggingface_hub import hf_hub_download
from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge

from commonforms.utils import BoundingBox, Page, Widget
from commonforms.utils import BoundingBox, Page, TextFragment, Widget
from commonforms.form_creator import PyPdfFormCreator
from commonforms.exceptions import EncryptedPdfError

import formalpdf
import pypdfium2
import logging
import PIL
Expand Down Expand Up @@ -38,7 +37,9 @@ def batch(lst: list, n: int = 8):
class FFDetrDetector:
def __init__(self, model_or_path: str, device: int | str = "cpu") -> None:
self.device = device
self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path))
self.model = RFDETRMedium(
pretrain_weights=self.get_model_path(model_or_path), device=device
)

self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"}

Expand Down Expand Up @@ -73,7 +74,9 @@ def extract_widgets(
image_size = 1024
results = []
for b in batch([p.image for p in pages], n=batch_size):
predictions = self.model.predict(b, threshold=confidence)
predictions = self.model.predict(
b, threshold=confidence, device=self.device
)
if isinstance(predictions, list):
results.extend(predictions)
else:
Expand Down Expand Up @@ -229,16 +232,137 @@ def sort_widgets(widgets: list[Widget]) -> list[Widget]:
return [widget for line in lines for widget in line]


def extract_text_fragments(page: pypdfium2.PdfPage) -> list[TextFragment]:
textpage = page.get_textpage()
try:
fragments = []
for term in textpage.get_text_range().splitlines():
text = term.strip()
if not text:
continue

searcher = textpage.search(term, match_case=False, consecutive=True)
try:
match = searcher.get_next()
finally:
searcher.close()

if match is None:
continue

index, count = match
rect_count = textpage.count_rects(index, count)
rects = [textpage.get_rect(i) for i in range(rect_count)]
if not rects:
continue

left = min(rect[0] for rect in rects)
top = max(rect[3] for rect in rects)
fragments.append(
TextFragment(
text=text,
x0=left / page.get_width(),
y0=1 - (top / page.get_height()),
)
)

return fragments
finally:
textpage.close()


def render_pdf(pdf_path: str) -> list[Page]:
pages = []
doc = formalpdf.open(pdf_path)
doc = pypdfium2.PdfDocument(pdf_path)
try:
for page in doc:
image = page.render(dpi=144)
pages.append(Page(image=image, width=image.width, height=image.height))
image = page.render(scale=2).to_pil()
pages.append(
Page(
image=image,
width=image.width,
height=image.height,
text_fragments=extract_text_fragments(page),
)
)
return pages
finally:
doc.document.close()
doc.close()


def group_widget_rows(
widgets: list[Widget], y_threshold: float = 0.015
) -> list[list[Widget]]:
rows: list[list[Widget]] = []
for widget in sorted(widgets, key=lambda item: item.bounding_box.y0):
if (
rows
and abs(widget.bounding_box.y0 - rows[-1][0].bounding_box.y0) <= y_threshold
):
rows[-1].append(widget)
else:
rows.append([widget])
return rows


def promote_signature_widgets(
pages: list[Page],
results: dict[int, list[Widget]],
signature_label_terms: tuple[str, ...] = ("signature",),
) -> dict[int, list[Widget]]:
"""Promote likely signature fields by matching signature labels to nearby rows."""
normalized_terms = tuple(term.lower() for term in signature_label_terms)

for page_ix, widgets in results.items():
if any(widget.widget_type == "Signature" for widget in widgets):
continue

signature_labels = [
fragment
for fragment in pages[page_ix].text_fragments
if any(term in fragment.text.lower() for term in normalized_terms)
]
if not signature_labels:
continue

textbox_rows = group_widget_rows(
[widget for widget in widgets if widget.widget_type == "TextBox"]
)
if not textbox_rows:
continue

scored_rows = []
for row in textbox_rows:
row_left = min(widget.bounding_box.x0 for widget in row)
row_right = max(widget.bounding_box.x1 for widget in row)
row_y = sum(widget.bounding_box.y0 for widget in row) / len(row)
row_width = row_right - row_left

for label in signature_labels:
horizontal_penalty = 0.0
if label.x0 < row_left:
horizontal_penalty = row_left - label.x0
elif label.x0 > row_right:
horizontal_penalty = label.x0 - row_right

score = (
horizontal_penalty,
abs(row_y - label.y0),
abs(row_left - label.x0),
-row_width,
-row_y,
)
scored_rows.append((score, row))

if not scored_rows:
continue

best_row = min(scored_rows, key=lambda item: item[0])[1]
candidate = min(best_row, key=lambda widget: widget.bounding_box.x0)
widget_ix = widgets.index(candidate)
widgets[widget_ix] = candidate.model_copy(update={"widget_type": "Signature"})

return results


def prepare_form(
Expand All @@ -254,11 +378,12 @@ def prepare_form(
fast: bool = False,
multiline: bool = False,
batch_size: int = 4,
signature_label_terms: tuple[str, ...] = ("signature",),
):
if "FFDNET" in model_or_path.upper():
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
else:
detector = FFDetrDetector(model_or_path)
detector = FFDetrDetector(model_or_path, device=device)

try:
pages = render_pdf(input_path)
Expand All @@ -274,6 +399,11 @@ def prepare_form(
pages, confidence=confidence, image_size=image_size
)

if use_signature_fields:
results = promote_signature_widgets(
pages, results, signature_label_terms=signature_label_terms
)

writer = PyPdfFormCreator(input_path)
if not keep_existing_fields:
writer.clear_existing_fields()
Expand Down
7 changes: 7 additions & 0 deletions commonforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ class Widget(BaseModel):
page: int


class TextFragment(BaseModel):
text: str
x0: float
y0: float


@dataclass
class Page:
image: Image.Image
width: float
height: float
text_fragments: list[TextFragment]
70 changes: 70 additions & 0 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

import formalpdf
import pytest
from PIL import Image

from commonforms.inference import promote_signature_widgets
from commonforms.utils import BoundingBox, Page, TextFragment, Widget


def test_inference(tmp_path):
Expand Down Expand Up @@ -67,6 +71,72 @@ def test_inference_ffdetr(tmp_path):
doc.document.close()


def test_promote_signature_widgets_uses_signature_label_on_test_pdf():
Comment thread
jbarrow marked this conversation as resolved.
pages = [
Page(
image=Image.new("RGB", (1, 1)),
width=1,
height=1,
text_fragments=[],
),
Page(
image=Image.new("RGB", (1, 1)),
width=1,
height=1,
text_fragments=[
TextFragment(
text="POLICYHOLDER/PATIENT SIGNATURE FAMILY RELATIONSHIP, IF NOT POLICYHOLDER DATE",
x0=0.37,
y0=0.61,
)
],
),
]
results = {
1: [
Widget(
widget_type="TextBox",
bounding_box=BoundingBox(x0=0.089, y0=0.857, x1=0.384, y1=0.895),
page=1,
),
Widget(
widget_type="TextBox",
bounding_box=BoundingBox(x0=0.752, y0=0.859, x1=0.927, y1=0.896),
page=1,
),
]
}

promoted = promote_signature_widgets(pages, results)

assert promoted[1][0].widget_type == "Signature"
assert promoted[1][1].widget_type == "TextBox"


def test_promote_signature_widgets_skips_pages_without_signature_label():
pages = [
Page(
image=Image.new("RGB", (1, 1)),
width=1,
height=1,
text_fragments=[TextFragment(text="General contact information", x0=0.1, y0=0.2)],
)
]
results = {
0: [
Widget(
widget_type="TextBox",
bounding_box=BoundingBox(x0=0.1, y0=0.8, x1=0.3, y1=0.84),
page=0,
)
]
}

promoted = promote_signature_widgets(pages, results)

assert promoted[0][0].widget_type == "TextBox"


# TODO(joe): future tests around handling encrypted PDFs
# 1. add a --password flag and test that inference doesn't fail
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
Expand Down
Loading