388 lines
13 KiB
Python
388 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
PP-StructureV3 Line Items Extraction POC
|
|
|
|
Tests line items extraction from Swedish invoices using PP-StructureV3.
|
|
Parses HTML table structure to extract structured line item data.
|
|
|
|
Run with invoice-sm120 conda environment.
|
|
"""
|
|
|
|
import sys
|
|
import re
|
|
from pathlib import Path
|
|
from html.parser import HTMLParser
|
|
from dataclasses import dataclass
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root / "packages" / "backend"))
|
|
|
|
from paddleocr import PPStructureV3
|
|
import fitz # PyMuPDF
|
|
|
|
|
|
@dataclass
|
|
class LineItem:
|
|
"""Single line item from invoice."""
|
|
row_index: int
|
|
article_number: str | None
|
|
description: str | None
|
|
quantity: str | None
|
|
unit: str | None
|
|
unit_price: str | None
|
|
amount: str | None
|
|
vat_rate: str | None
|
|
confidence: float = 0.9
|
|
|
|
|
|
class TableHTMLParser(HTMLParser):
|
|
"""Parse HTML table into rows and cells."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rows: list[list[str]] = []
|
|
self.current_row: list[str] = []
|
|
self.current_cell: str = ""
|
|
self.in_td = False
|
|
self.in_thead = False
|
|
self.header_row: list[str] = []
|
|
|
|
def handle_starttag(self, tag, attrs):
|
|
if tag == "tr":
|
|
self.current_row = []
|
|
elif tag in ("td", "th"):
|
|
self.in_td = True
|
|
self.current_cell = ""
|
|
elif tag == "thead":
|
|
self.in_thead = True
|
|
|
|
def handle_endtag(self, tag):
|
|
if tag in ("td", "th"):
|
|
self.in_td = False
|
|
self.current_row.append(self.current_cell.strip())
|
|
elif tag == "tr":
|
|
if self.current_row:
|
|
if self.in_thead:
|
|
self.header_row = self.current_row
|
|
else:
|
|
self.rows.append(self.current_row)
|
|
elif tag == "thead":
|
|
self.in_thead = False
|
|
|
|
def handle_data(self, data):
|
|
if self.in_td:
|
|
self.current_cell += data
|
|
|
|
|
|
# Swedish column name mappings
|
|
# Note: Some headers may contain multiple column names merged together
|
|
COLUMN_MAPPINGS = {
|
|
'article_number': ['art nummer', 'artikelnummer', 'artikel', 'artnr', 'art.nr', 'art nr'],
|
|
'description': ['beskrivning', 'produktbeskrivning', 'produkt', 'tjänst', 'text', 'benämning', 'vara/tjänst', 'vara'],
|
|
'quantity': ['antal', 'qty', 'st', 'pcs', 'kvantitet'],
|
|
'unit': ['enhet', 'unit'],
|
|
'unit_price': ['á-pris', 'a-pris', 'pris', 'styckpris', 'enhetspris', 'à pris'],
|
|
'amount': ['belopp', 'summa', 'total', 'netto', 'rad summa'],
|
|
'vat_rate': ['moms', 'moms%', 'vat', 'skatt', 'moms %'],
|
|
}
|
|
|
|
|
|
def normalize_header(header: str) -> str:
|
|
"""Normalize header text for matching."""
|
|
return header.lower().strip().replace(".", "").replace("-", " ")
|
|
|
|
|
|
def map_columns(headers: list[str]) -> dict[int, str]:
|
|
"""Map column indices to field names."""
|
|
mapping = {}
|
|
for idx, header in enumerate(headers):
|
|
normalized = normalize_header(header)
|
|
|
|
# Skip empty headers
|
|
if not normalized.strip():
|
|
continue
|
|
|
|
best_match = None
|
|
best_match_len = 0
|
|
|
|
for field, patterns in COLUMN_MAPPINGS.items():
|
|
for pattern in patterns:
|
|
# Require exact match or pattern must be a significant portion
|
|
if pattern == normalized:
|
|
# Exact match - use immediately
|
|
best_match = field
|
|
best_match_len = len(pattern) + 100 # Prioritize exact
|
|
break
|
|
elif pattern in normalized and len(pattern) > best_match_len:
|
|
# Pattern found in header - use longer matches
|
|
if len(pattern) >= 3: # Minimum pattern length
|
|
best_match = field
|
|
best_match_len = len(pattern)
|
|
|
|
if best_match_len > 100: # Was exact match
|
|
break
|
|
|
|
if best_match:
|
|
mapping[idx] = best_match
|
|
|
|
return mapping
|
|
|
|
|
|
def parse_table_html(html: str) -> tuple[list[str], list[list[str]]]:
|
|
"""Parse HTML table and return header and rows."""
|
|
parser = TableHTMLParser()
|
|
parser.feed(html)
|
|
return parser.header_row, parser.rows
|
|
|
|
|
|
def detect_header_row(rows: list[list[str]]) -> tuple[int, list[str], bool]:
|
|
"""
|
|
Detect which row is the header based on content patterns.
|
|
|
|
Returns (header_row_index, header_row, is_at_end).
|
|
is_at_end indicates if header is at the end (table is reversed).
|
|
Returns (-1, [], False) if no header detected.
|
|
"""
|
|
header_keywords = set()
|
|
for patterns in COLUMN_MAPPINGS.values():
|
|
for p in patterns:
|
|
header_keywords.add(p.lower())
|
|
|
|
best_match = (-1, [], 0)
|
|
|
|
for i, row in enumerate(rows):
|
|
# Skip empty rows
|
|
if all(not cell.strip() for cell in row):
|
|
continue
|
|
|
|
# Check if row contains header keywords
|
|
row_text = " ".join(cell.lower() for cell in row)
|
|
matches = sum(1 for kw in header_keywords if kw in row_text)
|
|
|
|
# Track the best match
|
|
if matches > best_match[2]:
|
|
best_match = (i, row, matches)
|
|
|
|
if best_match[2] >= 2:
|
|
header_idx = best_match[0]
|
|
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
|
|
return header_idx, best_match[1], is_at_end
|
|
|
|
return -1, [], False
|
|
|
|
|
|
def extract_line_items(html: str) -> list[LineItem]:
|
|
"""Extract line items from HTML table."""
|
|
header, rows = parse_table_html(html)
|
|
|
|
is_reversed = False
|
|
if not header:
|
|
# Try to detect header row from content
|
|
header_idx, detected_header, is_at_end = detect_header_row(rows)
|
|
if header_idx >= 0:
|
|
header = detected_header
|
|
if is_at_end:
|
|
# Header is at the end - table is reversed
|
|
is_reversed = True
|
|
rows = rows[:header_idx] # Data rows are before header
|
|
else:
|
|
rows = rows[header_idx + 1:] # Data rows start after header
|
|
elif rows:
|
|
# Fall back to first non-empty row
|
|
for i, row in enumerate(rows):
|
|
if any(cell.strip() for cell in row):
|
|
header = row
|
|
rows = rows[i + 1:]
|
|
break
|
|
|
|
column_map = map_columns(header)
|
|
|
|
items = []
|
|
for row_idx, row in enumerate(rows):
|
|
item_data = {
|
|
'row_index': row_idx,
|
|
'article_number': None,
|
|
'description': None,
|
|
'quantity': None,
|
|
'unit': None,
|
|
'unit_price': None,
|
|
'amount': None,
|
|
'vat_rate': None,
|
|
}
|
|
|
|
for col_idx, cell in enumerate(row):
|
|
if col_idx in column_map:
|
|
field = column_map[col_idx]
|
|
item_data[field] = cell if cell else None
|
|
|
|
# Only add if we have at least description or amount
|
|
if item_data['description'] or item_data['amount']:
|
|
items.append(LineItem(**item_data))
|
|
|
|
return items
|
|
|
|
|
|
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
|
|
"""Render first page of PDF to image bytes."""
|
|
doc = fitz.open(pdf_path)
|
|
page = doc[0]
|
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
|
pix = page.get_pixmap(matrix=mat)
|
|
img_bytes = pix.tobytes("png")
|
|
doc.close()
|
|
return img_bytes
|
|
|
|
|
|
def test_line_items_extraction(pdf_path: str) -> dict:
|
|
"""Test line items extraction on a PDF."""
|
|
print(f"\n{'='*70}")
|
|
print(f"Testing Line Items Extraction: {Path(pdf_path).name}")
|
|
print(f"{'='*70}")
|
|
|
|
# Render PDF to image
|
|
print("Rendering PDF to image...")
|
|
img_bytes = render_pdf_to_image(pdf_path)
|
|
|
|
# Save temp image
|
|
temp_img_path = "/tmp/test_invoice.png"
|
|
with open(temp_img_path, "wb") as f:
|
|
f.write(img_bytes)
|
|
|
|
# Initialize PP-StructureV3
|
|
print("Initializing PP-StructureV3...")
|
|
pipeline = PPStructureV3(
|
|
device="gpu:0",
|
|
use_doc_orientation_classify=False,
|
|
use_doc_unwarping=False,
|
|
)
|
|
|
|
# Run detection
|
|
print("Running table detection...")
|
|
results = pipeline.predict(temp_img_path)
|
|
|
|
all_line_items = []
|
|
table_details = []
|
|
|
|
for result in results if results else []:
|
|
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
|
|
|
|
if table_res_list:
|
|
print(f"\nFound {len(table_res_list)} tables")
|
|
|
|
for i, table_res in enumerate(table_res_list):
|
|
html = table_res.get("pred_html", "")
|
|
ocr_pred = table_res.get("table_ocr_pred", {})
|
|
|
|
print(f"\n--- Table {i+1} ---")
|
|
|
|
# Debug: show full HTML for first table
|
|
if i == 0:
|
|
print(f" Full HTML:\n{html}")
|
|
|
|
# Debug: inspect table_ocr_pred structure
|
|
if isinstance(ocr_pred, dict):
|
|
print(f" table_ocr_pred keys: {list(ocr_pred.keys())}")
|
|
# Check if rec_texts exists (actual OCR text)
|
|
if "rec_texts" in ocr_pred:
|
|
texts = ocr_pred["rec_texts"]
|
|
print(f" OCR texts count: {len(texts)}")
|
|
print(f" Sample OCR texts: {texts[:5]}")
|
|
elif isinstance(ocr_pred, list):
|
|
print(f" table_ocr_pred is list with {len(ocr_pred)} items")
|
|
if ocr_pred:
|
|
print(f" First item type: {type(ocr_pred[0])}")
|
|
print(f" First few items: {ocr_pred[:3]}")
|
|
|
|
# Parse HTML
|
|
header, rows = parse_table_html(html)
|
|
print(f" HTML Header (from thead): {header}")
|
|
print(f" HTML Rows: {len(rows)}")
|
|
|
|
# Try to detect header if not in thead
|
|
detected_header = None
|
|
is_reversed = False
|
|
if not header and rows:
|
|
header_idx, detected_header, is_at_end = detect_header_row(rows)
|
|
if header_idx >= 0:
|
|
is_reversed = is_at_end
|
|
print(f" Detected header at row {header_idx}: {detected_header}")
|
|
print(f" Table is {'REVERSED (header at bottom)' if is_reversed else 'normal'}")
|
|
header = detected_header
|
|
|
|
if rows:
|
|
print(f" First row: {rows[0]}")
|
|
if len(rows) > 1:
|
|
print(f" Second row: {rows[1]}")
|
|
|
|
# Check if this looks like a line items table
|
|
column_map = map_columns(header) if header else {}
|
|
print(f" Column mapping: {column_map}")
|
|
|
|
is_line_items_table = (
|
|
'description' in column_map.values() or
|
|
'amount' in column_map.values() or
|
|
'article_number' in column_map.values()
|
|
)
|
|
|
|
if is_line_items_table:
|
|
print(f" >>> This appears to be a LINE ITEMS table!")
|
|
items = extract_line_items(html)
|
|
print(f" Extracted {len(items)} line items:")
|
|
for item in items:
|
|
print(f" - {item.description}: {item.quantity} x {item.unit_price} = {item.amount}")
|
|
all_line_items.extend(items)
|
|
else:
|
|
print(f" >>> This is NOT a line items table (summary/payment)")
|
|
|
|
table_details.append({
|
|
"index": i,
|
|
"header": header,
|
|
"row_count": len(rows),
|
|
"is_line_items": is_line_items_table,
|
|
"column_map": column_map,
|
|
})
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f"EXTRACTION SUMMARY")
|
|
print(f"{'='*70}")
|
|
print(f"Total tables: {len(table_details)}")
|
|
print(f"Line items tables: {sum(1 for t in table_details if t['is_line_items'])}")
|
|
print(f"Total line items: {len(all_line_items)}")
|
|
|
|
return {
|
|
"pdf": pdf_path,
|
|
"tables": table_details,
|
|
"line_items": all_line_items,
|
|
}
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Test line items extraction")
|
|
parser.add_argument("--pdf", type=str, help="Path to PDF file")
|
|
args = parser.parse_args()
|
|
|
|
if args.pdf:
|
|
# Test specific PDF
|
|
pdf_path = Path(args.pdf)
|
|
if not pdf_path.exists():
|
|
# Try relative to project root
|
|
pdf_path = project_root / args.pdf
|
|
if not pdf_path.exists():
|
|
print(f"PDF not found: {args.pdf}")
|
|
return
|
|
test_line_items_extraction(str(pdf_path))
|
|
else:
|
|
# Test default invoice
|
|
default_pdf = project_root / "exampl" / "Faktura54011.pdf"
|
|
if default_pdf.exists():
|
|
test_line_items_extraction(str(default_pdf))
|
|
else:
|
|
print(f"Default PDF not found: {default_pdf}")
|
|
print("Usage: python ppstructure_line_items_poc.py --pdf <path>")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|