Update paddle, and support invoice line item

This commit is contained in:
Yaojia Wang
2026-02-03 21:28:06 +01:00
parent c4e3773df1
commit 35988b1ebf
41 changed files with 6832 additions and 48 deletions

View File

@@ -1,5 +1,18 @@
from .pipeline import InferencePipeline, InferenceResult
from .pipeline import (
InferencePipeline,
InferenceResult,
CrossValidationResult,
BUSINESS_FEATURES_AVAILABLE,
)
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']
__all__ = [
'InferencePipeline',
'InferenceResult',
'CrossValidationResult',
'YOLODetector',
'Detection',
'FieldExtractor',
'BUSINESS_FEATURES_AVAILABLE',
]

View File

@@ -2,19 +2,39 @@
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
Supports both basic field extraction and business invoice features
(line items, VAT extraction, cross-validation).
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import logging
import time
import re
logger = logging.getLogger(__name__)
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
# Business invoice feature imports (optional - for extract_line_items mode)
try:
from ..table.line_items_extractor import LineItem, LineItemsResult, LineItemsExtractor
from ..table.structure_detector import TableDetector
from ..vat.vat_extractor import VATSummary, VATExtractor
from ..validation.vat_validator import VATValidationResult, VATValidator
BUSINESS_FEATURES_AVAILABLE = True
except ImportError:
BUSINESS_FEATURES_AVAILABLE = False
LineItem = None
LineItemsResult = None
TableDetector = None
VATSummary = None
VATValidationResult = None
@dataclass
class CrossValidationResult:
@@ -45,6 +65,10 @@ class InferenceResult:
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
# Business invoice features (optional)
line_items: Any | None = None # LineItemsResult when available
vat_summary: Any | None = None # VATSummary when available
vat_validation: Any | None = None # VATValidationResult when available
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
@@ -81,8 +105,89 @@ class InferenceResult:
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
# Add business invoice features if present
if self.line_items is not None:
result['line_items'] = self._line_items_to_json()
if self.vat_summary is not None:
result['vat_summary'] = self._vat_summary_to_json()
if self.vat_validation is not None:
result['vat_validation'] = self._vat_validation_to_json()
return result
def _line_items_to_json(self) -> dict | None:
"""Convert LineItemsResult to JSON."""
if self.line_items is None:
return None
li = self.line_items
return {
'items': [
{
'row_index': item.row_index,
'description': item.description,
'quantity': item.quantity,
'unit': item.unit,
'unit_price': item.unit_price,
'amount': item.amount,
'article_number': item.article_number,
'vat_rate': item.vat_rate,
'is_deduction': item.is_deduction,
'confidence': item.confidence,
}
for item in li.items
],
'header_row': li.header_row,
'total_amount': li.total_amount,
}
def _vat_summary_to_json(self) -> dict | None:
"""Convert VATSummary to JSON."""
if self.vat_summary is None:
return None
vs = self.vat_summary
return {
'breakdowns': [
{
'rate': b.rate,
'base_amount': b.base_amount,
'vat_amount': b.vat_amount,
'source': b.source,
}
for b in vs.breakdowns
],
'total_excl_vat': vs.total_excl_vat,
'total_vat': vs.total_vat,
'total_incl_vat': vs.total_incl_vat,
'confidence': vs.confidence,
}
def _vat_validation_to_json(self) -> dict | None:
"""Convert VATValidationResult to JSON."""
if self.vat_validation is None:
return None
vv = self.vat_validation
return {
'is_valid': vv.is_valid,
'confidence_score': vv.confidence_score,
'math_checks': [
{
'rate': mc.rate,
'base_amount': mc.base_amount,
'expected_vat': mc.expected_vat,
'actual_vat': mc.actual_vat,
'is_valid': mc.is_valid,
'tolerance': mc.tolerance,
}
for mc in vv.math_checks
],
'total_check': vv.total_check,
'line_items_vs_summary': vv.line_items_vs_summary,
'amount_consistency': vv.amount_consistency,
'needs_review': vv.needs_review,
'review_reasons': vv.review_reasons,
}
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
@@ -107,7 +212,9 @@ class InferencePipeline:
ocr_lang: str = 'en',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
enable_fallback: bool = True,
enable_business_features: bool = False,
vat_tolerance: float = 0.5
):
"""
Initialize inference pipeline.
@@ -119,6 +226,8 @@ class InferencePipeline:
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
enable_business_features: Enable line items/VAT extraction
vat_tolerance: Tolerance for VAT math checks (in currency units)
"""
self.detector = YOLODetector(
model_path,
@@ -129,11 +238,34 @@ class InferencePipeline:
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
self.enable_business_features = enable_business_features
self.vat_tolerance = vat_tolerance
# Initialize business feature components if enabled and available
self.line_items_extractor = None
self.vat_extractor = None
self.vat_validator = None
self._business_ocr_engine = None # Lazy-initialized for VAT text extraction
self._table_detector = None # Shared TableDetector for line items extraction
if enable_business_features:
if not BUSINESS_FEATURES_AVAILABLE:
raise ImportError(
"Business features require table, vat, and validation modules. "
"Please ensure they are properly installed."
)
# Create shared TableDetector for performance (PP-StructureV3 init is slow)
self._table_detector = TableDetector()
# Pass shared detector to LineItemsExtractor
self.line_items_extractor = LineItemsExtractor(table_detector=self._table_detector)
self.vat_extractor = VATExtractor()
self.vat_validator = VATValidator(tolerance=vat_tolerance)
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
document_id: str | None = None,
extract_line_items: bool | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
@@ -141,6 +273,8 @@ class InferencePipeline:
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
extract_line_items: Whether to extract line items and VAT info.
If None, uses the enable_business_features setting from __init__.
Returns:
InferenceResult with extracted fields
@@ -156,9 +290,16 @@ class InferencePipeline:
document_id=document_id or Path(pdf_path).stem
)
# Determine if business features should be used
use_business_features = (
extract_line_items if extract_line_items is not None
else self.enable_business_features
)
try:
all_detections = []
all_extracted = []
all_ocr_text = [] # Collect OCR text for VAT extraction
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
@@ -175,6 +316,11 @@ class InferencePipeline:
extracted = self.extractor.extract_from_detection(detection, image_array)
all_extracted.append(extracted)
# Collect full-page OCR text for VAT extraction (only if business features enabled)
if use_business_features:
page_text = self._get_full_page_text(image_array)
all_ocr_text.append(page_text)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
@@ -185,6 +331,10 @@ class InferencePipeline:
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
# Extract business invoice features if enabled
if use_business_features:
self._extract_business_features(pdf_path, result, '\n'.join(all_ocr_text))
result.success = len(result.fields) > 0
except Exception as e:
@@ -194,6 +344,78 @@ class InferencePipeline:
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _get_full_page_text(self, image_array) -> str:
"""Extract full page text using OCR for VAT extraction."""
from shared.ocr import OCREngine
import logging
logger = logging.getLogger(__name__)
try:
# Lazy initialize OCR engine to avoid repeated model loading
if self._business_ocr_engine is None:
self._business_ocr_engine = OCREngine()
tokens = self._business_ocr_engine.extract_from_image(image_array, page_no=0)
return ' '.join(t.text for t in tokens)
except Exception as e:
logger.warning(f"OCR extraction for VAT failed: {e}")
return ""
def _extract_business_features(
self,
pdf_path: str | Path,
result: InferenceResult,
full_text: str
) -> None:
"""
Extract line items, VAT summary, and perform cross-validation.
Args:
pdf_path: Path to PDF file
result: InferenceResult to populate
full_text: Full OCR text from all pages
"""
if not BUSINESS_FEATURES_AVAILABLE:
result.errors.append("Business features not available")
return
if not self.line_items_extractor or not self.vat_extractor or not self.vat_validator:
result.errors.append("Business feature extractors not initialized")
return
try:
# Extract line items from tables
logger.info(f"Extracting line items from PDF: {pdf_path}")
line_items_result = self.line_items_extractor.extract_from_pdf(str(pdf_path))
logger.info(f"Line items extraction result: {line_items_result is not None}, items={len(line_items_result.items) if line_items_result else 0}")
if line_items_result and line_items_result.items:
result.line_items = line_items_result
logger.info(f"Set result.line_items with {len(line_items_result.items)} items")
# Extract VAT summary from text
logger.info(f"Extracting VAT summary from text ({len(full_text)} chars)")
vat_summary = self.vat_extractor.extract(full_text)
logger.info(f"VAT summary extraction result: {vat_summary is not None}")
if vat_summary:
result.vat_summary = vat_summary
# Cross-validate VAT information
existing_amount = result.fields.get('Amount')
vat_validation = self.vat_validator.validate(
vat_summary,
line_items=line_items_result,
existing_amount=str(existing_amount) if existing_amount else None
)
result.vat_validation = vat_validation
logger.info(f"VAT validation completed: is_valid={vat_validation.is_valid if vat_validation else None}")
except Exception as e:
import traceback
error_detail = f"{type(e).__name__}: {e}"
logger.error(f"Business feature extraction failed: {error_detail}\n{traceback.format_exc()}")
result.errors.append(f"Business feature extraction error: {error_detail}")
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
field_candidates: dict[str, list[ExtractedField]] = {}

View File

@@ -0,0 +1,32 @@
"""
Table detection and extraction module.
This module provides PP-StructureV3-based table detection for invoices,
and line items extraction from detected tables.
"""
from .structure_detector import (
TableDetectionResult,
TableDetector,
TableDetectorConfig,
)
from .line_items_extractor import (
LineItem,
LineItemsResult,
LineItemsExtractor,
ColumnMapper,
HTMLTableParser,
)
__all__ = [
# Structure detection
"TableDetectionResult",
"TableDetector",
"TableDetectorConfig",
# Line items extraction
"LineItem",
"LineItemsResult",
"LineItemsExtractor",
"ColumnMapper",
"HTMLTableParser",
]

View File

