Update paddle, and support invoice line item
This commit is contained in:
387
scripts/ppstructure_line_items_poc.py
Normal file
387
scripts/ppstructure_line_items_poc.py
Normal file
@@ -0,0 +1,387 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
PP-StructureV3 Line Items Extraction POC
|
||||
|
||||
Tests line items extraction from Swedish invoices using PP-StructureV3.
|
||||
Parses HTML table structure to extract structured line item data.
|
||||
|
||||
Run with invoice-sm120 conda environment.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
from html.parser import HTMLParser
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "packages" / "backend"))
|
||||
|
||||
from paddleocr import PPStructureV3
|
||||
import fitz # PyMuPDF
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItem:
|
||||
"""Single line item from invoice."""
|
||||
row_index: int
|
||||
article_number: str | None
|
||||
description: str | None
|
||||
quantity: str | None
|
||||
unit: str | None
|
||||
unit_price: str | None
|
||||
amount: str | None
|
||||
vat_rate: str | None
|
||||
confidence: float = 0.9
|
||||
|
||||
|
||||
class TableHTMLParser(HTMLParser):
|
||||
"""Parse HTML table into rows and cells."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rows: list[list[str]] = []
|
||||
self.current_row: list[str] = []
|
||||
self.current_cell: str = ""
|
||||
self.in_td = False
|
||||
self.in_thead = False
|
||||
self.header_row: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "tr":
|
||||
self.current_row = []
|
||||
elif tag in ("td", "th"):
|
||||
self.in_td = True
|
||||
self.current_cell = ""
|
||||
elif tag == "thead":
|
||||
self.in_thead = True
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if tag in ("td", "th"):
|
||||
self.in_td = False
|
||||
self.current_row.append(self.current_cell.strip())
|
||||
elif tag == "tr":
|
||||
if self.current_row:
|
||||
if self.in_thead:
|
||||
self.header_row = self.current_row
|
||||
else:
|
||||
self.rows.append(self.current_row)
|
||||
elif tag == "thead":
|
||||
self.in_thead = False
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.in_td:
|
||||
self.current_cell += data
|
||||
|
||||
|
||||
# Swedish column name mappings
|
||||
# Note: Some headers may contain multiple column names merged together
|
||||
COLUMN_MAPPINGS = {
|
||||
'article_number': ['art nummer', 'artikelnummer', 'artikel', 'artnr', 'art.nr', 'art nr'],
|
||||
'description': ['beskrivning', 'produktbeskrivning', 'produkt', 'tjänst', 'text', 'benämning', 'vara/tjänst', 'vara'],
|
||||
'quantity': ['antal', 'qty', 'st', 'pcs', 'kvantitet'],
|
||||
'unit': ['enhet', 'unit'],
|
||||
'unit_price': ['á-pris', 'a-pris', 'pris', 'styckpris', 'enhetspris', 'à pris'],
|
||||
'amount': ['belopp', 'summa', 'total', 'netto', 'rad summa'],
|
||||
'vat_rate': ['moms', 'moms%', 'vat', 'skatt', 'moms %'],
|
||||
}
|
||||
|
||||
|
||||
def normalize_header(header: str) -> str:
|
||||
"""Normalize header text for matching."""
|
||||
return header.lower().strip().replace(".", "").replace("-", " ")
|
||||
|
||||
|
||||
def map_columns(headers: list[str]) -> dict[int, str]:
|
||||
"""Map column indices to field names."""
|
||||
mapping = {}
|
||||
for idx, header in enumerate(headers):
|
||||
normalized = normalize_header(header)
|
||||
|
||||
# Skip empty headers
|
||||
if not normalized.strip():
|
||||
continue
|
||||
|
||||
best_match = None
|
||||
best_match_len = 0
|
||||
|
||||
for field, patterns in COLUMN_MAPPINGS.items():
|
||||
for pattern in patterns:
|
||||
# Require exact match or pattern must be a significant portion
|
||||
if pattern == normalized:
|
||||
# Exact match - use immediately
|
||||
best_match = field
|
||||
best_match_len = len(pattern) + 100 # Prioritize exact
|
||||
break
|
||||
elif pattern in normalized and len(pattern) > best_match_len:
|
||||
# Pattern found in header - use longer matches
|
||||
if len(pattern) >= 3: # Minimum pattern length
|
||||
best_match = field
|
||||
best_match_len = len(pattern)
|
||||
|
||||
if best_match_len > 100: # Was exact match
|
||||
break
|
||||
|
||||
if best_match:
|
||||
mapping[idx] = best_match
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def parse_table_html(html: str) -> tuple[list[str], list[list[str]]]:
|
||||
"""Parse HTML table and return header and rows."""
|
||||
parser = TableHTMLParser()
|
||||
parser.feed(html)
|
||||
return parser.header_row, parser.rows
|
||||
|
||||
|
||||
def detect_header_row(rows: list[list[str]]) -> tuple[int, list[str], bool]:
|
||||
"""
|
||||
Detect which row is the header based on content patterns.
|
||||
|
||||
Returns (header_row_index, header_row, is_at_end).
|
||||
is_at_end indicates if header is at the end (table is reversed).
|
||||
Returns (-1, [], False) if no header detected.
|
||||
"""
|
||||
header_keywords = set()
|
||||
for patterns in COLUMN_MAPPINGS.values():
|
||||
for p in patterns:
|
||||
header_keywords.add(p.lower())
|
||||
|
||||
best_match = (-1, [], 0)
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
# Skip empty rows
|
||||
if all(not cell.strip() for cell in row):
|
||||
continue
|
||||
|
||||
# Check if row contains header keywords
|
||||
row_text = " ".join(cell.lower() for cell in row)
|
||||
matches = sum(1 for kw in header_keywords if kw in row_text)
|
||||
|
||||
# Track the best match
|
||||
if matches > best_match[2]:
|
||||
best_match = (i, row, matches)
|
||||
|
||||
if best_match[2] >= 2:
|
||||
header_idx = best_match[0]
|
||||
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
|
||||
return header_idx, best_match[1], is_at_end
|
||||
|
||||
return -1, [], False
|
||||
|
||||
|
||||
def extract_line_items(html: str) -> list[LineItem]:
|
||||
"""Extract line items from HTML table."""
|
||||
header, rows = parse_table_html(html)
|
||||
|
||||
is_reversed = False
|
||||
if not header:
|
||||
# Try to detect header row from content
|
||||
header_idx, detected_header, is_at_end = detect_header_row(rows)
|
||||
if header_idx >= 0:
|
||||
header = detected_header
|
||||
if is_at_end:
|
||||
# Header is at the end - table is reversed
|
||||
is_reversed = True
|
||||
rows = rows[:header_idx] # Data rows are before header
|
||||
else:
|
||||
rows = rows[header_idx + 1:] # Data rows start after header
|
||||
elif rows:
|
||||
# Fall back to first non-empty row
|
||||
for i, row in enumerate(rows):
|
||||
if any(cell.strip() for cell in row):
|
||||
header = row
|
||||
rows = rows[i + 1:]
|
||||
break
|
||||
|
||||
column_map = map_columns(header)
|
||||
|
||||
items = []
|
||||
for row_idx, row in enumerate(rows):
|
||||
item_data = {
|
||||
'row_index': row_idx,
|
||||
'article_number': None,
|
||||
'description': None,
|
||||
'quantity': None,
|
||||
'unit': None,
|
||||
'unit_price': None,
|
||||
'amount': None,
|
||||
'vat_rate': None,
|
||||
}
|
||||
|
||||
for col_idx, cell in enumerate(row):
|
||||
if col_idx in column_map:
|
||||
field = column_map[col_idx]
|
||||
item_data[field] = cell if cell else None
|
||||
|
||||
# Only add if we have at least description or amount
|
||||
if item_data['description'] or item_data['amount']:
|
||||
items.append(LineItem(**item_data))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def render_pdf_to_image(pdf_path: str, dpi: int = 200) -> bytes:
|
||||
"""Render first page of PDF to image bytes."""
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[0]
|
||||
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
img_bytes = pix.tobytes("png")
|
||||
doc.close()
|
||||
return img_bytes
|
||||
|
||||
|
||||
def test_line_items_extraction(pdf_path: str) -> dict:
|
||||
"""Test line items extraction on a PDF."""
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Testing Line Items Extraction: {Path(pdf_path).name}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Render PDF to image
|
||||
print("Rendering PDF to image...")
|
||||
img_bytes = render_pdf_to_image(pdf_path)
|
||||
|
||||
# Save temp image
|
||||
temp_img_path = "/tmp/test_invoice.png"
|
||||
with open(temp_img_path, "wb") as f:
|
||||
f.write(img_bytes)
|
||||
|
||||
# Initialize PP-StructureV3
|
||||
print("Initializing PP-StructureV3...")
|
||||
pipeline = PPStructureV3(
|
||||
device="gpu:0",
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
)
|
||||
|
||||
# Run detection
|
||||
print("Running table detection...")
|
||||
results = pipeline.predict(temp_img_path)
|
||||
|
||||
all_line_items = []
|
||||
table_details = []
|
||||
|
||||
for result in results if results else []:
|
||||
table_res_list = result.get("table_res_list") if hasattr(result, "get") else None
|
||||
|
||||
if table_res_list:
|
||||
print(f"\nFound {len(table_res_list)} tables")
|
||||
|
||||
for i, table_res in enumerate(table_res_list):
|
||||
html = table_res.get("pred_html", "")
|
||||
ocr_pred = table_res.get("table_ocr_pred", {})
|
||||
|
||||
print(f"\n--- Table {i+1} ---")
|
||||
|
||||
# Debug: show full HTML for first table
|
||||
if i == 0:
|
||||
print(f" Full HTML:\n{html}")
|
||||
|
||||
# Debug: inspect table_ocr_pred structure
|
||||
if isinstance(ocr_pred, dict):
|
||||
print(f" table_ocr_pred keys: {list(ocr_pred.keys())}")
|
||||
# Check if rec_texts exists (actual OCR text)
|
||||
if "rec_texts" in ocr_pred:
|
||||
texts = ocr_pred["rec_texts"]
|
||||
print(f" OCR texts count: {len(texts)}")
|
||||
print(f" Sample OCR texts: {texts[:5]}")
|
||||
elif isinstance(ocr_pred, list):
|
||||
print(f" table_ocr_pred is list with {len(ocr_pred)} items")
|
||||
if ocr_pred:
|
||||
print(f" First item type: {type(ocr_pred[0])}")
|
||||
print(f" First few items: {ocr_pred[:3]}")
|
||||
|
||||
# Parse HTML
|
||||
header, rows = parse_table_html(html)
|
||||
print(f" HTML Header (from thead): {header}")
|
||||
print(f" HTML Rows: {len(rows)}")
|
||||
|
||||
# Try to detect header if not in thead
|
||||
detected_header = None
|
||||
is_reversed = False
|
||||
if not header and rows:
|
||||
header_idx, detected_header, is_at_end = detect_header_row(rows)
|
||||
if header_idx >= 0:
|
||||
is_reversed = is_at_end
|
||||
print(f" Detected header at row {header_idx}: {detected_header}")
|
||||
print(f" Table is {'REVERSED (header at bottom)' if is_reversed else 'normal'}")
|
||||
header = detected_header
|
||||
|
||||
if rows:
|
||||
print(f" First row: {rows[0]}")
|
||||
if len(rows) > 1:
|
||||
print(f" Second row: {rows[1]}")
|
||||
|
||||
# Check if this looks like a line items table
|
||||
column_map = map_columns(header) if header else {}
|
||||
print(f" Column mapping: {column_map}")
|
||||
|
||||
is_line_items_table = (
|
||||
'description' in column_map.values() or
|
||||
'amount' in column_map.values() or
|
||||
'article_number' in column_map.values()
|
||||
)
|
||||
|
||||
if is_line_items_table:
|
||||
print(f" >>> This appears to be a LINE ITEMS table!")
|
||||
items = extract_line_items(html)
|
||||
print(f" Extracted {len(items)} line items:")
|
||||
for item in items:
|
||||
print(f" - {item.description}: {item.quantity} x {item.unit_price} = {item.amount}")
|
||||
all_line_items.extend(items)
|
||||
else:
|
||||
print(f" >>> This is NOT a line items table (summary/payment)")
|
||||
|
||||
table_details.append({
|
||||
"index": i,
|
||||
"header": header,
|
||||
"row_count": len(rows),
|
||||
"is_line_items": is_line_items_table,
|
||||
"column_map": column_map,
|
||||
})
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f"EXTRACTION SUMMARY")
|
||||
print(f"{'='*70}")
|
||||
print(f"Total tables: {len(table_details)}")
|
||||
print(f"Line items tables: {sum(1 for t in table_details if t['is_line_items'])}")
|
||||
print(f"Total line items: {len(all_line_items)}")
|
||||
|
||||
return {
|
||||
"pdf": pdf_path,
|
||||
"tables": table_details,
|
||||
"line_items": all_line_items,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test line items extraction")
|
||||
parser.add_argument("--pdf", type=str, help="Path to PDF file")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pdf:
|
||||
# Test specific PDF
|
||||
pdf_path = Path(args.pdf)
|
||||
if not pdf_path.exists():
|
||||
# Try relative to project root
|
||||
pdf_path = project_root / args.pdf
|
||||
if not pdf_path.exists():
|
||||
print(f"PDF not found: {args.pdf}")
|
||||
return
|
||||
test_line_items_extraction(str(pdf_path))
|
||||
else:
|
||||
# Test default invoice
|
||||
default_pdf = project_root / "exampl" / "Faktura54011.pdf"
|
||||
if default_pdf.exists():
|
||||
test_line_items_extraction(str(default_pdf))
|
||||
else:
|
||||
print(f"Default PDF not found: {default_pdf}")
|
||||
print("Usage: python ppstructure_line_items_poc.py --pdf <path>")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user