Files
invoice-master-poc-v2/scripts/ppstructure_line_items_poc.py
2026-02-03 21:28:06 +01:00

388 lines
13 KiB
Python

#!/usr/bin/env python3
"""
PP-StructureV3 Line Items Extraction POC
Tests line items extraction from Swedish invoices using PP-StructureV3.
Parses HTML table structure to extract structured line item data.
Run with invoice-sm120 conda environment.
"""
import sys
import re
from pathlib import Path
from html.parser import HTMLParser
from dataclasses import dataclass
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "packages" / "backend"))
from paddleocr import PPStructureV3
import fitz # PyMuPDF
@dataclass
class LineItem:
"""Single line item from invoice."""
row_index: int
article_number: str | None
description: str | None
quantity: str | None
unit: str | None
unit_price: str | None
amount: str | None
vat_rate: str | None
confidence: float = 0.9
class TableHTMLParser(HTMLParser):
"""Parse HTML table into rows and cells."""
def __init__(self):
super().__init__()
self.rows: list[list[str]] = []
self.current_row: list[str] = []
self.current_cell: str = ""
self.in_td = False
self.in_thead = False
self.header_row: list[str] = []
def handle_starttag(self, tag, attrs):
if tag == "tr":
self.current_row = []
elif tag in ("td", "th"):
self.in_td = True
self.current_cell = ""
elif tag == "thead":
self.in_thead = True
def handle_endtag(self, tag):
if tag in ("td", "th"):
self.in_td = False
self.current_row.append(self.current_cell.strip())
elif tag == "tr":
if self.current_row:
if self.in_thead:
self.header_row = self.current_row
else:
self.rows.append(self.current_row)
elif tag == "thead":
self.in_thead = False
def handle_data(self, data):
if self.in_td:
self.current_cell += data
# Swedish column name mappings
# Note: Some headers may contain multiple column names merged together
COLUMN_MAPPINGS = {
'article_number': ['art nummer', 'artikelnummer', 'artikel', 'artnr', 'art.nr', 'art nr'],
'description': ['beskrivning', 'produktbeskrivning', 'produkt', 'tjänst', 'text', 'benämning', 'vara/tjänst', 'vara'],
'quantity': ['antal', 'qty', 'st', 'pcs', 'kvantitet'],
'unit': ['enhet', 'unit'],
'unit_price': ['á-pris', 'a-pris', 'pris', 'styckpris', 'enhetspris', 'à pris'],
'amount': ['belopp', 'summa', 'total', 'netto', 'rad summa'],
'vat_rate': ['moms', 'moms%', 'vat', 'skatt', 'moms %'],
}
def normalize_header(header: str) -> str:
"""Normalize header text for matching."""
return header.lower().strip().replace(".", "").replace("-", " ")
def map_columns(headers: list[str]) -> dict[int, str]:
"""Map column indices to field names."""
mapping = {}
for idx, header in enumerate(headers):
normalized = normalize_header(header)
# Skip empty headers
if not normalized.strip():
continue
best_match = None
best_match_len = 0
for field, patterns in COLUMN_MAPPINGS.items():
for pattern in patterns:
# Require exact match or pattern must be a significant portion
if pattern == normalized:
# Exact match - use immediately
best_match = field
best_match_len = len(pattern) + 100 # Prioritize exact
break
elif pattern in normalized and len(pattern) > best_match_len:
# Pattern found in header - use longer matches
if len(pattern) >= 3: # Minimum pattern length
best_match = field
best_match_len = len(pattern)
if best_match_len > 100: # Was exact match
break
if best_match:
mapping[idx] = best_match
return mapping
def parse_table_html(html: str) -> tuple[list[str], list[list[str]]]:
"""Parse HTML table and return header and rows."""
parser = TableHTMLParser()
parser.feed(html)
return parser.header_row, parser.rows
def detect_header_row(rows: list[list[str]]) -> tuple[int, list[str], bool]:
"""
Detect which row is the header based on content patterns.
Returns (header_row_index, header_row, is_at_end).
is_at_end indicates if header is at the end (table is reversed).
Returns (-1, [], False) if no header detected.
"""
header_keywords = set()
for patterns in COLUMN_MAPPINGS.values():
for p in patterns:
header_keywords.add(p.lower())
best_match = (-1, [], 0)
for i, row in enumerate(rows):
# Skip empty rows
if all(not cell.strip() for cell in row):
continue
# Check if row contains header keywords
row_text = " ".join(cell.lower() for cell in row)
matches = sum(1 for kw in header_keywords if kw in row_text)
# Track the best match
if matches > best_match[2]:
best_match = (i, row, matches)
if best_match[2] >= 2:
header_idx = best_match[0]
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
return header_idx, best_match[1], is_at_end
return -1, [], False
def extract_line_items(html: str) -> list[LineItem]:
"""Extract line items from HTML table."""
header, rows = parse_table_html(html)
is_reversed = False
if not header:
# Try to detect header row from content
header_idx, detected_header, is_at_end = detect_header_row(rows)
if header_idx >= 0:
header = detected_header
if is_at_end:
# Header is at the end - table is reversed
is_reversed = True
rows = rows[:header_idx] # Data rows are before header
else:
rows = rows[header_idx + 1:] # Data rows start after header
elif rows:
# Fall back to first non-empty row
for i, row in enumerate(rows):
if any(cell.strip() for cell in row):
header = row
rows = rows[i + 1:]
break
column_map = map_columns(header)
items = []
for row_idx, row in enumerate(rows):
item_data = {
'row_index': row_idx,
'article_number': None,
'description': None,
'quantity': None,
'unit': None,
'unit_price': None,
'amount': None,
'vat_rate': None,
}
for col_idx, cell in enumerate(row):
if col_idx in column_map:
field = column_map[col_idx]
item_data[field] = cell if cell else None
# Only add if we have at least description or amount
if item_data['description'] or item_data['amount']:
items.append(LineItem(**item_data))
return items
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
"""Render first page of PDF to image bytes."""
doc = fitz.open(pdf_path)
page = doc[0]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pix = page.get_pixmap(matrix=mat)
img_bytes = pix.tobytes("png")
doc.close()
return img_bytes
def test_line_items_extraction(pdf_path: str) -> dict:
"""Test line items extraction on a PDF."""
print(f"\n{'='*70}")
print(f"Testing Line Items Extraction: {Path(pdf_path).name}")
print(f"{'='*70}")
# Render PDF to image
print("Rendering PDF to image...")
img_bytes = render_pdf_to_image(pdf_path)
# Save temp image
temp_img_path = "/tmp/test_invoice.png"
with open(temp_img_path, "wb") as f:
f.write(img_bytes)
# Initialize PP-StructureV3
print("Initializing PP-StructureV3...")
pipeline = PPStructureV3(
device="gpu:0",
use_doc_orientation_classify=False,
use_doc_unwarping=False,
)
# Run detection
print("Running table detection...")
results = pipeline.predict(temp_img_path)
all_line_items = []
table_details = []
for result in results if results else []:
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
if table_res_list:
print(f"\nFound {len(table_res_list)} tables")
for i, table_res in enumerate(table_res_list):
html = table_res.get("pred_html", "")
ocr_pred = table_res.get("table_ocr_pred", {})
print(f"\n--- Table {i+1} ---")
# Debug: show full HTML for first table
if i == 0:
print(f" Full HTML:\n{html}")
# Debug: inspect table_ocr_pred structure
if isinstance(ocr_pred, dict):
print(f" table_ocr_pred keys: {list(ocr_pred.keys())}")
# Check if rec_texts exists (actual OCR text)
if "rec_texts" in ocr_pred:
texts = ocr_pred["rec_texts"]
print(f" OCR texts count: {len(texts)}")
print(f" Sample OCR texts: {texts[:5]}")
elif isinstance(ocr_pred, list):
print(f" table_ocr_pred is list with {len(ocr_pred)} items")
if ocr_pred:
print(f" First item type: {type(ocr_pred[0])}")
print(f" First few items: {ocr_pred[:3]}")
# Parse HTML
header, rows = parse_table_html(html)
print(f" HTML Header (from thead): {header}")
print(f" HTML Rows: {len(rows)}")
# Try to detect header if not in thead
detected_header = None
is_reversed = False
if not header and rows:
header_idx, detected_header, is_at_end = detect_header_row(rows)
if header_idx >= 0:
is_reversed = is_at_end
print(f" Detected header at row {header_idx}: {detected_header}")
print(f" Table is {'REVERSED (header at bottom)' if is_reversed else 'normal'}")
header = detected_header
if rows:
print(f" First row: {rows[0]}")
if len(rows) > 1:
print(f" Second row: {rows[1]}")
# Check if this looks like a line items table
column_map = map_columns(header) if header else {}
print(f" Column mapping: {column_map}")
is_line_items_table = (
'description' in column_map.values() or
'amount' in column_map.values() or
'article_number' in column_map.values()
)
if is_line_items_table:
print(f" >>> This appears to be a LINE ITEMS table!")
items = extract_line_items(html)
print(f" Extracted {len(items)} line items:")
for item in items:
print(f" - {item.description}: {item.quantity} x {item.unit_price} = {item.amount}")
all_line_items.extend(items)
else:
print(f" >>> This is NOT a line items table (summary/payment)")
table_details.append({
"index": i,
"header": header,
"row_count": len(rows),
"is_line_items": is_line_items_table,
"column_map": column_map,
})
print(f"\n{'='*70}")
print(f"EXTRACTION SUMMARY")
print(f"{'='*70}")
print(f"Total tables: {len(table_details)}")
print(f"Line items tables: {sum(1 for t in table_details if t['is_line_items'])}")
print(f"Total line items: {len(all_line_items)}")
return {
"pdf": pdf_path,
"tables": table_details,
"line_items": all_line_items,
}
def main():
import argparse
parser = argparse.ArgumentParser(description="Test line items extraction")
parser.add_argument("--pdf", type=str, help="Path to PDF file")
args = parser.parse_args()
if args.pdf:
# Test specific PDF
pdf_path = Path(args.pdf)
if not pdf_path.exists():
# Try relative to project root
pdf_path = project_root / args.pdf
if not pdf_path.exists():
print(f"PDF not found: {args.pdf}")
return
test_line_items_extraction(str(pdf_path))
else:
# Test default invoice
default_pdf = project_root / "exampl" / "Faktura54011.pdf"
if default_pdf.exists():
test_line_items_extraction(str(default_pdf))
else:
print(f"Default PDF not found: {default_pdf}")
print("Usage: python ppstructure_line_items_poc.py --pdf <path>")
if __name__ == "__main__":
main()