@@ -0,0 +1,970 @@
"""
Line Items Extractor
Extracts structured line items from HTML tables produced by PP-StructureV3.
Handles Swedish invoice formats including reversed tables (header at bottom).
Includes fallback text-based extraction for invoices without detectable table structures.
"""
from dataclasses import dataclass, field
from html.parser import HTMLParser
from decimal import Decimal, InvalidOperation
import re
import logging
logger = logging.getLogger(__name__)
@dataclass
class LineItem:
"""Single line item from invoice."""
row_index: int
description: str | None = None
quantity: str | None = None
unit: str | None = None
unit_price: str | None = None
amount: str | None = None
article_number: str | None = None
vat_rate: str | None = None
is_deduction: bool = False # True if this row is a deduction/discount
confidence: float = 0.9
@dataclass
class LineItemsResult:
"""Result of line items extraction."""
items: list[LineItem]
header_row: list[str]
raw_html: str
is_reversed: bool = False
@property
def total_amount(self) -> str | None:
"""Calculate total amount from line items (deduction rows have negative amounts)."""
if not self.items:
return None
total = Decimal("0")
for item in self.items:
if item.amount:
try:
# Parse Swedish number format (1 234,56)
amount_str = item.amount.replace(" ", "").replace(",", ".")
total += Decimal(amount_str)
except InvalidOperation:
pass
if total == 0:
return None
# Format back to Swedish format
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
# Fix the space/comma swap
parts = formatted.rsplit(",", 1)
if len(parts) == 2:
return parts[0].replace(" ", " ") + "," + parts[1]
return formatted
# Swedish column name mappings
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
COLUMN_MAPPINGS = {
"article_number": [
"art nummer",
"artikelnummer",
"artikel",
"artnr",
"art.nr",
"art nr",
"objektnummer", # Rental: property reference
"objekt",
],
"description": [
"beskrivning",
"produktbeskrivning",
"produkt",
"tjänst",
"text",
"benämning",
"vara/tjänst",
"vara",
# Rental invoice specific
"specifikation",
"spec",
"hyresperiod", # Rental period
"period",
"typ", # Type of charge
# Utility bills
"förbrukning", # Consumption
"avläsning", # Meter reading
],
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "", "kvm"],
"unit": ["enhet", "unit"],
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
"amount": [
"belopp",
"summa",
"total",
"netto",
"rad summa",
# Rental specific
"hyra", # Rent
"avgift", # Fee
"kostnad", # Cost
"debitering", # Charge
"totalt", # Total
],
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
# Additional field for rental: deductions/adjustments
"deduction": [
"avdrag", # Deduction
"rabatt", # Discount
"kredit", # Credit
],
}
# Keywords that indicate NOT a line items table
SUMMARY_KEYWORDS = [
"frakt",
"faktura.avg",
"fakturavg",
"exkl.moms",
"att betala",
"öresavr",
"bankgiro",
"plusgiro",
"ocr",
"forfallodatum",
"förfallodatum",
]
class _TableHTMLParser(HTMLParser):
"""Internal HTML parser for tables."""
def __init__(self):
super().__init__()
self.rows: list[list[str]] = []
self.current_row: list[str] = []
self.current_cell: str = ""
self.in_td = False
self.in_thead = False
self.header_row: list[str] = []
def handle_starttag(self, tag, attrs):
if tag == "tr":
self.current_row = []
elif tag in ("td", "th"):
self.in_td = True
self.current_cell = ""
elif tag == "thead":
self.in_thead = True
def handle_endtag(self, tag):
if tag in ("td", "th"):
self.in_td = False
self.current_row.append(self.current_cell.strip())
elif tag == "tr":
if self.current_row:
if self.in_thead:
self.header_row = self.current_row
else:
self.rows.append(self.current_row)
elif tag == "thead":
self.in_thead = False
def handle_data(self, data):
if self.in_td:
self.current_cell += data
class HTMLTableParser:
"""Parse HTML tables into structured data."""
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
"""
Parse HTML table and return header and rows.
Args:
html: HTML string containing table.
Returns:
Tuple of (header_row, data_rows).
"""
parser = _TableHTMLParser()
parser.feed(html)
return parser.header_row, parser.rows
class ColumnMapper:
"""Map column headers to field names."""
def __init__(self, mappings: dict[str, list[str]] | None = None):
"""
Initialize column mapper.
Args:
mappings: Custom column mappings. Uses Swedish defaults if None.
"""
self.mappings = mappings or COLUMN_MAPPINGS
def map(self, headers: list[str]) -> dict[int, str]:
"""
Map column indices to field names.
Args:
headers: List of column header strings.
Returns:
Dictionary mapping column index to field name.
"""
mapping = {}
for idx, header in enumerate(headers):
normalized = self._normalize(header)
if not normalized.strip():
continue
best_match = None
best_match_len = 0
for field_name, patterns in self.mappings.items():
for pattern in patterns:
if pattern == normalized:
best_match = field_name
best_match_len = len(pattern) + 100
break
elif pattern in normalized and len(pattern) > best_match_len:
if len(pattern) >= 3:
best_match = field_name
best_match_len = len(pattern)
if best_match_len > 100:
break
if best_match:
mapping[idx] = best_match
return mapping
def _normalize(self, header: str) -> str:
"""Normalize header text for matching."""
return header.lower().strip().replace(".", "").replace("-", " ")
class LineItemsExtractor:
"""Extract structured line items from HTML tables."""
def __init__(
self,
column_mapper: ColumnMapper | None = None,
table_detector: "TableDetector | None" = None,
enable_text_fallback: bool = True,
):
"""
Initialize extractor.
Args:
column_mapper: Custom column mapper. Uses default if None.
table_detector: Pre-initialized TableDetector to reuse. Creates new if None.
enable_text_fallback: Enable text-based fallback extraction when no tables detected.
"""
self.parser = HTMLTableParser()
self.mapper = column_mapper or ColumnMapper()
self._table_detector = table_detector
self._enable_text_fallback = enable_text_fallback
self._text_extractor = None # Lazy initialized
def extract(self, html: str) -> LineItemsResult:
"""
Extract line items from HTML table.
Args:
html: HTML string containing table.
Returns:
LineItemsResult with extracted items.
"""
header, rows = self.parser.parse(html)
is_reversed = False
# Check if cells contain merged multi-line data (PP-StructureV3 issue)
if rows and self._has_vertically_merged_cells(rows):
logger.info("Detected vertically merged cells, attempting to split")
header, rows = self._split_merged_rows(rows)
if not header:
header_idx, detected_header, is_at_end = self._detect_header_row(rows)
if header_idx >= 0:
header = detected_header
if is_at_end:
is_reversed = True
rows = rows[:header_idx]
else:
rows = rows[header_idx + 1 :]
elif rows:
for i, row in enumerate(rows):
if any(cell.strip() for cell in row):
header = row
rows = rows[i + 1 :]
break
column_map = self.mapper.map(header)
items = self._extract_items(rows, column_map)
# If no items extracted but header looks like line items table,
# try parsing merged cells (common in poorly OCR'd rental invoices)
if not items and self._has_merged_header(header):
logger.info(f"Trying merged cell parsing: header={header}, rows={rows}")
items = self._extract_from_merged_cells(header, rows)
logger.info(f"Merged cell parsing result: {len(items)} items")
return LineItemsResult(
items=items,
header_row=header,
raw_html=html,
is_reversed=is_reversed,
)
def _get_table_detector(self) -> "TableDetector":
"""Get or create TableDetector instance (lazy initialization)."""
if self._table_detector is None:
from .structure_detector import TableDetector
self._table_detector = TableDetector()
return self._table_detector
def _get_text_extractor(self) -> "TextLineItemsExtractor":
"""Get or create TextLineItemsExtractor instance (lazy initialization)."""
if self._text_extractor is None:
from .text_line_items_extractor import TextLineItemsExtractor
self._text_extractor = TextLineItemsExtractor()
return self._text_extractor
def extract_from_pdf(self, pdf_path: str) -> LineItemsResult | None:
"""
Extract line items from a PDF by detecting tables.
Uses PP-StructureV3 for table detection and extraction.
Falls back to text-based extraction if no tables detected.
Reuses TableDetector instance for performance.
Args:
pdf_path: Path to the PDF file.
Returns:
LineItemsResult if line items are found, None otherwise.
"""
# Reuse detector instance for performance
detector = self._get_table_detector()
tables, parsing_res_list = self._detect_tables_with_parsing(detector, pdf_path)
logger.info(f"LineItemsExtractor: detected {len(tables) if tables else 0} tables from PDF")
# Try table-based extraction first
best_result = self._extract_from_tables(tables)
# If no results from tables and fallback is enabled, try text-based extraction
if best_result is None and self._enable_text_fallback and parsing_res_list:
logger.info("LineItemsExtractor: no tables found, trying text-based fallback")
best_result = self._extract_from_text(parsing_res_list)
logger.info(f"LineItemsExtractor: final result has {len(best_result.items) if best_result else 0} items")
return best_result
def _detect_tables_with_parsing(
self, detector: "TableDetector", pdf_path: str
) -> tuple[list, list]:
"""
Detect tables and also return parsing_res_list for fallback.
Args:
detector: TableDetector instance.
pdf_path: Path to PDF file.
Returns:
Tuple of (table_results, parsing_res_list).
"""
from pathlib import Path
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
pdf_path = Path(pdf_path)
if not pdf_path.exists():
logger.warning(f"PDF not found: {pdf_path}")
return [], []
# Ensure detector is initialized
detector._ensure_initialized()
# Render first page
parsing_res_list = []
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=300):
if page_no == 0:
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run PP-StructureV3 and get raw results
if detector._pipeline is None:
return [], []
raw_results = detector._pipeline.predict(image_array)
# Extract parsing_res_list from raw results
if raw_results:
for result in raw_results if isinstance(raw_results, list) else [raw_results]:
if hasattr(result, "get"):
parsing_res_list = result.get("parsing_res_list", [])
elif hasattr(result, "parsing_res_list"):
parsing_res_list = result.parsing_res_list or []
# Parse tables using existing logic
tables = detector._parse_results(raw_results)
return tables, parsing_res_list
return [], []
def _extract_from_tables(self, tables: list) -> LineItemsResult | None:
"""Extract line items from detected tables."""
if not tables:
return None
best_result = None
best_item_count = 0
for i, table in enumerate(tables):
if not table.html:
logger.debug(f"Table {i}: no HTML content")
continue
logger.info(f"Table {i}: html_len={len(table.html)}, html={table.html[:500]}")
result = self.extract(table.html)
logger.info(f"Table {i}: extracted {len(result.items)} items, headers={result.header_row}")
# Check if this table has line items
is_line_items = self.is_line_items_table(result.header_row or [])
logger.info(f"Table {i}: is_line_items_table={is_line_items}")
if result.items and is_line_items:
if len(result.items) > best_item_count:
best_item_count = len(result.items)
best_result = result
logger.debug(f"Table {i}: selected as best (items={best_item_count})")
return best_result
def _extract_from_text(self, parsing_res_list: list) -> LineItemsResult | None:
"""Extract line items using text-based fallback."""
from .text_line_items_extractor import convert_text_line_item
text_extractor = self._get_text_extractor()
text_result = text_extractor.extract_from_parsing_res(parsing_res_list)
if text_result is None or not text_result.items:
logger.debug("Text-based extraction found no items")
return None
# Convert TextLineItems to LineItems
converted_items = [convert_text_line_item(item) for item in text_result.items]
logger.info(f"Text-based extraction found {len(converted_items)} items")
return LineItemsResult(
items=converted_items,
header_row=text_result.header_row,
raw_html="", # No HTML for text-based extraction
is_reversed=False,
)
def is_line_items_table(self, headers: list[str]) -> bool:
"""
Check if headers indicate a line items table.
Args:
headers: List of column headers.
Returns:
True if this appears to be a line items table.
"""
column_map = self.mapper.map(headers)
mapped_fields = set(column_map.values())
logger.debug(f"is_line_items_table: headers={headers}, mapped_fields={mapped_fields}")
# Must have description or article_number OR amount field
# (rental invoices may have amount columns like "Hyra" without explicit description)
has_item_identifier = (
"description" in mapped_fields
or "article_number" in mapped_fields
)
has_amount = "amount" in mapped_fields
# Check for summary table keywords
header_text = " ".join(h.lower() for h in headers)
is_summary = any(kw in header_text for kw in SUMMARY_KEYWORDS)
# Accept table if it has item identifiers OR has amount columns (and not a summary)
result = (has_item_identifier or has_amount) and not is_summary
logger.debug(f"is_line_items_table: has_item_identifier={has_item_identifier}, has_amount={has_amount}, is_summary={is_summary}, result={result}")
return result
def _detect_header_row(
self, rows: list[list[str]]
) -> tuple[int, list[str], bool]:
"""
Detect which row is the header based on content patterns.
Returns:
Tuple of (header_index, header_row, is_at_end).
"""
header_keywords = set()
for patterns in self.mapper.mappings.values():
for p in patterns:
header_keywords.add(p.lower())
best_match = (-1, [], 0)
for i, row in enumerate(rows):
if all(not cell.strip() for cell in row):
continue
row_text = " ".join(cell.lower() for cell in row)
matches = sum(1 for kw in header_keywords if kw in row_text)
if matches > best_match[2]:
best_match = (i, row, matches)
if best_match[2] >= 2:
header_idx = best_match[0]
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
return header_idx, best_match[1], is_at_end
return -1, [], False
def _extract_items(
self, rows: list[list[str]], column_map: dict[int, str]
) -> list[LineItem]:
"""Extract line items from data rows."""
items = []
for row_idx, row in enumerate(rows):
item_data: dict = {
"row_index": row_idx,
"description": None,
"quantity": None,
"unit": None,
"unit_price": None,
"amount": None,
"article_number": None,
"vat_rate": None,
"is_deduction": False,
}
for col_idx, cell in enumerate(row):
if col_idx in column_map:
field = column_map[col_idx]
# Handle deduction column - store value as amount and mark as deduction
if field == "deduction":
if cell:
item_data["amount"] = cell
item_data["is_deduction"] = True
# Skip assigning to "deduction" field (it doesn't exist in LineItem)
else:
item_data[field] = cell if cell else None
# Only add if we have at least description or amount
if item_data["description"] or item_data["amount"]:
items.append(LineItem(**item_data))
return items
def _has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
"""
Check if table rows contain vertically merged data in single cells.
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
"""
if not rows:
return False
for row in rows:
for cell in row:
if not cell or len(cell) < 20:
continue
# Check for multiple product numbers (7+ digit patterns)
product_nums = re.findall(r"\b\d{7}\b", cell)
if len(product_nums) >= 2:
logger.debug(f"_has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
return True
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
if len(prices) >= 3:
logger.debug(f"_has_vertically_merged_cells: found {len(prices)} prices in cell")
return True
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
if len(quantities) >= 2:
logger.debug(f"_has_vertically_merged_cells: found {len(quantities)} quantities in cell")
return True
return False
def _split_merged_rows(
self, rows: list[list[str]]
) -> tuple[list[str], list[list[str]]]:
"""
Split vertically merged cells back into separate rows.
Handles complex cases where PP-StructureV3 merges content across
multiple HTML rows. For example, 5 line items might be spread across
3 HTML rows with content mixed together.
Strategy:
1. Merge all row content per column
2. Detect how many actual data rows exist (by counting product numbers)
3. Split each column's content into that many lines
Returns header and data rows.
"""
if not rows:
return [], []
# Filter out completely empty rows
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
if not non_empty_rows:
return [], rows
# Determine column count
col_count = max(len(r) for r in non_empty_rows)
# Merge content from all rows for each column
merged_columns = []
for col_idx in range(col_count):
col_content = []
for row in non_empty_rows:
if col_idx < len(row) and row[col_idx].strip():
col_content.append(row[col_idx].strip())
merged_columns.append(" ".join(col_content))
logger.debug(f"_split_merged_rows: merged columns = {merged_columns}")
# Count how many actual data rows we should have
# Use the column with most product numbers as reference
expected_rows = self._count_expected_rows(merged_columns)
logger.info(f"_split_merged_rows: expecting {expected_rows} data rows")
if expected_rows <= 1:
# Not enough data for splitting
return [], rows
# Split each column based on expected row count
split_columns = []
for col_idx, col_text in enumerate(merged_columns):
if not col_text.strip():
split_columns.append([""] * (expected_rows + 1)) # +1 for header
continue
lines = self._split_cell_content_for_rows(col_text, expected_rows)
split_columns.append(lines)
# Ensure all columns have same number of lines
max_lines = max(len(col) for col in split_columns)
for col in split_columns:
while len(col) < max_lines:
col.append("")
logger.info(f"_split_merged_rows: split into {max_lines} lines total")
# First line is header, rest are data rows
header = [col[0] for col in split_columns]
data_rows = []
for line_idx in range(1, max_lines):
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
if any(cell.strip() for cell in row):
data_rows.append(row)
logger.info(f"_split_merged_rows: header={header}, data_rows count={len(data_rows)}")
return header, data_rows
def _count_expected_rows(self, merged_columns: list[str]) -> int:
"""
Count how many data rows should exist based on content patterns.
Returns the maximum count found from:
- Product numbers (7 digits)
- Quantity patterns (number + ST/PCS)
- Amount patterns (in columns likely to be totals)
"""
max_count = 0
for col_text in merged_columns:
if not col_text:
continue
# Count product numbers (most reliable indicator)
product_nums = re.findall(r"\b\d{7}\b", col_text)
max_count = max(max_count, len(product_nums))
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
max_count = max(max_count, len(quantities))
return max_count
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
"""
Split cell content knowing how many data rows we expect.
This is smarter than _split_cell_content because it knows the target count.
"""
cell = cell.strip()
# Try product number split first
product_pattern = re.compile(r"(\b\d{7}\b)")
products = product_pattern.findall(cell)
if len(products) == expected_rows:
parts = product_pattern.split(cell)
header = parts[0].strip() if parts else ""
# Include description text after each product number
values = []
for i in range(1, len(parts), 2): # Odd indices are product numbers
if i < len(parts):
prod_num = parts[i].strip()
# Check if there's description text after
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
# If description looks like text (not another pattern), include it
if desc and not re.match(r"^\d{7}$", desc):
# Truncate at next product number pattern if any
desc_clean = re.split(r"\d{7}", desc)[0].strip()
if desc_clean:
values.append(f"{prod_num} {desc_clean}")
else:
values.append(prod_num)
else:
values.append(prod_num)
if len(values) == expected_rows:
return [header] + values
# Try quantity split
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
quantities = qty_pattern.findall(cell)
if len(quantities) == expected_rows:
parts = qty_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
if len(values) == expected_rows:
return [header] + values
# Try amount split for discount+totalsumma columns
cell_lower = cell.lower()
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
if has_discount and has_total:
# Extract only amounts (3+ digit numbers), skip discount percentages
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
amounts = amount_pattern.findall(cell)
if len(amounts) >= expected_rows:
# Take the last expected_rows amounts (they are likely the totals)
return ["Totalsumma"] + amounts[:expected_rows]
# Try price split
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
prices = price_pattern.findall(cell)
if len(prices) >= expected_rows:
parts = price_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
if len(values) >= expected_rows:
return [header] + values[:expected_rows]
# Fall back to original single-value behavior
return [cell]
def _split_cell_content(self, cell: str) -> list[str]:
"""
Split a cell containing merged multi-line content.
Strategies:
1. Look for product number patterns (7 digits)
2. Look for quantity patterns (number + ST/PCS)
3. Look for price patterns (with decimal)
4. Handle interleaved discount+amount patterns
"""
cell = cell.strip()
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
product_pattern = re.compile(r"(\b\d{7}\b)")
products = product_pattern.findall(cell)
if len(products) >= 2:
# Extract header (text before first product number) and values
parts = product_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
return [header] + values
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
quantities = qty_pattern.findall(cell)
if len(quantities) >= 2:
parts = qty_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
return [header] + values
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
# Check if header contains two keywords indicating merged columns
cell_lower = cell.lower()
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
if has_discount_header and has_amount_header:
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
amounts = amount_pattern.findall(cell)
if len(amounts) >= 2:
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
# This avoids the "Rabatt" keyword causing is_deduction=True
header = "Totalsumma"
return [header] + amounts
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
prices = price_pattern.findall(cell)
if len(prices) >= 2:
parts = price_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
return [header] + values
# No pattern detected, return as single value
return [cell]
def _has_merged_header(self, header: list[str] | None) -> bool:
"""
Check if header appears to be a merged cell containing multiple column names.
This happens when OCR merges table headers into a single cell, e.g.:
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
Also handles cases where PP-StructureV3 produces headers like:
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
"""
if header is None or not header:
return False
# Filter out empty cells to find the actual content
non_empty_cells = [h for h in header if h.strip()]
# Check if we have a single non-empty cell that contains multiple keywords
if len(non_empty_cells) == 1:
header_text = non_empty_cells[0].lower()
# Count how many column keywords are in this single cell
keyword_count = 0
for patterns in self.mapper.mappings.values():
for pattern in patterns:
if pattern in header_text:
keyword_count += 1
break # Only count once per field type
logger.debug(f"_has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
return keyword_count >= 2
return False
def _extract_from_merged_cells(
self, header: list[str], rows: list[list[str]]
) -> list[LineItem]:
"""
Extract line items from tables with merged cells.
For poorly OCR'd tables like:
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
Row 1: ["", "", "", "8159"] <- amount row
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
Or:
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
"""
items = []
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
amount_pattern = re.compile(
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
)
# Try to parse header cell for description info
header_text = " ".join(h for h in header if h.strip()) if header else ""
logger.info(f"_extract_from_merged_cells: header_text='{header_text}'")
logger.info(f"_extract_from_merged_cells: rows={rows}")
# Extract description from header
description = None
article_number = None
# Look for object number pattern (e.g., "0218103-1201")
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
if obj_match:
article_number = obj_match.group(1)
# Look for description after object number
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
if desc_match:
description = desc_match.group(1).strip()
row_index = 0
for row in rows:
# Combine all non-empty cells in the row
row_text = " ".join(cell.strip() for cell in row if cell.strip())
logger.info(f"_extract_from_merged_cells: row text='{row_text}'")
if not row_text:
continue
# Find all amounts in the row
amounts = amount_pattern.findall(row_text)
logger.info(f"_extract_from_merged_cells: amounts={amounts}")
for amt_str in amounts:
# Clean the amount string
cleaned = amt_str.replace(" ", "").strip()
if not cleaned or cleaned == "-":
continue
is_deduction = cleaned.startswith("-")
# Skip small positive numbers that are likely not amounts
if not is_deduction:
try:
val = float(cleaned.replace(",", "."))
if val < 100:
continue
except ValueError:
continue
# Create a line item for each amount
item = LineItem(
row_index=row_index,
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
article_number=article_number if row_index == 0 else None,
amount=cleaned,
is_deduction=is_deduction,
confidence=0.7,
)
items.append(item)
row_index += 1
logger.info(f"_extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
return items

View File

@@ -0,0 +1,480 @@
"""
PP-StructureV3 Table Detection Wrapper
Provides automatic table detection in invoice images using PaddleOCR's
PP-StructureV3 pipeline. Supports both wired (bordered) and wireless
(borderless) tables commonly found in Swedish invoices.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
import logging
import numpy as np
logger = logging.getLogger(__name__)
@dataclass
class TableDetectorConfig:
"""Configuration for TableDetector."""
device: str = "gpu:0"
use_doc_orientation_classify: bool = False
use_doc_unwarping: bool = False
use_textline_orientation: bool = False
# Use SLANeXt models for better table recognition accuracy
# SLANeXt_wireless has ~6% higher accuracy than SLANet for borderless tables
wired_table_model: str = "SLANeXt_wired"
wireless_table_model: str = "SLANeXt_wireless"
layout_model: str = "PP-DocLayout_plus-L"
min_confidence: float = 0.5
@dataclass
class TableDetectionResult:
"""Result of table detection."""
bbox: tuple[float, float, float, float] # x1, y1, x2, y2 in pixels
html: str # Table structure as HTML
confidence: float
table_type: str # 'wired' or 'wireless'
cells: list[dict[str, Any]] = field(default_factory=list) # Cell-level data
class PPStructureProtocol(Protocol):
"""Protocol for PP-StructureV3 pipeline interface."""
def predict(self, image: str | np.ndarray, **kwargs: Any) -> Any:
"""Run prediction on image."""
...
class TableDetector:
"""
Table detector using PP-StructureV3.
Detects tables in invoice images and returns their bounding boxes,
HTML structure, and cell-level data.
"""
def __init__(
self,
config: TableDetectorConfig | None = None,
pipeline: PPStructureProtocol | None = None,
):
"""
Initialize table detector.
Args:
config: Configuration options. Uses defaults if None.
pipeline: Optional pre-initialized PP-StructureV3 pipeline.
If None, will be lazily initialized on first use.
"""
self.config = config or TableDetectorConfig()
self._pipeline = pipeline
self._initialized = pipeline is not None
def _ensure_initialized(self) -> None:
"""Lazily initialize PP-Structure pipeline."""
if self._initialized:
return
# Try PPStructureV3 first (paddleocr >= 3.0.0), fall back to PPStructure (2.x)
try:
from paddleocr import PPStructureV3
self._pipeline = PPStructureV3(
layout_detection_model_name=self.config.layout_model,
wired_table_structure_recognition_model_name=self.config.wired_table_model,
wireless_table_structure_recognition_model_name=self.config.wireless_table_model,
use_doc_orientation_classify=self.config.use_doc_orientation_classify,
use_doc_unwarping=self.config.use_doc_unwarping,
use_textline_orientation=self.config.use_textline_orientation,
device=self.config.device,
)
self._initialized = True
logger.info("PP-StructureV3 pipeline initialized successfully")
except ImportError:
# Fall back to PPStructure (paddleocr 2.x)
try:
from paddleocr import PPStructure
# Map device config to use_gpu for PPStructure 2.x
use_gpu = "gpu" in self.config.device.lower()
self._pipeline = PPStructure(
table=True,
ocr=True,
use_gpu=use_gpu,
show_log=False,
)
self._initialized = True
logger.info("PPStructure (2.x) pipeline initialized successfully")
except ImportError as e:
raise ImportError(
"PPStructure requires paddleocr. "
"Install with: pip install paddleocr"
) from e
def detect(
self,
image: np.ndarray | str | Path,
) -> list[TableDetectionResult]:
"""
Detect tables in an image.
Args:
image: Input image as numpy array, file path, or Path object.
Returns:
List of TableDetectionResult for each detected table.
"""
self._ensure_initialized()
if self._pipeline is None:
raise RuntimeError("Pipeline not initialized")
# Convert Path to string
if isinstance(image, Path):
image = str(image)
# Run detection
results = self._pipeline.predict(image)
return self._parse_results(results)
def _parse_results(self, results: Any) -> list[TableDetectionResult]:
"""Parse PP-StructureV3 output into TableDetectionResult list.
Supports both:
- PaddleX 3.x API: dict-like LayoutParsingResultV2 with table_res_list
- Legacy API: objects with layout_elements attribute
"""
tables: list[TableDetectionResult] = []
if results is None:
logger.warning("PP-StructureV3 returned None results")
return tables
# Log raw result type for debugging
logger.info(f"PP-StructureV3 raw results type: {type(results).__name__}")
# Handle case where results is a single dict-like object (PaddleX 3.x)
# rather than a list of results
if hasattr(results, "get") and not isinstance(results, list):
# Single result object - wrap in list for uniform processing
logger.info("Results is dict-like, wrapping in list")
results = [results]
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
# Iterator or generator - convert to list
try:
results = list(results)
logger.info(f"Converted iterator to list with {len(results)} items")
except Exception as e:
logger.warning(f"Failed to convert results to list: {e}")
return tables
logger.info(f"Processing {len(results)} result(s)")
for i, result in enumerate(results):
try:
result_type = type(result).__name__
has_get = hasattr(result, "get")
has_layout = hasattr(result, "layout_elements")
logger.info(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
# Try PaddleX 3.x API first (dict-like with table_res_list)
if has_get:
parsed = self._parse_paddlex_result(result)
logger.info(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
tables.extend(parsed)
continue
# Fall back to legacy API (layout_elements)
if has_layout:
legacy_count = 0
for element in result.layout_elements:
if not self._is_table_element(element):
continue
table_result = self._extract_table_data(element)
if table_result and table_result.confidence >= self.config.min_confidence:
tables.append(table_result)
legacy_count += 1
logger.info(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
else:
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
except Exception as e:
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
continue
logger.info(f"Total tables detected: {len(tables)}")
return tables
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
"""Parse PaddleX 3.x LayoutParsingResultV2."""
tables: list[TableDetectionResult] = []
try:
# Log result structure for debugging
result_type = type(result).__name__
result_keys = []
if hasattr(result, "keys"):
result_keys = list(result.keys())
elif hasattr(result, "__dict__"):
result_keys = list(result.__dict__.keys())
logger.info(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
# Get table results from PaddleX 3.x API
# Handle both dict.get() and attribute access
if hasattr(result, "get"):
table_res_list = result.get("table_res_list")
parsing_res_list = result.get("parsing_res_list", [])
else:
table_res_list = getattr(result, "table_res_list", None)
parsing_res_list = getattr(result, "parsing_res_list", [])
logger.info(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
logger.info(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
if not table_res_list:
# Log available keys/attributes for debugging
logger.warning(f"No table_res_list found in result: {result_type}, available: {result_keys}")
return tables
# Get parsing_res_list to find table bounding boxes
table_bboxes = {}
for elem in parsing_res_list or []:
try:
if isinstance(elem, dict):
label = elem.get("label", "")
bbox = elem.get("bbox", [])
else:
label = getattr(elem, "label", "")
bbox = getattr(elem, "bbox", [])
# Check bbox has items (handles numpy arrays safely)
has_bbox = False
try:
has_bbox = len(bbox) >= 4 if hasattr(bbox, "__len__") else False
except (TypeError, ValueError):
pass
if label == "table" and has_bbox:
# Map by index (parsing_res_list tables appear in order)
idx = len(table_bboxes)
table_bboxes[idx] = bbox
except Exception as e:
logger.debug(f"Failed to parse parsing_res element: {e}")
continue
for i, table_res in enumerate(table_res_list):
try:
# Extract from PaddleX 3.x table result format
# Handle both dict and object access (SingleTableRecognitionResult)
if isinstance(table_res, dict):
cell_boxes = table_res.get("cell_box_list", [])
html = table_res.get("pred_html", "")
ocr_data = table_res.get("table_ocr_pred", {})
else:
cell_boxes = getattr(table_res, "cell_box_list", [])
html = getattr(table_res, "pred_html", "")
ocr_data = getattr(table_res, "table_ocr_pred", {})
# table_ocr_pred can be dict (PaddleOCR 3.x) or list (older versions)
# For dict format: {"rec_texts": [...], "rec_scores": [...], ...}
ocr_texts = []
if isinstance(ocr_data, dict):
ocr_texts = ocr_data.get("rec_texts", [])
elif isinstance(ocr_data, list):
ocr_texts = ocr_data
# Try to get bbox from parsing_res_list
bbox = table_bboxes.get(i, [0.0, 0.0, 0.0, 0.0])
# Handle numpy arrays - check length explicitly to avoid boolean ambiguity
try:
bbox_len = len(bbox) if hasattr(bbox, "__len__") else 0
if bbox_len < 4:
bbox = [0.0, 0.0, 0.0, 0.0]
except (TypeError, ValueError):
bbox = [0.0, 0.0, 0.0, 0.0]
# Build cells from cell_box_list and OCR text
cells = []
# Check cell_boxes length explicitly to avoid numpy array boolean issues
has_cell_boxes = False
try:
has_cell_boxes = len(cell_boxes) > 0 if hasattr(cell_boxes, "__len__") else bool(cell_boxes)
except (TypeError, ValueError):
pass
if has_cell_boxes:
# Check ocr_texts length safely for numpy arrays
ocr_texts_len = 0
try:
ocr_texts_len = len(ocr_texts) if hasattr(ocr_texts, "__len__") else 0
except (TypeError, ValueError):
pass
for j, cell_bbox in enumerate(cell_boxes):
cell_text = ocr_texts[j] if ocr_texts_len > j else ""
# Convert cell_bbox to list safely (may be numpy array)
cell_bbox_list = []
try:
cell_bbox_list = list(cell_bbox) if hasattr(cell_bbox, "__iter__") else []
except (TypeError, ValueError):
pass
cells.append({
"text": cell_text,
"bbox": cell_bbox_list,
"row": 0, # Row/col info not directly available
"col": j,
})
# Default confidence for PaddleX 3.x results
confidence = 0.9
logger.info(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
tables.append(TableDetectionResult(
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
html=html,
confidence=confidence,
table_type="wired", # PaddleX 3.x handles both types
cells=cells,
))
except Exception as e:
import traceback
logger.warning(f"Failed to parse table_res {i}: {e}\n{traceback.format_exc()}")
continue
except Exception as e:
logger.warning(f"Failed to parse PaddleX result: {type(e).__name__}: {e}")
return tables
def _is_table_element(self, element: Any) -> bool:
"""Check if element is a table."""
if hasattr(element, "label"):
return element.label.lower() in ("table", "wired_table", "wireless_table")
if hasattr(element, "type"):
return element.type.lower() in ("table", "wired_table", "wireless_table")
return False
def _extract_table_data(self, element: Any) -> TableDetectionResult | None:
"""Extract table data from PP-StructureV3 element."""
try:
# Get bounding box
bbox = self._get_bbox(element)
if bbox is None:
return None
# Get HTML content
html = self._get_html(element)
# Get confidence
confidence = getattr(element, "score", 0.9)
if isinstance(confidence, (list, tuple)):
confidence = float(confidence[0]) if confidence else 0.9
# Determine table type
table_type = self._get_table_type(element)
# Get cells if available
cells = self._get_cells(element)
return TableDetectionResult(
bbox=bbox,
html=html,
confidence=float(confidence),
table_type=table_type,
cells=cells,
)
except Exception as e:
logger.warning(f"Failed to extract table data: {e}")
return None
def _get_bbox(self, element: Any) -> tuple[float, float, float, float] | None:
"""Extract bounding box from element."""
if hasattr(element, "bbox"):
bbox = element.bbox
if len(bbox) >= 4:
return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]))
if hasattr(element, "box"):
box = element.box
if len(box) >= 4:
return (float(box[0]), float(box[1]), float(box[2]), float(box[3]))
return None
def _get_html(self, element: Any) -> str:
"""Extract HTML content from element."""
if hasattr(element, "html"):
return str(element.html)
if hasattr(element, "table_html"):
return str(element.table_html)
if hasattr(element, "res") and isinstance(element.res, dict):
return element.res.get("html", "")
return ""
def _get_table_type(self, element: Any) -> str:
"""Determine table type (wired or wireless)."""
label = ""
if hasattr(element, "label"):
label = str(element.label).lower()
elif hasattr(element, "type"):
label = str(element.type).lower()
if "wireless" in label or "borderless" in label:
return "wireless"
return "wired"
def _get_cells(self, element: Any) -> list[dict[str, Any]]:
"""Extract cell-level data from element."""
cells: list[dict[str, Any]] = []
if hasattr(element, "cells"):
for cell in element.cells:
cell_data = {
"text": getattr(cell, "text", ""),
"row": getattr(cell, "row", 0),
"col": getattr(cell, "col", 0),
"row_span": getattr(cell, "row_span", 1),
"col_span": getattr(cell, "col_span", 1),
}
if hasattr(cell, "bbox"):
cell_data["bbox"] = cell.bbox
cells.append(cell_data)
return cells
def detect_from_pdf(
self,
pdf_path: str | Path,
page_number: int = 0,
dpi: int = 300,
) -> list[TableDetectionResult]:
"""
Detect tables from a PDF page.
Args:
pdf_path: Path to PDF file.
page_number: Page number (0-indexed).
dpi: Resolution for rendering.
Returns:
List of TableDetectionResult for the specified page.
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
pdf_path = Path(pdf_path)
if not pdf_path.exists():
raise FileNotFoundError(f"PDF not found: {pdf_path}")
logger.info(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
# Render specific page
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
if page_no == page_number:
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
logger.info(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
return self.detect(image_array)
raise ValueError(f"Page {page_number} not found in PDF")

View File

@@ -0,0 +1,449 @@
"""
Text-Based Line Items Extractor
Fallback extraction for invoices where PP-StructureV3 cannot detect table structures
(e.g., borderless/wireless tables). Uses spatial analysis of OCR text elements to
identify and group line items.
"""
from dataclasses import dataclass, field
from decimal import Decimal, InvalidOperation
import re
from typing import Any
import logging
logger = logging.getLogger(__name__)
@dataclass
class TextElement:
"""Single text element from OCR."""
text: str
bbox: tuple[float, float, float, float] # x1, y1, x2, y2
confidence: float = 1.0
@property
def center_y(self) -> float:
"""Vertical center of the element."""
return (self.bbox[1] + self.bbox[3]) / 2
@property
def center_x(self) -> float:
"""Horizontal center of the element."""
return (self.bbox[0] + self.bbox[2]) / 2
@property
def height(self) -> float:
"""Height of the element."""
return self.bbox[3] - self.bbox[1]
@dataclass
class TextLineItem:
"""Line item extracted from text elements."""
row_index: int
description: str | None = None
quantity: str | None = None
unit: str | None = None
unit_price: str | None = None
amount: str | None = None
article_number: str | None = None
vat_rate: str | None = None
is_deduction: bool = False # True if this row is a deduction/discount
confidence: float = 0.7 # Lower default confidence for text-based extraction
@dataclass
class TextLineItemsResult:
"""Result of text-based line items extraction."""
items: list[TextLineItem]
header_row: list[str]
extraction_method: str = "text_spatial"
# Swedish amount pattern: 1 234,56 or 1234.56 or 1,234.56
AMOUNT_PATTERN = re.compile(
r"(?<![0-9])(?:"
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
r"|-?\d{1,3}(?:,\d{3})*(?:\.\d{2})?" # US: 1,234.56
r"|-?\d+(?:[.,]\d{2})?" # Simple: 1234,56 or 1234.56
r")(?:\s*(?:kr|SEK|:-))?" # Optional currency suffix
r"(?![0-9])"
)
# Quantity patterns
QUANTITY_PATTERN = re.compile(
r"^(?:"
r"\d+(?:[.,]\d+)?\s*(?:st|pcs|m|kg|l|h|tim|timmar)?" # Number with optional unit
r")$",
re.IGNORECASE,
)
# VAT rate patterns
VAT_RATE_PATTERN = re.compile(r"(\d+)\s*%")
# Keywords indicating a line item area
LINE_ITEM_KEYWORDS = [
"beskrivning",
"artikel",
"produkt",
"belopp",
"summa",
"antal",
"pris",
"á-pris",
"a-pris",
"moms",
]
# Keywords indicating NOT line items (summary area)
SUMMARY_KEYWORDS = [
"att betala",
"total",
"summa att betala",
"betalningsvillkor",
"förfallodatum",
"bankgiro",
"plusgiro",
"ocr-nummer",
"fakturabelopp",
"exkl. moms",
"inkl. moms",
"varav moms",
]
class TextLineItemsExtractor:
"""
Extract line items from text elements using spatial analysis.
This is a fallback for when PP-StructureV3 cannot detect table structures.
It groups text elements by vertical position and identifies patterns
that match line item rows.
"""
def __init__(
self,
row_tolerance: float = 15.0, # Max vertical distance to consider same row
min_items_for_valid: int = 2, # Minimum items to consider extraction valid
):
"""
Initialize extractor.
Args:
row_tolerance: Maximum vertical distance (pixels) between elements
to consider them on the same row.
min_items_for_valid: Minimum number of line items required for
extraction to be considered successful.
"""
self.row_tolerance = row_tolerance
self.min_items_for_valid = min_items_for_valid
def extract_from_parsing_res(
self, parsing_res_list: list[dict[str, Any]]
) -> TextLineItemsResult | None:
"""
Extract line items from PP-StructureV3 parsing_res_list.
Args:
parsing_res_list: List of parsed elements from PP-StructureV3.
Returns:
TextLineItemsResult if line items found, None otherwise.
"""
if not parsing_res_list:
logger.debug("No parsing_res_list provided")
return None
# Extract text elements from parsing results
text_elements = self._extract_text_elements(parsing_res_list)
logger.info(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
if len(text_elements) < 5: # Need at least a few elements
logger.debug("Too few text elements for line item extraction")
return None
return self.extract_from_text_elements(text_elements)
def extract_from_text_elements(
self, text_elements: list[TextElement]
) -> TextLineItemsResult | None:
"""
Extract line items from a list of text elements.
Args:
text_elements: List of TextElement objects.
Returns:
TextLineItemsResult if line items found, None otherwise.
"""
# Group elements by row
rows = self._group_by_row(text_elements)
logger.info(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
# Find the line items section
item_rows = self._identify_line_item_rows(rows)
logger.info(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
if len(item_rows) < self.min_items_for_valid:
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
return None
# Extract structured items
items = self._parse_line_items(item_rows)
logger.info(f"TextLineItemsExtractor: extracted {len(items)} line items")
if len(items) < self.min_items_for_valid:
return None
return TextLineItemsResult(
items=items,
header_row=[], # No explicit header in text-based extraction
extraction_method="text_spatial",
)
def _extract_text_elements(
self, parsing_res_list: list[dict[str, Any]]
) -> list[TextElement]:
"""Extract TextElement objects from parsing_res_list."""
elements = []
for elem in parsing_res_list:
try:
# Get label and bbox - handle both dict and LayoutBlock objects
if isinstance(elem, dict):
label = elem.get("label", "")
bbox = elem.get("bbox", [])
# Try both 'text' and 'content' keys
text = elem.get("text", "") or elem.get("content", "")
else:
label = getattr(elem, "label", "")
bbox = getattr(elem, "bbox", [])
# LayoutBlock objects use 'content' attribute
text = getattr(elem, "content", "") or getattr(elem, "text", "")
# Only process text elements (skip images, tables, etc.)
if label not in ("text", "paragraph_title", "aside_text"):
continue
# Validate bbox
if not self._valid_bbox(bbox):
continue
# Clean text
text = str(text).strip() if text else ""
if not text:
continue
elements.append(
TextElement(
text=text,
bbox=(
float(bbox[0]),
float(bbox[1]),
float(bbox[2]),
float(bbox[3]),
),
)
)
except Exception as e:
logger.debug(f"Failed to parse element: {e}")
continue
return elements
def _valid_bbox(self, bbox: Any) -> bool:
"""Check if bbox is valid (has 4 elements)."""
try:
return len(bbox) >= 4 if hasattr(bbox, "__len__") else False
except (TypeError, ValueError):
return False
def _group_by_row(
self, elements: list[TextElement]
) -> list[list[TextElement]]:
"""
Group text elements into rows based on vertical position.
Elements within row_tolerance of each other are considered same row.
"""
if not elements:
return []
# Sort by vertical position
sorted_elements = sorted(elements, key=lambda e: e.center_y)
rows = []
current_row = [sorted_elements[0]]
current_y = sorted_elements[0].center_y
for elem in sorted_elements[1:]:
if abs(elem.center_y - current_y) <= self.row_tolerance:
# Same row
current_row.append(elem)
else:
# New row
if current_row:
# Sort row by horizontal position
current_row.sort(key=lambda e: e.center_x)
rows.append(current_row)
current_row = [elem]
current_y = elem.center_y
# Don't forget last row
if current_row:
current_row.sort(key=lambda e: e.center_x)
rows.append(current_row)
return rows
def _identify_line_item_rows(
self, rows: list[list[TextElement]]
) -> list[list[TextElement]]:
"""
Identify which rows are likely line items.
Line item rows typically have:
- Multiple elements per row
- At least one amount-like value
- Description text
"""
item_rows = []
in_item_section = False
for row in rows:
row_text = " ".join(e.text for e in row).lower()
# Check if we're entering summary section
if any(kw in row_text for kw in SUMMARY_KEYWORDS):
in_item_section = False
continue
# Check if this looks like a header row
if any(kw in row_text for kw in LINE_ITEM_KEYWORDS):
in_item_section = True
continue # Skip header row itself
# Check if row looks like a line item
if in_item_section or self._looks_like_line_item(row):
if self._looks_like_line_item(row):
item_rows.append(row)
return item_rows
def _looks_like_line_item(self, row: list[TextElement]) -> bool:
"""Check if a row looks like a line item."""
if len(row) < 2:
return False
row_text = " ".join(e.text for e in row)
# Must have at least one amount
amounts = AMOUNT_PATTERN.findall(row_text)
if not amounts:
return False
# Should have some description text (not just numbers)
has_description = any(
len(e.text) > 3 and not AMOUNT_PATTERN.fullmatch(e.text.strip())
for e in row
)
return has_description
def _parse_line_items(
self, item_rows: list[list[TextElement]]
) -> list[TextLineItem]:
"""Parse line item rows into structured items."""
items = []
for idx, row in enumerate(item_rows):
item = self._parse_single_row(row, idx)
if item:
items.append(item)
return items
def _parse_single_row(
self, row: list[TextElement], row_index: int
) -> TextLineItem | None:
"""Parse a single row into a line item."""
if not row:
return None
# Combine all text for analysis
all_text = " ".join(e.text for e in row)
# Find amounts (rightmost is usually the total)
amounts = list(AMOUNT_PATTERN.finditer(all_text))
if not amounts:
return None
# Last amount is typically line total
amount_match = amounts[-1]
amount = amount_match.group(0).strip()
# Second to last might be unit price
unit_price = None
if len(amounts) >= 2:
unit_price = amounts[-2].group(0).strip()
# Look for quantity
quantity = None
for elem in row:
text = elem.text.strip()
if QUANTITY_PATTERN.match(text):
quantity = text
break
# Look for VAT rate
vat_rate = None
vat_match = VAT_RATE_PATTERN.search(all_text)
if vat_match:
vat_rate = vat_match.group(1)
# Description is typically the longest non-numeric text
description = None
max_len = 0
for elem in row:
text = elem.text.strip()
# Skip if it looks like a number/amount
if AMOUNT_PATTERN.fullmatch(text):
continue
if QUANTITY_PATTERN.match(text):
continue
if len(text) > max_len:
description = text
max_len = len(text)
return TextLineItem(
row_index=row_index,
description=description,
quantity=quantity,
unit_price=unit_price,
amount=amount,
vat_rate=vat_rate,
confidence=0.7,
)
def convert_text_line_item(item: TextLineItem) -> "LineItem":
"""Convert TextLineItem to standard LineItem dataclass."""
from .line_items_extractor import LineItem
return LineItem(
row_index=item.row_index,
description=item.description,
quantity=item.quantity,
unit=item.unit,
unit_price=item.unit_price,
amount=item.amount,
article_number=item.article_number,
vat_rate=item.vat_rate,
is_deduction=item.is_deduction,
confidence=item.confidence,
)

View File

@@ -1,7 +1,19 @@
"""
Cross-validation module for verifying field extraction using LLM.
Cross-validation module for verifying field extraction.
Includes LLM validation and VAT cross-validation.
"""
from .llm_validator import LLMValidator
from .vat_validator import (
VATValidationResult,
VATValidator,
MathCheckResult,
)
__all__ = ['LLMValidator']
__all__ = [
"LLMValidator",
"VATValidationResult",
"VATValidator",
"MathCheckResult",
]

View File

@@ -0,0 +1,267 @@
"""
VAT Validator
Cross-validates VAT information from multiple sources:
- Mathematical verification (base × rate = vat)
- Line items vs VAT summary comparison
- Consistency with existing amount field
"""
from dataclasses import dataclass, field
from decimal import Decimal, InvalidOperation
from backend.vat.vat_extractor import VATSummary, AmountParser
from backend.table.line_items_extractor import LineItemsResult
@dataclass
class MathCheckResult:
"""Result of a single VAT rate mathematical check."""
rate: float
base_amount: float | None
expected_vat: float | None
actual_vat: float
is_valid: bool
tolerance: float
@dataclass
class VATValidationResult:
"""Complete VAT validation result."""
is_valid: bool
confidence_score: float # 0.0 - 1.0
# Mathematical verification
math_checks: list[MathCheckResult]
total_check: bool # incl = excl + total_vat?
# Source comparison
line_items_vs_summary: bool | None # line items total = VAT summary?
amount_consistency: bool | None # total_incl_vat = existing amount field?
# Review flags
needs_review: bool
review_reasons: list[str] = field(default_factory=list)
class VATValidator:
"""Validates VAT information using multiple cross-checks."""
def __init__(self, tolerance: float = 0.02):
"""
Initialize validator.
Args:
tolerance: Acceptable difference for math checks (default 0.02 = 2 cents)
"""
self.tolerance = tolerance
self.amount_parser = AmountParser()
def validate(
self,
vat_summary: VATSummary,
line_items: LineItemsResult | None = None,
existing_amount: str | None = None,
) -> VATValidationResult:
"""
Validate VAT information.
Args:
vat_summary: Extracted VAT summary.
line_items: Optional line items for comparison.
existing_amount: Optional existing amount field from YOLO extraction.
Returns:
VATValidationResult with all check results.
"""
review_reasons: list[str] = []
# Handle empty summary
if not vat_summary.breakdowns and not vat_summary.total_vat:
return VATValidationResult(
is_valid=False,
confidence_score=0.0,
math_checks=[],
total_check=False,
line_items_vs_summary=None,
amount_consistency=None,
needs_review=True,
review_reasons=["No VAT information found"],
)
# Run all checks
math_checks = self._run_math_checks(vat_summary)
total_check = self._check_totals(vat_summary)
line_items_check = self._check_line_items(vat_summary, line_items)
amount_check = self._check_amount_consistency(vat_summary, existing_amount)
# Collect review reasons
math_failures = [c for c in math_checks if not c.is_valid]
if math_failures:
review_reasons.append(f"Math check failed for {len(math_failures)} VAT rate(s)")
if not total_check:
review_reasons.append("Total amount mismatch (excl + vat != incl)")
if line_items_check is False:
review_reasons.append("Line items total doesn't match VAT summary")
if amount_check is False:
review_reasons.append("VAT total doesn't match existing amount field")
# Calculate overall validity and confidence
all_math_valid = all(c.is_valid for c in math_checks) if math_checks else True
is_valid = all_math_valid and total_check and (amount_check is not False)
confidence_score = self._calculate_confidence(
vat_summary, math_checks, total_check, line_items_check, amount_check
)
needs_review = len(review_reasons) > 0 or confidence_score < 0.7
return VATValidationResult(
is_valid=is_valid,
confidence_score=confidence_score,
math_checks=math_checks,
total_check=total_check,
line_items_vs_summary=line_items_check,
amount_consistency=amount_check,
needs_review=needs_review,
review_reasons=review_reasons,
)
def _run_math_checks(self, vat_summary: VATSummary) -> list[MathCheckResult]:
"""Run mathematical verification for each VAT rate."""
results = []
for breakdown in vat_summary.breakdowns:
actual_vat = self.amount_parser.parse(breakdown.vat_amount)
if actual_vat is None:
continue
base_amount = None
expected_vat = None
is_valid = True
if breakdown.base_amount:
base_amount = self.amount_parser.parse(breakdown.base_amount)
if base_amount is not None:
expected_vat = base_amount * (breakdown.rate / 100)
is_valid = abs(expected_vat - actual_vat) <= self.tolerance
results.append(
MathCheckResult(
rate=breakdown.rate,
base_amount=base_amount,
expected_vat=expected_vat,
actual_vat=actual_vat,
is_valid=is_valid,
tolerance=self.tolerance,
)
)
return results
def _check_totals(self, vat_summary: VATSummary) -> bool:
"""Check if total_excl + total_vat = total_incl."""
if not vat_summary.total_excl_vat or not vat_summary.total_incl_vat:
# Can't verify without both values
return True # Assume ok if we can't check
excl = self.amount_parser.parse(vat_summary.total_excl_vat)
incl = self.amount_parser.parse(vat_summary.total_incl_vat)
if excl is None or incl is None:
return True # Can't verify
# Calculate expected VAT
if vat_summary.total_vat:
vat = self.amount_parser.parse(vat_summary.total_vat)
if vat is not None:
expected_incl = excl + vat
return abs(expected_incl - incl) <= self.tolerance
# Can't verify if vat parsing failed
return True
else:
# Sum up breakdown VAT amounts
total_vat = sum(
self.amount_parser.parse(b.vat_amount) or 0
for b in vat_summary.breakdowns
)
expected_incl = excl + total_vat
return abs(expected_incl - incl) <= self.tolerance
def _check_line_items(
self, vat_summary: VATSummary, line_items: LineItemsResult | None
) -> bool | None:
"""Check if line items total matches VAT summary."""
if line_items is None or not line_items.items:
return None # No comparison possible
# Sum line item amounts
line_total = 0.0
for item in line_items.items:
if item.amount:
amount = self.amount_parser.parse(item.amount)
if amount is not None:
line_total += amount
# Compare with VAT summary total
if vat_summary.total_excl_vat:
summary_total = self.amount_parser.parse(vat_summary.total_excl_vat)
if summary_total is not None:
# Allow larger tolerance for line items (rounding errors)
return abs(line_total - summary_total) <= 1.0
return None
def _check_amount_consistency(
self, vat_summary: VATSummary, existing_amount: str | None
) -> bool | None:
"""Check if VAT total matches existing amount field."""
if existing_amount is None:
return None # No comparison possible
existing = self.amount_parser.parse(existing_amount)
if existing is None:
return None
if vat_summary.total_incl_vat:
vat_total = self.amount_parser.parse(vat_summary.total_incl_vat)
if vat_total is not None:
return abs(existing - vat_total) <= self.tolerance
return None
def _calculate_confidence(
self,
vat_summary: VATSummary,
math_checks: list[MathCheckResult],
total_check: bool,
line_items_check: bool | None,
amount_check: bool | None,
) -> float:
"""Calculate overall confidence score."""
score = vat_summary.confidence # Start with extraction confidence
# Adjust based on validation results
if math_checks:
math_valid_ratio = sum(1 for c in math_checks if c.is_valid) / len(math_checks)
score = score * (0.5 + 0.5 * math_valid_ratio)
if not total_check:
score *= 0.5
if line_items_check is True:
score = min(score * 1.1, 1.0) # Boost if line items match
elif line_items_check is False:
score *= 0.7
if amount_check is True:
score = min(score * 1.1, 1.0) # Boost if amount matches
elif amount_check is False:
score *= 0.6
return round(score, 2)

View File

@@ -0,0 +1,19 @@
"""
VAT extraction module.
Extracts VAT (Moms) information from Swedish invoices using regex patterns.
"""
from .vat_extractor import (
VATBreakdown,
VATSummary,
VATExtractor,
AmountParser,
)
__all__ = [
"VATBreakdown",
"VATSummary",
"VATExtractor",
"AmountParser",
]

View File

@@ -0,0 +1,350 @@
"""
VAT Extractor
Extracts VAT (Moms) information from Swedish invoice text using regex patterns.
Supports multiple VAT rates (25%, 12%, 6%, 0%) and various Swedish formats.
"""
from dataclasses import dataclass
import re
from decimal import Decimal, InvalidOperation
@dataclass
class VATBreakdown:
"""Single VAT rate breakdown."""
rate: float # 25.0, 12.0, 6.0, 0.0
base_amount: str | None # Tax base (excl VAT)
vat_amount: str # VAT amount
source: str # 'regex' | 'line_items'
@dataclass
class VATSummary:
"""Complete VAT summary."""
breakdowns: list[VATBreakdown]
total_excl_vat: str | None
total_vat: str | None
total_incl_vat: str | None
confidence: float
class AmountParser:
"""Parse Swedish and European number formats."""
# Patterns to clean amount strings
CURRENCY_PATTERN = re.compile(r"(SEK|kr|:-)\s*", re.IGNORECASE)
def parse(self, amount_str: str) -> float | None:
"""
Parse amount string to float.
Handles:
- Swedish: 1 234,56
- European: 1.234,56
- US: 1,234.56
Args:
amount_str: Amount string to parse.
Returns:
Parsed float value or None if invalid.
"""
if not amount_str or not amount_str.strip():
return None
# Clean the string
cleaned = amount_str.strip()
# Remove currency
cleaned = self.CURRENCY_PATTERN.sub("", cleaned).strip()
cleaned = re.sub(r"^SEK\s*", "", cleaned, flags=re.IGNORECASE)
if not cleaned:
return None
# Check for negative
is_negative = cleaned.startswith("-")
if is_negative:
cleaned = cleaned[1:].strip()
try:
# Remove spaces (Swedish thousands separator)
cleaned = cleaned.replace(" ", "")
# Detect format
# Swedish/European: comma is decimal separator
# US: period is decimal separator
has_comma = "," in cleaned
has_period = "." in cleaned
if has_comma and has_period:
# Both present - check position
comma_pos = cleaned.rfind(",")
period_pos = cleaned.rfind(".")
if comma_pos > period_pos:
# European: 1.234,56
cleaned = cleaned.replace(".", "")
cleaned = cleaned.replace(",", ".")
else:
# US: 1,234.56
cleaned = cleaned.replace(",", "")
elif has_comma:
# Swedish: 1234,56
cleaned = cleaned.replace(",", ".")
# else: US format or integer
value = float(cleaned)
return -value if is_negative else value
except (ValueError, InvalidOperation):
return None
class VATExtractor:
"""Extract VAT information from invoice text."""
# VAT extraction patterns
# Note: Amount pattern uses [^\n] to avoid crossing line boundaries
VAT_PATTERNS = [
# Moms 25%: 2 500,00 or Moms 25% 2 500,00
re.compile(
r"[Mm]oms\s*(\d+(?:[,\.]\d+)?)\s*%\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
# Varav moms 25% 2 500,00
re.compile(
r"[Vv]arav\s+moms\s+(\d+(?:[,\.]\d+)?)\s*%\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
# 25% moms: 2 500,00 (at line start or after whitespace)
re.compile(
r"(?:^|\s)(\d+(?:[,\.]\d+)?)\s*%\s*moms\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
# Moms (25%): 2 500,00
re.compile(
r"[Mm]oms\s*\((\d+(?:[,\.]\d+)?)\s*%\)\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
re.MULTILINE,
),
]
# Pattern with base amount (Underlag)
VAT_WITH_BASE_PATTERN = re.compile(
r"[Mm]oms\s*(\d+(?:[,\.]\d+)?)\s*%\s*:?\s*([\d\s,\.]+)"
r"(?:.*?[Uu]nderlag\s*([\d\s,\.]+))?",
re.MULTILINE | re.DOTALL,
)
# Total patterns
TOTAL_EXCL_PATTERN = re.compile(
r"(?:[Ss]umma|[Tt]otal(?:t)?|[Nn]etto)\s*(?:exkl\.?\s*)?(?:moms)?\s*:?\s*([\d\s,\.]+)",
re.MULTILINE,
)
TOTAL_VAT_PATTERN = re.compile(
r"(?:[Ss]umma|[Tt]otal(?:t)?)\s*moms\s*:?\s*([\d\s,\.]+)",
re.MULTILINE,
)
TOTAL_INCL_PATTERN = re.compile(
r"(?:[Ss]umma|[Tt]otal(?:t)?|[Bb]rutto)\s*(?:inkl\.?\s*)?(?:moms|att\s*betala)?\s*:?\s*([\d\s,\.]+)",
re.MULTILINE,
)
def __init__(self):
self.amount_parser = AmountParser()
def extract(self, text: str) -> VATSummary:
"""
Extract VAT information from text.
Args:
text: Invoice text (OCR output).
Returns:
VATSummary with extracted information.
"""
if not text or not text.strip():
return VATSummary(
breakdowns=[],
total_excl_vat=None,
total_vat=None,
total_incl_vat=None,
confidence=0.0,
)
breakdowns = self._extract_breakdowns(text)
total_excl = self._extract_total_excl(text)
total_vat = self._extract_total_vat(text)
total_incl = self._extract_total_incl(text)
confidence = self._calculate_confidence(
breakdowns, total_excl, total_vat, total_incl
)
return VATSummary(
breakdowns=breakdowns,
total_excl_vat=total_excl,
total_vat=total_vat,
total_incl_vat=total_incl,
confidence=confidence,
)
def _extract_breakdowns(self, text: str) -> list[VATBreakdown]:
"""Extract individual VAT rate breakdowns."""
breakdowns = []
seen_rates = set()
# Try pattern with base amount first
for match in self.VAT_WITH_BASE_PATTERN.finditer(text):
rate = self._parse_rate(match.group(1))
vat_amount = self._clean_amount(match.group(2))
base_amount = (
self._clean_amount(match.group(3)) if match.group(3) else None
)
if rate is not None and vat_amount and rate not in seen_rates:
seen_rates.add(rate)
breakdowns.append(
VATBreakdown(
rate=rate,
base_amount=base_amount,
vat_amount=vat_amount,
source="regex",
)
)
# Try other patterns
for pattern in self.VAT_PATTERNS:
for match in pattern.finditer(text):
rate = self._parse_rate(match.group(1))
vat_amount = self._clean_amount(match.group(2))
if rate is not None and vat_amount and rate not in seen_rates:
seen_rates.add(rate)
breakdowns.append(
VATBreakdown(
rate=rate,
base_amount=None,
vat_amount=vat_amount,
source="regex",
)
)
return breakdowns
def _extract_total_excl(self, text: str) -> str | None:
"""Extract total excluding VAT."""
# Look for specific patterns first
patterns = [
re.compile(r"[Ss]umma\s+exkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Nn]etto\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Ee]xkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
]
for pattern in patterns:
match = pattern.search(text)
if match:
return self._clean_amount(match.group(1))
return None
def _extract_total_vat(self, text: str) -> str | None:
"""Extract total VAT amount."""
patterns = [
re.compile(r"[Ss]umma\s+moms\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Tt]otal(?:t)?\s+moms\s*:?\s*([\d\s,\.]+)"),
# Generic "Moms:" without percentage
re.compile(r"^[Mm]oms\s*:?\s*([\d\s,\.]+)", re.MULTILINE),
]
for pattern in patterns:
match = pattern.search(text)
if match:
return self._clean_amount(match.group(1))
return None
def _extract_total_incl(self, text: str) -> str | None:
"""Extract total including VAT."""
patterns = [
re.compile(r"[Ss]umma\s+inkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Tt]otal(?:t)?\s+att\s+betala\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Bb]rutto\s*:?\s*([\d\s,\.]+)"),
re.compile(r"[Aa]tt\s+betala\s*:?\s*([\d\s,\.]+)"),
]
for pattern in patterns:
match = pattern.search(text)
if match:
return self._clean_amount(match.group(1))
return None
def _parse_rate(self, rate_str: str) -> float | None:
"""Parse VAT rate string to float."""
try:
rate_str = rate_str.replace(",", ".")
return float(rate_str)
except (ValueError, TypeError):
return None
def _clean_amount(self, amount_str: str) -> str | None:
"""Clean and validate amount string."""
if not amount_str:
return None
cleaned = amount_str.strip()
# Remove trailing non-numeric chars (except comma/period)
cleaned = re.sub(r"[^\d\s,\.]+$", "", cleaned).strip()
if not cleaned:
return None
# Validate it parses as a number
if self.amount_parser.parse(cleaned) is None:
return None
return cleaned
def _calculate_confidence(
self,
breakdowns: list[VATBreakdown],
total_excl: str | None,
total_vat: str | None,
total_incl: str | None,
) -> float:
"""Calculate confidence score based on extracted data."""
score = 0.0
# Has VAT breakdowns
if breakdowns:
score += 0.3
# Has total excluding VAT
if total_excl:
score += 0.2
# Has total VAT
if total_vat:
score += 0.2
# Has total including VAT
if total_incl:
score += 0.15
# Mathematical consistency check
if total_excl and total_vat and total_incl:
excl = self.amount_parser.parse(total_excl)
vat = self.amount_parser.parse(total_vat)
incl = self.amount_parser.parse(total_incl)
if excl and vat and incl:
expected = excl + vat
if abs(expected - incl) < 0.02: # Allow 2 cent tolerance
score += 0.15
return min(score, 1.0)

View File

@@ -12,7 +12,7 @@ import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi import APIRouter, File, Form, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from backend.web.schemas.inference import (
@@ -20,6 +20,12 @@ from backend.web.schemas.inference import (
HealthResponse,
InferenceResponse,
InferenceResult,
LineItemSchema,
LineItemsResultSchema,
MathCheckResultSchema,
VATBreakdownSchema,
VATSummarySchema,
VATValidationResultSchema,
)
from backend.web.schemas.common import ErrorResponse
from backend.web.services.storage_helpers import get_storage_helper
@@ -67,12 +73,21 @@ def create_inference_router(
)
async def infer_document(
file: UploadFile = File(..., description="PDF or image file to process"),
extract_line_items: bool = Form(
default=False,
description="Extract line items and VAT information (business features)",
),
) -> InferenceResponse:
"""
Process a document and extract invoice fields.
Accepts PDF or image files (PNG, JPG, JPEG).
Returns extracted field values with confidence scores.
When extract_line_items=True, also extracts:
- Line items (products/services with quantities and amounts)
- VAT summary (multiple tax rates breakdown)
- VAT validation (cross-validation results)
"""
# Validate file extension
if not file.filename:
@@ -116,7 +131,9 @@ def create_inference_router(
# Process based on file type
if file_ext == ".pdf":
service_result = inference_service.process_pdf(
upload_path, document_id=doc_id
upload_path,
document_id=doc_id,
extract_line_items=extract_line_items,
)
else:
service_result = inference_service.process_image(
@@ -128,6 +145,39 @@ def create_inference_router(
if service_result.visualization_path:
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
# Build business features schemas if present
line_items_schema = None
vat_summary_schema = None
vat_validation_schema = None
if service_result.line_items:
line_items_schema = LineItemsResultSchema(
items=[LineItemSchema(**item) for item in service_result.line_items.get("items", [])],
header_row=service_result.line_items.get("header_row", []),
total_amount=service_result.line_items.get("total_amount"),
)
if service_result.vat_summary:
vat_summary_schema = VATSummarySchema(
breakdowns=[VATBreakdownSchema(**b) for b in service_result.vat_summary.get("breakdowns", [])],
total_excl_vat=service_result.vat_summary.get("total_excl_vat"),
total_vat=service_result.vat_summary.get("total_vat"),
total_incl_vat=service_result.vat_summary.get("total_incl_vat"),
confidence=service_result.vat_summary.get("confidence", 0.0),
)
if service_result.vat_validation:
vat_validation_schema = VATValidationResultSchema(
is_valid=service_result.vat_validation.get("is_valid", False),
confidence_score=service_result.vat_validation.get("confidence_score", 0.0),
math_checks=[MathCheckResultSchema(**m) for m in service_result.vat_validation.get("math_checks", [])],
total_check=service_result.vat_validation.get("total_check", False),
line_items_vs_summary=service_result.vat_validation.get("line_items_vs_summary"),
amount_consistency=service_result.vat_validation.get("amount_consistency"),
needs_review=service_result.vat_validation.get("needs_review", False),
review_reasons=service_result.vat_validation.get("review_reasons", []),
)
inference_result = InferenceResult(
document_id=service_result.document_id,
success=service_result.success,
@@ -140,6 +190,9 @@ def create_inference_router(
processing_time_ms=service_result.processing_time_ms,
visualization_url=viz_url,
errors=service_result.errors,
line_items=line_items_schema,
vat_summary=vat_summary_schema,
vat_validation=vat_validation_schema,
)
return InferenceResponse(

View File

@@ -69,6 +69,17 @@ class InferenceResult(BaseModel):
)
errors: list[str] = Field(default_factory=list, description="Error messages")
# Business features (optional, only when extract_line_items=True)
line_items: "LineItemsResultSchema | None" = Field(
None, description="Extracted line items (when extract_line_items=True)"
)
vat_summary: "VATSummarySchema | None" = Field(
None, description="VAT summary (when extract_line_items=True)"
)
vat_validation: "VATValidationResultSchema | None" = Field(
None, description="VAT validation result (when extract_line_items=True)"
)
class InferenceResponse(BaseModel):
"""API response for inference endpoint."""
@@ -194,3 +205,90 @@ class RateLimitInfo(BaseModel):
limit: int = Field(..., description="Maximum requests per minute")
remaining: int = Field(..., description="Remaining requests in current window")
reset_at: datetime = Field(..., description="Time when limit resets")
# =============================================================================
# Business Features Schemas (Line Items, VAT)
# =============================================================================
class LineItemSchema(BaseModel):
"""Single line item from invoice."""
row_index: int = Field(..., description="Row index in the table")
description: str | None = Field(None, description="Product/service description")
quantity: str | None = Field(None, description="Quantity")
unit: str | None = Field(None, description="Unit (st, pcs, etc.)")
unit_price: str | None = Field(None, description="Price per unit")
amount: str | None = Field(None, description="Line total amount")
article_number: str | None = Field(None, description="Article/product number")
vat_rate: str | None = Field(None, description="VAT rate (e.g., '25')")
is_deduction: bool = Field(default=False, description="True if this row is a deduction/discount (avdrag/rabatt)")
confidence: float = Field(default=0.0, ge=0, le=1, description="Extraction confidence")
class LineItemsResultSchema(BaseModel):
"""Line items extraction result."""
items: list[LineItemSchema] = Field(default_factory=list, description="Extracted line items")
header_row: list[str] = Field(default_factory=list, description="Table header row")
total_amount: str | None = Field(None, description="Calculated total from line items")
class VATBreakdownSchema(BaseModel):
"""Single VAT rate breakdown."""
rate: float = Field(..., description="VAT rate (e.g., 25.0, 12.0, 6.0)")
base_amount: str | None = Field(None, description="Tax base amount (excluding VAT)")
vat_amount: str | None = Field(None, description="VAT amount")
source: str = Field(default="regex", description="Extraction source (regex or line_items)")
class VATSummarySchema(BaseModel):
"""VAT summary information."""
breakdowns: list[VATBreakdownSchema] = Field(
default_factory=list, description="VAT breakdowns by rate"
)
total_excl_vat: str | None = Field(None, description="Total excluding VAT")
total_vat: str | None = Field(None, description="Total VAT amount")
total_incl_vat: str | None = Field(None, description="Total including VAT")
confidence: float = Field(default=0.0, ge=0, le=1, description="Extraction confidence")
class MathCheckResultSchema(BaseModel):
"""Single math validation check result."""
rate: float = Field(..., description="VAT rate checked")
base_amount: float | None = Field(None, description="Base amount")
expected_vat: float | None = Field(None, description="Expected VAT (base * rate)")
actual_vat: float | None = Field(None, description="Actual VAT from invoice")
is_valid: bool = Field(..., description="Whether math check passed")
tolerance: float = Field(..., description="Tolerance used for comparison")
class VATValidationResultSchema(BaseModel):
"""VAT cross-validation result."""
is_valid: bool = Field(..., description="Overall validation status")
confidence_score: float = Field(
..., ge=0, le=1, description="Validation confidence score"
)
math_checks: list[MathCheckResultSchema] = Field(
default_factory=list, description="Math check results per VAT rate"
)
total_check: bool = Field(default=False, description="Whether total calculation is valid")
line_items_vs_summary: bool | None = Field(
None, description="Whether line items match VAT summary"
)
amount_consistency: bool | None = Field(
None, description="Whether total matches detected amount field"
)
needs_review: bool = Field(default=False, description="Whether manual review is recommended")
review_reasons: list[str] = Field(
default_factory=list, description="Reasons for manual review"
)
# Rebuild models to resolve forward references
InferenceResult.model_rebuild()

View File

@@ -42,6 +42,11 @@ class ServiceResult:
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
# Business features (optional, populated when extract_line_items=True)
line_items: dict | None = None
vat_summary: dict | None = None
vat_validation: dict | None = None
class InferenceService:
"""
@@ -74,6 +79,7 @@ class InferenceService:
self._detector = None
self._is_initialized = False
self._current_model_path: Path | None = None
self._business_features_enabled = False
def _resolve_model_path(self) -> Path:
"""Resolve the model path to use for inference.
@@ -95,12 +101,16 @@ class InferenceService:
return self.model_config.model_path
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
def initialize(self, enable_business_features: bool = False) -> None:
"""Initialize the inference pipeline (lazy loading).
Args:
enable_business_features: Whether to enable line items and VAT extraction
"""
if self._is_initialized:
return
logger.info("Initializing inference service...")
logger.info(f"Initializing inference service (business_features={enable_business_features})...")
start_time = time.time()
try:
@@ -118,16 +128,18 @@ class InferenceService:
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
# Initialize full pipeline with optional business features
self._pipeline = InferencePipeline(
model_path=str(model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
enable_business_features=enable_business_features,
)
self._is_initialized = True
self._business_features_enabled = enable_business_features
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
@@ -242,6 +254,7 @@ class InferenceService:
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
extract_line_items: bool = False,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
@@ -250,12 +263,17 @@ class InferenceService:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
extract_line_items: Whether to extract line items and VAT info
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
self.initialize(enable_business_features=extract_line_items)
elif extract_line_items and not self._business_features_enabled:
# Reinitialize with business features if needed
self._is_initialized = False
self.initialize(enable_business_features=True)
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
@@ -263,8 +281,12 @@ class InferenceService:
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
# Run inference pipeline with optional business features
pipeline_result = self._pipeline.process_pdf(
pdf_path,
document_id=doc_id,
extract_line_items=extract_line_items,
)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
@@ -288,6 +310,12 @@ class InferenceService:
for d in pipeline_result.raw_detections
]
# Include business features if extracted
if extract_line_items:
result.line_items = pipeline_result._line_items_to_json() if pipeline_result.line_items else None
result.vat_summary = pipeline_result._vat_summary_to_json() if pipeline_result.vat_summary else None
result.vat_validation = pipeline_result._vat_validation_to_json() if pipeline_result.vat_validation else None
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)

View File

@@ -4,7 +4,7 @@ setup(
name="invoice-backend",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
python_requires=">=3.10", # 3.10 for RTX 50 series SM120 wheel
install_requires=[
"invoice-shared",
"fastapi>=0.104.0",