Update paddle, and support invoice line item
This commit is contained in:
@@ -1,5 +1,18 @@
|
||||
from .pipeline import InferencePipeline, InferenceResult
|
||||
from .pipeline import (
|
||||
InferencePipeline,
|
||||
InferenceResult,
|
||||
CrossValidationResult,
|
||||
BUSINESS_FEATURES_AVAILABLE,
|
||||
)
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor
|
||||
|
||||
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']
|
||||
__all__ = [
|
||||
'InferencePipeline',
|
||||
'InferenceResult',
|
||||
'CrossValidationResult',
|
||||
'YOLODetector',
|
||||
'Detection',
|
||||
'FieldExtractor',
|
||||
'BUSINESS_FEATURES_AVAILABLE',
|
||||
]
|
||||
|
||||
@@ -2,19 +2,39 @@
|
||||
Inference Pipeline
|
||||
|
||||
Complete pipeline for extracting invoice data from PDFs.
|
||||
Supports both basic field extraction and business invoice features
|
||||
(line items, VAT extraction, cross-validation).
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import logging
|
||||
import time
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from shared.fields import CLASS_TO_FIELD
|
||||
from .yolo_detector import YOLODetector, Detection
|
||||
from .field_extractor import FieldExtractor, ExtractedField
|
||||
from .payment_line_parser import PaymentLineParser
|
||||
|
||||
# Business invoice feature imports (optional - for extract_line_items mode)
|
||||
try:
|
||||
from ..table.line_items_extractor import LineItem, LineItemsResult, LineItemsExtractor
|
||||
from ..table.structure_detector import TableDetector
|
||||
from ..vat.vat_extractor import VATSummary, VATExtractor
|
||||
from ..validation.vat_validator import VATValidationResult, VATValidator
|
||||
BUSINESS_FEATURES_AVAILABLE = True
|
||||
except ImportError:
|
||||
BUSINESS_FEATURES_AVAILABLE = False
|
||||
LineItem = None
|
||||
LineItemsResult = None
|
||||
TableDetector = None
|
||||
VATSummary = None
|
||||
VATValidationResult = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossValidationResult:
|
||||
@@ -45,6 +65,10 @@ class InferenceResult:
|
||||
errors: list[str] = field(default_factory=list)
|
||||
fallback_used: bool = False
|
||||
cross_validation: CrossValidationResult | None = None
|
||||
# Business invoice features (optional)
|
||||
line_items: Any | None = None # LineItemsResult when available
|
||||
vat_summary: Any | None = None # VATSummary when available
|
||||
vat_validation: Any | None = None # VATValidationResult when available
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""Convert to JSON-serializable dictionary."""
|
||||
@@ -81,8 +105,89 @@ class InferenceResult:
|
||||
'payment_line_account_type': self.cross_validation.payment_line_account_type,
|
||||
'details': self.cross_validation.details,
|
||||
}
|
||||
|
||||
# Add business invoice features if present
|
||||
if self.line_items is not None:
|
||||
result['line_items'] = self._line_items_to_json()
|
||||
if self.vat_summary is not None:
|
||||
result['vat_summary'] = self._vat_summary_to_json()
|
||||
if self.vat_validation is not None:
|
||||
result['vat_validation'] = self._vat_validation_to_json()
|
||||
|
||||
return result
|
||||
|
||||
def _line_items_to_json(self) -> dict | None:
|
||||
"""Convert LineItemsResult to JSON."""
|
||||
if self.line_items is None:
|
||||
return None
|
||||
li = self.line_items
|
||||
return {
|
||||
'items': [
|
||||
{
|
||||
'row_index': item.row_index,
|
||||
'description': item.description,
|
||||
'quantity': item.quantity,
|
||||
'unit': item.unit,
|
||||
'unit_price': item.unit_price,
|
||||
'amount': item.amount,
|
||||
'article_number': item.article_number,
|
||||
'vat_rate': item.vat_rate,
|
||||
'is_deduction': item.is_deduction,
|
||||
'confidence': item.confidence,
|
||||
}
|
||||
for item in li.items
|
||||
],
|
||||
'header_row': li.header_row,
|
||||
'total_amount': li.total_amount,
|
||||
}
|
||||
|
||||
def _vat_summary_to_json(self) -> dict | None:
|
||||
"""Convert VATSummary to JSON."""
|
||||
if self.vat_summary is None:
|
||||
return None
|
||||
vs = self.vat_summary
|
||||
return {
|
||||
'breakdowns': [
|
||||
{
|
||||
'rate': b.rate,
|
||||
'base_amount': b.base_amount,
|
||||
'vat_amount': b.vat_amount,
|
||||
'source': b.source,
|
||||
}
|
||||
for b in vs.breakdowns
|
||||
],
|
||||
'total_excl_vat': vs.total_excl_vat,
|
||||
'total_vat': vs.total_vat,
|
||||
'total_incl_vat': vs.total_incl_vat,
|
||||
'confidence': vs.confidence,
|
||||
}
|
||||
|
||||
def _vat_validation_to_json(self) -> dict | None:
|
||||
"""Convert VATValidationResult to JSON."""
|
||||
if self.vat_validation is None:
|
||||
return None
|
||||
vv = self.vat_validation
|
||||
return {
|
||||
'is_valid': vv.is_valid,
|
||||
'confidence_score': vv.confidence_score,
|
||||
'math_checks': [
|
||||
{
|
||||
'rate': mc.rate,
|
||||
'base_amount': mc.base_amount,
|
||||
'expected_vat': mc.expected_vat,
|
||||
'actual_vat': mc.actual_vat,
|
||||
'is_valid': mc.is_valid,
|
||||
'tolerance': mc.tolerance,
|
||||
}
|
||||
for mc in vv.math_checks
|
||||
],
|
||||
'total_check': vv.total_check,
|
||||
'line_items_vs_summary': vv.line_items_vs_summary,
|
||||
'amount_consistency': vv.amount_consistency,
|
||||
'needs_review': vv.needs_review,
|
||||
'review_reasons': vv.review_reasons,
|
||||
}
|
||||
|
||||
def get_field(self, field_name: str) -> tuple[Any, float]:
|
||||
"""Get field value and confidence."""
|
||||
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
|
||||
@@ -107,7 +212,9 @@ class InferencePipeline:
|
||||
ocr_lang: str = 'en',
|
||||
use_gpu: bool = False,
|
||||
dpi: int = 300,
|
||||
enable_fallback: bool = True
|
||||
enable_fallback: bool = True,
|
||||
enable_business_features: bool = False,
|
||||
vat_tolerance: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize inference pipeline.
|
||||
@@ -119,6 +226,8 @@ class InferencePipeline:
|
||||
use_gpu: Whether to use GPU
|
||||
dpi: Resolution for PDF rendering
|
||||
enable_fallback: Enable fallback to full-page OCR
|
||||
enable_business_features: Enable line items/VAT extraction
|
||||
vat_tolerance: Tolerance for VAT math checks (in currency units)
|
||||
"""
|
||||
self.detector = YOLODetector(
|
||||
model_path,
|
||||
@@ -129,11 +238,34 @@ class InferencePipeline:
|
||||
self.payment_line_parser = PaymentLineParser()
|
||||
self.dpi = dpi
|
||||
self.enable_fallback = enable_fallback
|
||||
self.enable_business_features = enable_business_features
|
||||
self.vat_tolerance = vat_tolerance
|
||||
|
||||
# Initialize business feature components if enabled and available
|
||||
self.line_items_extractor = None
|
||||
self.vat_extractor = None
|
||||
self.vat_validator = None
|
||||
self._business_ocr_engine = None # Lazy-initialized for VAT text extraction
|
||||
self._table_detector = None # Shared TableDetector for line items extraction
|
||||
|
||||
if enable_business_features:
|
||||
if not BUSINESS_FEATURES_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Business features require table, vat, and validation modules. "
|
||||
"Please ensure they are properly installed."
|
||||
)
|
||||
# Create shared TableDetector for performance (PP-StructureV3 init is slow)
|
||||
self._table_detector = TableDetector()
|
||||
# Pass shared detector to LineItemsExtractor
|
||||
self.line_items_extractor = LineItemsExtractor(table_detector=self._table_detector)
|
||||
self.vat_extractor = VATExtractor()
|
||||
self.vat_validator = VATValidator(tolerance=vat_tolerance)
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
document_id: str | None = None
|
||||
document_id: str | None = None,
|
||||
extract_line_items: bool | None = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Process a PDF and extract invoice fields.
|
||||
@@ -141,6 +273,8 @@ class InferencePipeline:
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
extract_line_items: Whether to extract line items and VAT info.
|
||||
If None, uses the enable_business_features setting from __init__.
|
||||
|
||||
Returns:
|
||||
InferenceResult with extracted fields
|
||||
@@ -156,9 +290,16 @@ class InferencePipeline:
|
||||
document_id=document_id or Path(pdf_path).stem
|
||||
)
|
||||
|
||||
# Determine if business features should be used
|
||||
use_business_features = (
|
||||
extract_line_items if extract_line_items is not None
|
||||
else self.enable_business_features
|
||||
)
|
||||
|
||||
try:
|
||||
all_detections = []
|
||||
all_extracted = []
|
||||
all_ocr_text = [] # Collect OCR text for VAT extraction
|
||||
|
||||
# Process each page
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||
@@ -175,6 +316,11 @@ class InferencePipeline:
|
||||
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||
all_extracted.append(extracted)
|
||||
|
||||
# Collect full-page OCR text for VAT extraction (only if business features enabled)
|
||||
if use_business_features:
|
||||
page_text = self._get_full_page_text(image_array)
|
||||
all_ocr_text.append(page_text)
|
||||
|
||||
result.raw_detections = all_detections
|
||||
result.extracted_fields = all_extracted
|
||||
|
||||
@@ -185,6 +331,10 @@ class InferencePipeline:
|
||||
if self.enable_fallback and self._needs_fallback(result):
|
||||
self._run_fallback(pdf_path, result)
|
||||
|
||||
# Extract business invoice features if enabled
|
||||
if use_business_features:
|
||||
self._extract_business_features(pdf_path, result, '\n'.join(all_ocr_text))
|
||||
|
||||
result.success = len(result.fields) > 0
|
||||
|
||||
except Exception as e:
|
||||
@@ -194,6 +344,78 @@ class InferencePipeline:
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _get_full_page_text(self, image_array) -> str:
|
||||
"""Extract full page text using OCR for VAT extraction."""
|
||||
from shared.ocr import OCREngine
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# Lazy initialize OCR engine to avoid repeated model loading
|
||||
if self._business_ocr_engine is None:
|
||||
self._business_ocr_engine = OCREngine()
|
||||
|
||||
tokens = self._business_ocr_engine.extract_from_image(image_array, page_no=0)
|
||||
return ' '.join(t.text for t in tokens)
|
||||
except Exception as e:
|
||||
logger.warning(f"OCR extraction for VAT failed: {e}")
|
||||
return ""
|
||||
|
||||
def _extract_business_features(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
result: InferenceResult,
|
||||
full_text: str
|
||||
) -> None:
|
||||
"""
|
||||
Extract line items, VAT summary, and perform cross-validation.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
result: InferenceResult to populate
|
||||
full_text: Full OCR text from all pages
|
||||
"""
|
||||
if not BUSINESS_FEATURES_AVAILABLE:
|
||||
result.errors.append("Business features not available")
|
||||
return
|
||||
|
||||
if not self.line_items_extractor or not self.vat_extractor or not self.vat_validator:
|
||||
result.errors.append("Business feature extractors not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract line items from tables
|
||||
logger.info(f"Extracting line items from PDF: {pdf_path}")
|
||||
line_items_result = self.line_items_extractor.extract_from_pdf(str(pdf_path))
|
||||
logger.info(f"Line items extraction result: {line_items_result is not None}, items={len(line_items_result.items) if line_items_result else 0}")
|
||||
if line_items_result and line_items_result.items:
|
||||
result.line_items = line_items_result
|
||||
logger.info(f"Set result.line_items with {len(line_items_result.items)} items")
|
||||
|
||||
# Extract VAT summary from text
|
||||
logger.info(f"Extracting VAT summary from text ({len(full_text)} chars)")
|
||||
vat_summary = self.vat_extractor.extract(full_text)
|
||||
logger.info(f"VAT summary extraction result: {vat_summary is not None}")
|
||||
if vat_summary:
|
||||
result.vat_summary = vat_summary
|
||||
|
||||
# Cross-validate VAT information
|
||||
existing_amount = result.fields.get('Amount')
|
||||
vat_validation = self.vat_validator.validate(
|
||||
vat_summary,
|
||||
line_items=line_items_result,
|
||||
existing_amount=str(existing_amount) if existing_amount else None
|
||||
)
|
||||
result.vat_validation = vat_validation
|
||||
logger.info(f"VAT validation completed: is_valid={vat_validation.is_valid if vat_validation else None}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_detail = f"{type(e).__name__}: {e}"
|
||||
logger.error(f"Business feature extraction failed: {error_detail}\n{traceback.format_exc()}")
|
||||
result.errors.append(f"Business feature extraction error: {error_detail}")
|
||||
|
||||
def _merge_fields(self, result: InferenceResult) -> None:
|
||||
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||
|
||||
32
packages/backend/backend/table/__init__.py
Normal file
32
packages/backend/backend/table/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Table detection and extraction module.
|
||||
|
||||
This module provides PP-StructureV3-based table detection for invoices,
|
||||
and line items extraction from detected tables.
|
||||
"""
|
||||
|
||||
from .structure_detector import (
|
||||
TableDetectionResult,
|
||||
TableDetector,
|
||||
TableDetectorConfig,
|
||||
)
|
||||
from .line_items_extractor import (
|
||||
LineItem,
|
||||
LineItemsResult,
|
||||
LineItemsExtractor,
|
||||
ColumnMapper,
|
||||
HTMLTableParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Structure detection
|
||||
"TableDetectionResult",
|
||||
"TableDetector",
|
||||
"TableDetectorConfig",
|
||||
# Line items extraction
|
||||
"LineItem",
|
||||
"LineItemsResult",
|
||||
"LineItemsExtractor",
|
||||
"ColumnMapper",
|
||||
"HTMLTableParser",
|
||||
]
|
||||
970
packages/backend/backend/table/line_items_extractor.py
Normal file
970
packages/backend/backend/table/line_items_extractor.py
Normal file
@@ -0,0 +1,970 @@
|
||||
"""
|
||||
Line Items Extractor
|
||||
|
||||
Extracts structured line items from HTML tables produced by PP-StructureV3.
|
||||
Handles Swedish invoice formats including reversed tables (header at bottom).
|
||||
Includes fallback text-based extraction for invoices without detectable table structures.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from html.parser import HTMLParser
|
||||
from decimal import Decimal, InvalidOperation
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItem:
|
||||
"""Single line item from invoice."""
|
||||
|
||||
row_index: int
|
||||
description: str | None = None
|
||||
quantity: str | None = None
|
||||
unit: str | None = None
|
||||
unit_price: str | None = None
|
||||
amount: str | None = None
|
||||
article_number: str | None = None
|
||||
vat_rate: str | None = None
|
||||
is_deduction: bool = False # True if this row is a deduction/discount
|
||||
confidence: float = 0.9
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItemsResult:
|
||||
"""Result of line items extraction."""
|
||||
|
||||
items: list[LineItem]
|
||||
header_row: list[str]
|
||||
raw_html: str
|
||||
is_reversed: bool = False
|
||||
|
||||
@property
|
||||
def total_amount(self) -> str | None:
|
||||
"""Calculate total amount from line items (deduction rows have negative amounts)."""
|
||||
if not self.items:
|
||||
return None
|
||||
|
||||
total = Decimal("0")
|
||||
for item in self.items:
|
||||
if item.amount:
|
||||
try:
|
||||
# Parse Swedish number format (1 234,56)
|
||||
amount_str = item.amount.replace(" ", "").replace(",", ".")
|
||||
total += Decimal(amount_str)
|
||||
except InvalidOperation:
|
||||
pass
|
||||
|
||||
if total == 0:
|
||||
return None
|
||||
|
||||
# Format back to Swedish format
|
||||
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
|
||||
# Fix the space/comma swap
|
||||
parts = formatted.rsplit(",", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].replace(" ", " ") + "," + parts[1]
|
||||
return formatted
|
||||
|
||||
|
||||
# Swedish column name mappings
|
||||
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
|
||||
COLUMN_MAPPINGS = {
|
||||
"article_number": [
|
||||
"art nummer",
|
||||
"artikelnummer",
|
||||
"artikel",
|
||||
"artnr",
|
||||
"art.nr",
|
||||
"art nr",
|
||||
"objektnummer", # Rental: property reference
|
||||
"objekt",
|
||||
],
|
||||
"description": [
|
||||
"beskrivning",
|
||||
"produktbeskrivning",
|
||||
"produkt",
|
||||
"tjänst",
|
||||
"text",
|
||||
"benämning",
|
||||
"vara/tjänst",
|
||||
"vara",
|
||||
# Rental invoice specific
|
||||
"specifikation",
|
||||
"spec",
|
||||
"hyresperiod", # Rental period
|
||||
"period",
|
||||
"typ", # Type of charge
|
||||
# Utility bills
|
||||
"förbrukning", # Consumption
|
||||
"avläsning", # Meter reading
|
||||
],
|
||||
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "m²", "kvm"],
|
||||
"unit": ["enhet", "unit"],
|
||||
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
|
||||
"amount": [
|
||||
"belopp",
|
||||
"summa",
|
||||
"total",
|
||||
"netto",
|
||||
"rad summa",
|
||||
# Rental specific
|
||||
"hyra", # Rent
|
||||
"avgift", # Fee
|
||||
"kostnad", # Cost
|
||||
"debitering", # Charge
|
||||
"totalt", # Total
|
||||
],
|
||||
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
|
||||
# Additional field for rental: deductions/adjustments
|
||||
"deduction": [
|
||||
"avdrag", # Deduction
|
||||
"rabatt", # Discount
|
||||
"kredit", # Credit
|
||||
],
|
||||
}
|
||||
|
||||
# Keywords that indicate NOT a line items table
|
||||
SUMMARY_KEYWORDS = [
|
||||
"frakt",
|
||||
"faktura.avg",
|
||||
"fakturavg",
|
||||
"exkl.moms",
|
||||
"att betala",
|
||||
"öresavr",
|
||||
"bankgiro",
|
||||
"plusgiro",
|
||||
"ocr",
|
||||
"forfallodatum",
|
||||
"förfallodatum",
|
||||
]
|
||||
|
||||
|
||||
class _TableHTMLParser(HTMLParser):
|
||||
"""Internal HTML parser for tables."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rows: list[list[str]] = []
|
||||
self.current_row: list[str] = []
|
||||
self.current_cell: str = ""
|
||||
self.in_td = False
|
||||
self.in_thead = False
|
||||
self.header_row: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "tr":
|
||||
self.current_row = []
|
||||
elif tag in ("td", "th"):
|
||||
self.in_td = True
|
||||
self.current_cell = ""
|
||||
elif tag == "thead":
|
||||
self.in_thead = True
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if tag in ("td", "th"):
|
||||
self.in_td = False
|
||||
self.current_row.append(self.current_cell.strip())
|
||||
elif tag == "tr":
|
||||
if self.current_row:
|
||||
if self.in_thead:
|
||||
self.header_row = self.current_row
|
||||
else:
|
||||
self.rows.append(self.current_row)
|
||||
elif tag == "thead":
|
||||
self.in_thead = False
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.in_td:
|
||||
self.current_cell += data
|
||||
|
||||
|
||||
class HTMLTableParser:
|
||||
"""Parse HTML tables into structured data."""
|
||||
|
||||
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
|
||||
"""
|
||||
Parse HTML table and return header and rows.
|
||||
|
||||
Args:
|
||||
html: HTML string containing table.
|
||||
|
||||
Returns:
|
||||
Tuple of (header_row, data_rows).
|
||||
"""
|
||||
parser = _TableHTMLParser()
|
||||
parser.feed(html)
|
||||
return parser.header_row, parser.rows
|
||||
|
||||
|
||||
class ColumnMapper:
|
||||
"""Map column headers to field names."""
|
||||
|
||||
def __init__(self, mappings: dict[str, list[str]] | None = None):
|
||||
"""
|
||||
Initialize column mapper.
|
||||
|
||||
Args:
|
||||
mappings: Custom column mappings. Uses Swedish defaults if None.
|
||||
"""
|
||||
self.mappings = mappings or COLUMN_MAPPINGS
|
||||
|
||||
def map(self, headers: list[str]) -> dict[int, str]:
|
||||
"""
|
||||
Map column indices to field names.
|
||||
|
||||
Args:
|
||||
headers: List of column header strings.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping column index to field name.
|
||||
"""
|
||||
mapping = {}
|
||||
for idx, header in enumerate(headers):
|
||||
normalized = self._normalize(header)
|
||||
|
||||
if not normalized.strip():
|
||||
continue
|
||||
|
||||
best_match = None
|
||||
best_match_len = 0
|
||||
|
||||
for field_name, patterns in self.mappings.items():
|
||||
for pattern in patterns:
|
||||
if pattern == normalized:
|
||||
best_match = field_name
|
||||
best_match_len = len(pattern) + 100
|
||||
break
|
||||
elif pattern in normalized and len(pattern) > best_match_len:
|
||||
if len(pattern) >= 3:
|
||||
best_match = field_name
|
||||
best_match_len = len(pattern)
|
||||
|
||||
if best_match_len > 100:
|
||||
break
|
||||
|
||||
if best_match:
|
||||
mapping[idx] = best_match
|
||||
|
||||
return mapping
|
||||
|
||||
def _normalize(self, header: str) -> str:
|
||||
"""Normalize header text for matching."""
|
||||
return header.lower().strip().replace(".", "").replace("-", " ")
|
||||
|
||||
|
||||
class LineItemsExtractor:
|
||||
"""Extract structured line items from HTML tables."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
column_mapper: ColumnMapper | None = None,
|
||||
table_detector: "TableDetector | None" = None,
|
||||
enable_text_fallback: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize extractor.
|
||||
|
||||
Args:
|
||||
column_mapper: Custom column mapper. Uses default if None.
|
||||
table_detector: Pre-initialized TableDetector to reuse. Creates new if None.
|
||||
enable_text_fallback: Enable text-based fallback extraction when no tables detected.
|
||||
"""
|
||||
self.parser = HTMLTableParser()
|
||||
self.mapper = column_mapper or ColumnMapper()
|
||||
self._table_detector = table_detector
|
||||
self._enable_text_fallback = enable_text_fallback
|
||||
self._text_extractor = None # Lazy initialized
|
||||
|
||||
def extract(self, html: str) -> LineItemsResult:
|
||||
"""
|
||||
Extract line items from HTML table.
|
||||
|
||||
Args:
|
||||
html: HTML string containing table.
|
||||
|
||||
Returns:
|
||||
LineItemsResult with extracted items.
|
||||
"""
|
||||
header, rows = self.parser.parse(html)
|
||||
is_reversed = False
|
||||
|
||||
# Check if cells contain merged multi-line data (PP-StructureV3 issue)
|
||||
if rows and self._has_vertically_merged_cells(rows):
|
||||
logger.info("Detected vertically merged cells, attempting to split")
|
||||
header, rows = self._split_merged_rows(rows)
|
||||
|
||||
if not header:
|
||||
header_idx, detected_header, is_at_end = self._detect_header_row(rows)
|
||||
if header_idx >= 0:
|
||||
header = detected_header
|
||||
if is_at_end:
|
||||
is_reversed = True
|
||||
rows = rows[:header_idx]
|
||||
else:
|
||||
rows = rows[header_idx + 1 :]
|
||||
elif rows:
|
||||
for i, row in enumerate(rows):
|
||||
if any(cell.strip() for cell in row):
|
||||
header = row
|
||||
rows = rows[i + 1 :]
|
||||
break
|
||||
|
||||
column_map = self.mapper.map(header)
|
||||
items = self._extract_items(rows, column_map)
|
||||
|
||||
# If no items extracted but header looks like line items table,
|
||||
# try parsing merged cells (common in poorly OCR'd rental invoices)
|
||||
if not items and self._has_merged_header(header):
|
||||
logger.info(f"Trying merged cell parsing: header={header}, rows={rows}")
|
||||
items = self._extract_from_merged_cells(header, rows)
|
||||
logger.info(f"Merged cell parsing result: {len(items)} items")
|
||||
|
||||
return LineItemsResult(
|
||||
items=items,
|
||||
header_row=header,
|
||||
raw_html=html,
|
||||
is_reversed=is_reversed,
|
||||
)
|
||||
|
||||
def _get_table_detector(self) -> "TableDetector":
|
||||
"""Get or create TableDetector instance (lazy initialization)."""
|
||||
if self._table_detector is None:
|
||||
from .structure_detector import TableDetector
|
||||
self._table_detector = TableDetector()
|
||||
return self._table_detector
|
||||
|
||||
def _get_text_extractor(self) -> "TextLineItemsExtractor":
|
||||
"""Get or create TextLineItemsExtractor instance (lazy initialization)."""
|
||||
if self._text_extractor is None:
|
||||
from .text_line_items_extractor import TextLineItemsExtractor
|
||||
self._text_extractor = TextLineItemsExtractor()
|
||||
return self._text_extractor
|
||||
|
||||
def extract_from_pdf(self, pdf_path: str) -> LineItemsResult | None:
|
||||
"""
|
||||
Extract line items from a PDF by detecting tables.
|
||||
|
||||
Uses PP-StructureV3 for table detection and extraction.
|
||||
Falls back to text-based extraction if no tables detected.
|
||||
Reuses TableDetector instance for performance.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file.
|
||||
|
||||
Returns:
|
||||
LineItemsResult if line items are found, None otherwise.
|
||||
"""
|
||||
# Reuse detector instance for performance
|
||||
detector = self._get_table_detector()
|
||||
tables, parsing_res_list = self._detect_tables_with_parsing(detector, pdf_path)
|
||||
|
||||
logger.info(f"LineItemsExtractor: detected {len(tables) if tables else 0} tables from PDF")
|
||||
|
||||
# Try table-based extraction first
|
||||
best_result = self._extract_from_tables(tables)
|
||||
|
||||
# If no results from tables and fallback is enabled, try text-based extraction
|
||||
if best_result is None and self._enable_text_fallback and parsing_res_list:
|
||||
logger.info("LineItemsExtractor: no tables found, trying text-based fallback")
|
||||
best_result = self._extract_from_text(parsing_res_list)
|
||||
|
||||
logger.info(f"LineItemsExtractor: final result has {len(best_result.items) if best_result else 0} items")
|
||||
return best_result
|
||||
|
||||
def _detect_tables_with_parsing(
|
||||
self, detector: "TableDetector", pdf_path: str
|
||||
) -> tuple[list, list]:
|
||||
"""
|
||||
Detect tables and also return parsing_res_list for fallback.
|
||||
|
||||
Args:
|
||||
detector: TableDetector instance.
|
||||
pdf_path: Path to PDF file.
|
||||
|
||||
Returns:
|
||||
Tuple of (table_results, parsing_res_list).
|
||||
"""
|
||||
from pathlib import Path
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
pdf_path = Path(pdf_path)
|
||||
if not pdf_path.exists():
|
||||
logger.warning(f"PDF not found: {pdf_path}")
|
||||
return [], []
|
||||
|
||||
# Ensure detector is initialized
|
||||
detector._ensure_initialized()
|
||||
|
||||
# Render first page
|
||||
parsing_res_list = []
|
||||
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=300):
|
||||
if page_no == 0:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
# Run PP-StructureV3 and get raw results
|
||||
if detector._pipeline is None:
|
||||
return [], []
|
||||
|
||||
raw_results = detector._pipeline.predict(image_array)
|
||||
|
||||
# Extract parsing_res_list from raw results
|
||||
if raw_results:
|
||||
for result in raw_results if isinstance(raw_results, list) else [raw_results]:
|
||||
if hasattr(result, "get"):
|
||||
parsing_res_list = result.get("parsing_res_list", [])
|
||||
elif hasattr(result, "parsing_res_list"):
|
||||
parsing_res_list = result.parsing_res_list or []
|
||||
|
||||
# Parse tables using existing logic
|
||||
tables = detector._parse_results(raw_results)
|
||||
return tables, parsing_res_list
|
||||
|
||||
return [], []
|
||||
|
||||
def _extract_from_tables(self, tables: list) -> LineItemsResult | None:
|
||||
"""Extract line items from detected tables."""
|
||||
if not tables:
|
||||
return None
|
||||
|
||||
best_result = None
|
||||
best_item_count = 0
|
||||
|
||||
for i, table in enumerate(tables):
|
||||
if not table.html:
|
||||
logger.debug(f"Table {i}: no HTML content")
|
||||
continue
|
||||
|
||||
logger.info(f"Table {i}: html_len={len(table.html)}, html={table.html[:500]}")
|
||||
result = self.extract(table.html)
|
||||
logger.info(f"Table {i}: extracted {len(result.items)} items, headers={result.header_row}")
|
||||
|
||||
# Check if this table has line items
|
||||
is_line_items = self.is_line_items_table(result.header_row or [])
|
||||
logger.info(f"Table {i}: is_line_items_table={is_line_items}")
|
||||
|
||||
if result.items and is_line_items:
|
||||
if len(result.items) > best_item_count:
|
||||
best_item_count = len(result.items)
|
||||
best_result = result
|
||||
logger.debug(f"Table {i}: selected as best (items={best_item_count})")
|
||||
|
||||
return best_result
|
||||
|
||||
def _extract_from_text(self, parsing_res_list: list) -> LineItemsResult | None:
|
||||
"""Extract line items using text-based fallback."""
|
||||
from .text_line_items_extractor import convert_text_line_item
|
||||
|
||||
text_extractor = self._get_text_extractor()
|
||||
text_result = text_extractor.extract_from_parsing_res(parsing_res_list)
|
||||
|
||||
if text_result is None or not text_result.items:
|
||||
logger.debug("Text-based extraction found no items")
|
||||
return None
|
||||
|
||||
# Convert TextLineItems to LineItems
|
||||
converted_items = [convert_text_line_item(item) for item in text_result.items]
|
||||
|
||||
logger.info(f"Text-based extraction found {len(converted_items)} items")
|
||||
return LineItemsResult(
|
||||
items=converted_items,
|
||||
header_row=text_result.header_row,
|
||||
raw_html="", # No HTML for text-based extraction
|
||||
is_reversed=False,
|
||||
)
|
||||
|
||||
def is_line_items_table(self, headers: list[str]) -> bool:
|
||||
"""
|
||||
Check if headers indicate a line items table.
|
||||
|
||||
Args:
|
||||
headers: List of column headers.
|
||||
|
||||
Returns:
|
||||
True if this appears to be a line items table.
|
||||
"""
|
||||
column_map = self.mapper.map(headers)
|
||||
mapped_fields = set(column_map.values())
|
||||
|
||||
logger.debug(f"is_line_items_table: headers={headers}, mapped_fields={mapped_fields}")
|
||||
|
||||
# Must have description or article_number OR amount field
|
||||
# (rental invoices may have amount columns like "Hyra" without explicit description)
|
||||
has_item_identifier = (
|
||||
"description" in mapped_fields
|
||||
or "article_number" in mapped_fields
|
||||
)
|
||||
has_amount = "amount" in mapped_fields
|
||||
|
||||
# Check for summary table keywords
|
||||
header_text = " ".join(h.lower() for h in headers)
|
||||
is_summary = any(kw in header_text for kw in SUMMARY_KEYWORDS)
|
||||
|
||||
# Accept table if it has item identifiers OR has amount columns (and not a summary)
|
||||
result = (has_item_identifier or has_amount) and not is_summary
|
||||
logger.debug(f"is_line_items_table: has_item_identifier={has_item_identifier}, has_amount={has_amount}, is_summary={is_summary}, result={result}")
|
||||
|
||||
return result
|
||||
|
||||
def _detect_header_row(
|
||||
self, rows: list[list[str]]
|
||||
) -> tuple[int, list[str], bool]:
|
||||
"""
|
||||
Detect which row is the header based on content patterns.
|
||||
|
||||
Returns:
|
||||
Tuple of (header_index, header_row, is_at_end).
|
||||
"""
|
||||
header_keywords = set()
|
||||
for patterns in self.mapper.mappings.values():
|
||||
for p in patterns:
|
||||
header_keywords.add(p.lower())
|
||||
|
||||
best_match = (-1, [], 0)
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if all(not cell.strip() for cell in row):
|
||||
continue
|
||||
|
||||
row_text = " ".join(cell.lower() for cell in row)
|
||||
matches = sum(1 for kw in header_keywords if kw in row_text)
|
||||
|
||||
if matches > best_match[2]:
|
||||
best_match = (i, row, matches)
|
||||
|
||||
if best_match[2] >= 2:
|
||||
header_idx = best_match[0]
|
||||
is_at_end = header_idx == len(rows) - 1 or header_idx > len(rows) // 2
|
||||
return header_idx, best_match[1], is_at_end
|
||||
|
||||
return -1, [], False
|
||||
|
||||
def _extract_items(
|
||||
self, rows: list[list[str]], column_map: dict[int, str]
|
||||
) -> list[LineItem]:
|
||||
"""Extract line items from data rows."""
|
||||
items = []
|
||||
|
||||
for row_idx, row in enumerate(rows):
|
||||
item_data: dict = {
|
||||
"row_index": row_idx,
|
||||
"description": None,
|
||||
"quantity": None,
|
||||
"unit": None,
|
||||
"unit_price": None,
|
||||
"amount": None,
|
||||
"article_number": None,
|
||||
"vat_rate": None,
|
||||
"is_deduction": False,
|
||||
}
|
||||
|
||||
for col_idx, cell in enumerate(row):
|
||||
if col_idx in column_map:
|
||||
field = column_map[col_idx]
|
||||
# Handle deduction column - store value as amount and mark as deduction
|
||||
if field == "deduction":
|
||||
if cell:
|
||||
item_data["amount"] = cell
|
||||
item_data["is_deduction"] = True
|
||||
# Skip assigning to "deduction" field (it doesn't exist in LineItem)
|
||||
else:
|
||||
item_data[field] = cell if cell else None
|
||||
|
||||
# Only add if we have at least description or amount
|
||||
if item_data["description"] or item_data["amount"]:
|
||||
items.append(LineItem(**item_data))
|
||||
|
||||
return items
|
||||
|
||||
def _has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
|
||||
"""
|
||||
Check if table rows contain vertically merged data in single cells.
|
||||
|
||||
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
|
||||
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
|
||||
|
||||
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
|
||||
"""
|
||||
if not rows:
|
||||
return False
|
||||
|
||||
for row in rows:
|
||||
for cell in row:
|
||||
if not cell or len(cell) < 20:
|
||||
continue
|
||||
|
||||
# Check for multiple product numbers (7+ digit patterns)
|
||||
product_nums = re.findall(r"\b\d{7}\b", cell)
|
||||
if len(product_nums) >= 2:
|
||||
logger.debug(f"_has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
|
||||
return True
|
||||
|
||||
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
|
||||
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
|
||||
if len(prices) >= 3:
|
||||
logger.debug(f"_has_vertically_merged_cells: found {len(prices)} prices in cell")
|
||||
return True
|
||||
|
||||
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
|
||||
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
|
||||
if len(quantities) >= 2:
|
||||
logger.debug(f"_has_vertically_merged_cells: found {len(quantities)} quantities in cell")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _split_merged_rows(
|
||||
self, rows: list[list[str]]
|
||||
) -> tuple[list[str], list[list[str]]]:
|
||||
"""
|
||||
Split vertically merged cells back into separate rows.
|
||||
|
||||
Handles complex cases where PP-StructureV3 merges content across
|
||||
multiple HTML rows. For example, 5 line items might be spread across
|
||||
3 HTML rows with content mixed together.
|
||||
|
||||
Strategy:
|
||||
1. Merge all row content per column
|
||||
2. Detect how many actual data rows exist (by counting product numbers)
|
||||
3. Split each column's content into that many lines
|
||||
|
||||
Returns header and data rows.
|
||||
"""
|
||||
if not rows:
|
||||
return [], []
|
||||
|
||||
# Filter out completely empty rows
|
||||
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
|
||||
if not non_empty_rows:
|
||||
return [], rows
|
||||
|
||||
# Determine column count
|
||||
col_count = max(len(r) for r in non_empty_rows)
|
||||
|
||||
# Merge content from all rows for each column
|
||||
merged_columns = []
|
||||
for col_idx in range(col_count):
|
||||
col_content = []
|
||||
for row in non_empty_rows:
|
||||
if col_idx < len(row) and row[col_idx].strip():
|
||||
col_content.append(row[col_idx].strip())
|
||||
merged_columns.append(" ".join(col_content))
|
||||
|
||||
logger.debug(f"_split_merged_rows: merged columns = {merged_columns}")
|
||||
|
||||
# Count how many actual data rows we should have
|
||||
# Use the column with most product numbers as reference
|
||||
expected_rows = self._count_expected_rows(merged_columns)
|
||||
logger.info(f"_split_merged_rows: expecting {expected_rows} data rows")
|
||||
|
||||
if expected_rows <= 1:
|
||||
# Not enough data for splitting
|
||||
return [], rows
|
||||
|
||||
# Split each column based on expected row count
|
||||
split_columns = []
|
||||
for col_idx, col_text in enumerate(merged_columns):
|
||||
if not col_text.strip():
|
||||
split_columns.append([""] * (expected_rows + 1)) # +1 for header
|
||||
continue
|
||||
lines = self._split_cell_content_for_rows(col_text, expected_rows)
|
||||
split_columns.append(lines)
|
||||
|
||||
# Ensure all columns have same number of lines
|
||||
max_lines = max(len(col) for col in split_columns)
|
||||
for col in split_columns:
|
||||
while len(col) < max_lines:
|
||||
col.append("")
|
||||
|
||||
logger.info(f"_split_merged_rows: split into {max_lines} lines total")
|
||||
|
||||
# First line is header, rest are data rows
|
||||
header = [col[0] for col in split_columns]
|
||||
data_rows = []
|
||||
for line_idx in range(1, max_lines):
|
||||
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
|
||||
if any(cell.strip() for cell in row):
|
||||
data_rows.append(row)
|
||||
|
||||
logger.info(f"_split_merged_rows: header={header}, data_rows count={len(data_rows)}")
|
||||
return header, data_rows
|
||||
|
||||
def _count_expected_rows(self, merged_columns: list[str]) -> int:
|
||||
"""
|
||||
Count how many data rows should exist based on content patterns.
|
||||
|
||||
Returns the maximum count found from:
|
||||
- Product numbers (7 digits)
|
||||
- Quantity patterns (number + ST/PCS)
|
||||
- Amount patterns (in columns likely to be totals)
|
||||
"""
|
||||
max_count = 0
|
||||
|
||||
for col_text in merged_columns:
|
||||
if not col_text:
|
||||
continue
|
||||
|
||||
# Count product numbers (most reliable indicator)
|
||||
product_nums = re.findall(r"\b\d{7}\b", col_text)
|
||||
max_count = max(max_count, len(product_nums))
|
||||
|
||||
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
|
||||
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
|
||||
max_count = max(max_count, len(quantities))
|
||||
|
||||
return max_count
|
||||
|
||||
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
|
||||
"""
|
||||
Split cell content knowing how many data rows we expect.
|
||||
|
||||
This is smarter than _split_cell_content because it knows the target count.
|
||||
"""
|
||||
cell = cell.strip()
|
||||
|
||||
# Try product number split first
|
||||
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||
products = product_pattern.findall(cell)
|
||||
if len(products) == expected_rows:
|
||||
parts = product_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
# Include description text after each product number
|
||||
values = []
|
||||
for i in range(1, len(parts), 2): # Odd indices are product numbers
|
||||
if i < len(parts):
|
||||
prod_num = parts[i].strip()
|
||||
# Check if there's description text after
|
||||
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
|
||||
# If description looks like text (not another pattern), include it
|
||||
if desc and not re.match(r"^\d{7}$", desc):
|
||||
# Truncate at next product number pattern if any
|
||||
desc_clean = re.split(r"\d{7}", desc)[0].strip()
|
||||
if desc_clean:
|
||||
values.append(f"{prod_num} {desc_clean}")
|
||||
else:
|
||||
values.append(prod_num)
|
||||
else:
|
||||
values.append(prod_num)
|
||||
if len(values) == expected_rows:
|
||||
return [header] + values
|
||||
|
||||
# Try quantity split
|
||||
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||
quantities = qty_pattern.findall(cell)
|
||||
if len(quantities) == expected_rows:
|
||||
parts = qty_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||
if len(values) == expected_rows:
|
||||
return [header] + values
|
||||
|
||||
# Try amount split for discount+totalsumma columns
|
||||
cell_lower = cell.lower()
|
||||
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
|
||||
|
||||
if has_discount and has_total:
|
||||
# Extract only amounts (3+ digit numbers), skip discount percentages
|
||||
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||
amounts = amount_pattern.findall(cell)
|
||||
if len(amounts) >= expected_rows:
|
||||
# Take the last expected_rows amounts (they are likely the totals)
|
||||
return ["Totalsumma"] + amounts[:expected_rows]
|
||||
|
||||
# Try price split
|
||||
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||
prices = price_pattern.findall(cell)
|
||||
if len(prices) >= expected_rows:
|
||||
parts = price_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||
if len(values) >= expected_rows:
|
||||
return [header] + values[:expected_rows]
|
||||
|
||||
# Fall back to original single-value behavior
|
||||
return [cell]
|
||||
|
||||
def _split_cell_content(self, cell: str) -> list[str]:
|
||||
"""
|
||||
Split a cell containing merged multi-line content.
|
||||
|
||||
Strategies:
|
||||
1. Look for product number patterns (7 digits)
|
||||
2. Look for quantity patterns (number + ST/PCS)
|
||||
3. Look for price patterns (with decimal)
|
||||
4. Handle interleaved discount+amount patterns
|
||||
"""
|
||||
cell = cell.strip()
|
||||
|
||||
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
|
||||
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||
products = product_pattern.findall(cell)
|
||||
if len(products) >= 2:
|
||||
# Extract header (text before first product number) and values
|
||||
parts = product_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
|
||||
return [header] + values
|
||||
|
||||
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
|
||||
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||
quantities = qty_pattern.findall(cell)
|
||||
if len(quantities) >= 2:
|
||||
parts = qty_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||
return [header] + values
|
||||
|
||||
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
|
||||
# Check if header contains two keywords indicating merged columns
|
||||
cell_lower = cell.lower()
|
||||
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
|
||||
|
||||
if has_discount_header and has_amount_header:
|
||||
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
|
||||
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
|
||||
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||
amounts = amount_pattern.findall(cell)
|
||||
|
||||
if len(amounts) >= 2:
|
||||
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
|
||||
# This avoids the "Rabatt" keyword causing is_deduction=True
|
||||
header = "Totalsumma"
|
||||
return [header] + amounts
|
||||
|
||||
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
|
||||
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||
prices = price_pattern.findall(cell)
|
||||
if len(prices) >= 2:
|
||||
parts = price_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||
return [header] + values
|
||||
|
||||
# No pattern detected, return as single value
|
||||
return [cell]
|
||||
|
||||
def _has_merged_header(self, header: list[str] | None) -> bool:
|
||||
"""
|
||||
Check if header appears to be a merged cell containing multiple column names.
|
||||
|
||||
This happens when OCR merges table headers into a single cell, e.g.:
|
||||
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
|
||||
|
||||
Also handles cases where PP-StructureV3 produces headers like:
|
||||
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
|
||||
"""
|
||||
if header is None or not header:
|
||||
return False
|
||||
|
||||
# Filter out empty cells to find the actual content
|
||||
non_empty_cells = [h for h in header if h.strip()]
|
||||
|
||||
# Check if we have a single non-empty cell that contains multiple keywords
|
||||
if len(non_empty_cells) == 1:
|
||||
header_text = non_empty_cells[0].lower()
|
||||
# Count how many column keywords are in this single cell
|
||||
keyword_count = 0
|
||||
for patterns in self.mapper.mappings.values():
|
||||
for pattern in patterns:
|
||||
if pattern in header_text:
|
||||
keyword_count += 1
|
||||
break # Only count once per field type
|
||||
|
||||
logger.debug(f"_has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
|
||||
return keyword_count >= 2
|
||||
|
||||
return False
|
||||
|
||||
def _extract_from_merged_cells(
|
||||
self, header: list[str], rows: list[list[str]]
|
||||
) -> list[LineItem]:
|
||||
"""
|
||||
Extract line items from tables with merged cells.
|
||||
|
||||
For poorly OCR'd tables like:
|
||||
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
Row 1: ["", "", "", "8159"] <- amount row
|
||||
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
|
||||
|
||||
Or:
|
||||
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
|
||||
|
||||
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
|
||||
"""
|
||||
items = []
|
||||
|
||||
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
|
||||
amount_pattern = re.compile(
|
||||
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
|
||||
)
|
||||
|
||||
# Try to parse header cell for description info
|
||||
header_text = " ".join(h for h in header if h.strip()) if header else ""
|
||||
logger.info(f"_extract_from_merged_cells: header_text='{header_text}'")
|
||||
logger.info(f"_extract_from_merged_cells: rows={rows}")
|
||||
|
||||
# Extract description from header
|
||||
description = None
|
||||
article_number = None
|
||||
|
||||
# Look for object number pattern (e.g., "0218103-1201")
|
||||
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
|
||||
if obj_match:
|
||||
article_number = obj_match.group(1)
|
||||
|
||||
# Look for description after object number
|
||||
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
|
||||
if desc_match:
|
||||
description = desc_match.group(1).strip()
|
||||
|
||||
row_index = 0
|
||||
for row in rows:
|
||||
# Combine all non-empty cells in the row
|
||||
row_text = " ".join(cell.strip() for cell in row if cell.strip())
|
||||
logger.info(f"_extract_from_merged_cells: row text='{row_text}'")
|
||||
|
||||
if not row_text:
|
||||
continue
|
||||
|
||||
# Find all amounts in the row
|
||||
amounts = amount_pattern.findall(row_text)
|
||||
logger.info(f"_extract_from_merged_cells: amounts={amounts}")
|
||||
|
||||
for amt_str in amounts:
|
||||
# Clean the amount string
|
||||
cleaned = amt_str.replace(" ", "").strip()
|
||||
if not cleaned or cleaned == "-":
|
||||
continue
|
||||
|
||||
is_deduction = cleaned.startswith("-")
|
||||
|
||||
# Skip small positive numbers that are likely not amounts
|
||||
if not is_deduction:
|
||||
try:
|
||||
val = float(cleaned.replace(",", "."))
|
||||
if val < 100:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Create a line item for each amount
|
||||
item = LineItem(
|
||||
row_index=row_index,
|
||||
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
|
||||
article_number=article_number if row_index == 0 else None,
|
||||
amount=cleaned,
|
||||
is_deduction=is_deduction,
|
||||
confidence=0.7,
|
||||
)
|
||||
items.append(item)
|
||||
row_index += 1
|
||||
logger.info(f"_extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
|
||||
|
||||
return items
|
||||
480
packages/backend/backend/table/structure_detector.py
Normal file
480
packages/backend/backend/table/structure_detector.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
PP-StructureV3 Table Detection Wrapper
|
||||
|
||||
Provides automatic table detection in invoice images using PaddleOCR's
|
||||
PP-StructureV3 pipeline. Supports both wired (bordered) and wireless
|
||||
(borderless) tables commonly found in Swedish invoices.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableDetectorConfig:
|
||||
"""Configuration for TableDetector."""
|
||||
|
||||
device: str = "gpu:0"
|
||||
use_doc_orientation_classify: bool = False
|
||||
use_doc_unwarping: bool = False
|
||||
use_textline_orientation: bool = False
|
||||
# Use SLANeXt models for better table recognition accuracy
|
||||
# SLANeXt_wireless has ~6% higher accuracy than SLANet for borderless tables
|
||||
wired_table_model: str = "SLANeXt_wired"
|
||||
wireless_table_model: str = "SLANeXt_wireless"
|
||||
layout_model: str = "PP-DocLayout_plus-L"
|
||||
min_confidence: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableDetectionResult:
|
||||
"""Result of table detection."""
|
||||
|
||||
bbox: tuple[float, float, float, float] # x1, y1, x2, y2 in pixels
|
||||
html: str # Table structure as HTML
|
||||
confidence: float
|
||||
table_type: str # 'wired' or 'wireless'
|
||||
cells: list[dict[str, Any]] = field(default_factory=list) # Cell-level data
|
||||
|
||||
|
||||
class PPStructureProtocol(Protocol):
|
||||
"""Protocol for PP-StructureV3 pipeline interface."""
|
||||
|
||||
def predict(self, image: str | np.ndarray, **kwargs: Any) -> Any:
|
||||
"""Run prediction on image."""
|
||||
...
|
||||
|
||||
|
||||
class TableDetector:
|
||||
"""
|
||||
Table detector using PP-StructureV3.
|
||||
|
||||
Detects tables in invoice images and returns their bounding boxes,
|
||||
HTML structure, and cell-level data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TableDetectorConfig | None = None,
|
||||
pipeline: PPStructureProtocol | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize table detector.
|
||||
|
||||
Args:
|
||||
config: Configuration options. Uses defaults if None.
|
||||
pipeline: Optional pre-initialized PP-StructureV3 pipeline.
|
||||
If None, will be lazily initialized on first use.
|
||||
"""
|
||||
self.config = config or TableDetectorConfig()
|
||||
self._pipeline = pipeline
|
||||
self._initialized = pipeline is not None
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""Lazily initialize PP-Structure pipeline."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Try PPStructureV3 first (paddleocr >= 3.0.0), fall back to PPStructure (2.x)
|
||||
try:
|
||||
from paddleocr import PPStructureV3
|
||||
|
||||
self._pipeline = PPStructureV3(
|
||||
layout_detection_model_name=self.config.layout_model,
|
||||
wired_table_structure_recognition_model_name=self.config.wired_table_model,
|
||||
wireless_table_structure_recognition_model_name=self.config.wireless_table_model,
|
||||
use_doc_orientation_classify=self.config.use_doc_orientation_classify,
|
||||
use_doc_unwarping=self.config.use_doc_unwarping,
|
||||
use_textline_orientation=self.config.use_textline_orientation,
|
||||
device=self.config.device,
|
||||
)
|
||||
self._initialized = True
|
||||
logger.info("PP-StructureV3 pipeline initialized successfully")
|
||||
except ImportError:
|
||||
# Fall back to PPStructure (paddleocr 2.x)
|
||||
try:
|
||||
from paddleocr import PPStructure
|
||||
|
||||
# Map device config to use_gpu for PPStructure 2.x
|
||||
use_gpu = "gpu" in self.config.device.lower()
|
||||
self._pipeline = PPStructure(
|
||||
table=True,
|
||||
ocr=True,
|
||||
use_gpu=use_gpu,
|
||||
show_log=False,
|
||||
)
|
||||
self._initialized = True
|
||||
logger.info("PPStructure (2.x) pipeline initialized successfully")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PPStructure requires paddleocr. "
|
||||
"Install with: pip install paddleocr"
|
||||
) from e
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: np.ndarray | str | Path,
|
||||
) -> list[TableDetectionResult]:
|
||||
"""
|
||||
Detect tables in an image.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array, file path, or Path object.
|
||||
|
||||
Returns:
|
||||
List of TableDetectionResult for each detected table.
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
if self._pipeline is None:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
# Convert Path to string
|
||||
if isinstance(image, Path):
|
||||
image = str(image)
|
||||
|
||||
# Run detection
|
||||
results = self._pipeline.predict(image)
|
||||
|
||||
return self._parse_results(results)
|
||||
|
||||
def _parse_results(self, results: Any) -> list[TableDetectionResult]:
|
||||
"""Parse PP-StructureV3 output into TableDetectionResult list.
|
||||
|
||||
Supports both:
|
||||
- PaddleX 3.x API: dict-like LayoutParsingResultV2 with table_res_list
|
||||
- Legacy API: objects with layout_elements attribute
|
||||
"""
|
||||
tables: list[TableDetectionResult] = []
|
||||
|
||||
if results is None:
|
||||
logger.warning("PP-StructureV3 returned None results")
|
||||
return tables
|
||||
|
||||
# Log raw result type for debugging
|
||||
logger.info(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
||||
|
||||
# Handle case where results is a single dict-like object (PaddleX 3.x)
|
||||
# rather than a list of results
|
||||
if hasattr(results, "get") and not isinstance(results, list):
|
||||
# Single result object - wrap in list for uniform processing
|
||||
logger.info("Results is dict-like, wrapping in list")
|
||||
results = [results]
|
||||
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
|
||||
# Iterator or generator - convert to list
|
||||
try:
|
||||
results = list(results)
|
||||
logger.info(f"Converted iterator to list with {len(results)} items")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert results to list: {e}")
|
||||
return tables
|
||||
|
||||
logger.info(f"Processing {len(results)} result(s)")
|
||||
|
||||
for i, result in enumerate(results):
|
||||
try:
|
||||
result_type = type(result).__name__
|
||||
has_get = hasattr(result, "get")
|
||||
has_layout = hasattr(result, "layout_elements")
|
||||
logger.info(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
||||
|
||||
# Try PaddleX 3.x API first (dict-like with table_res_list)
|
||||
if has_get:
|
||||
parsed = self._parse_paddlex_result(result)
|
||||
logger.info(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
||||
tables.extend(parsed)
|
||||
continue
|
||||
|
||||
# Fall back to legacy API (layout_elements)
|
||||
if has_layout:
|
||||
legacy_count = 0
|
||||
for element in result.layout_elements:
|
||||
if not self._is_table_element(element):
|
||||
continue
|
||||
table_result = self._extract_table_data(element)
|
||||
if table_result and table_result.confidence >= self.config.min_confidence:
|
||||
tables.append(table_result)
|
||||
legacy_count += 1
|
||||
logger.info(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
||||
else:
|
||||
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Total tables detected: {len(tables)}")
|
||||
return tables
|
||||
|
||||
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
|
||||
"""Parse PaddleX 3.x LayoutParsingResultV2."""
|
||||
tables: list[TableDetectionResult] = []
|
||||
|
||||
try:
|
||||
# Log result structure for debugging
|
||||
result_type = type(result).__name__
|
||||
result_keys = []
|
||||
if hasattr(result, "keys"):
|
||||
result_keys = list(result.keys())
|
||||
elif hasattr(result, "__dict__"):
|
||||
result_keys = list(result.__dict__.keys())
|
||||
logger.info(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
||||
|
||||
# Get table results from PaddleX 3.x API
|
||||
# Handle both dict.get() and attribute access
|
||||
if hasattr(result, "get"):
|
||||
table_res_list = result.get("table_res_list")
|
||||
parsing_res_list = result.get("parsing_res_list", [])
|
||||
else:
|
||||
table_res_list = getattr(result, "table_res_list", None)
|
||||
parsing_res_list = getattr(result, "parsing_res_list", [])
|
||||
|
||||
logger.info(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
||||
logger.info(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
||||
|
||||
if not table_res_list:
|
||||
# Log available keys/attributes for debugging
|
||||
logger.warning(f"No table_res_list found in result: {result_type}, available: {result_keys}")
|
||||
return tables
|
||||
|
||||
# Get parsing_res_list to find table bounding boxes
|
||||
table_bboxes = {}
|
||||
for elem in parsing_res_list or []:
|
||||
try:
|
||||
if isinstance(elem, dict):
|
||||
label = elem.get("label", "")
|
||||
bbox = elem.get("bbox", [])
|
||||
else:
|
||||
label = getattr(elem, "label", "")
|
||||
bbox = getattr(elem, "bbox", [])
|
||||
# Check bbox has items (handles numpy arrays safely)
|
||||
has_bbox = False
|
||||
try:
|
||||
has_bbox = len(bbox) >= 4 if hasattr(bbox, "__len__") else False
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if label == "table" and has_bbox:
|
||||
# Map by index (parsing_res_list tables appear in order)
|
||||
idx = len(table_bboxes)
|
||||
table_bboxes[idx] = bbox
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse parsing_res element: {e}")
|
||||
continue
|
||||
|
||||
for i, table_res in enumerate(table_res_list):
|
||||
try:
|
||||
# Extract from PaddleX 3.x table result format
|
||||
# Handle both dict and object access (SingleTableRecognitionResult)
|
||||
if isinstance(table_res, dict):
|
||||
cell_boxes = table_res.get("cell_box_list", [])
|
||||
html = table_res.get("pred_html", "")
|
||||
ocr_data = table_res.get("table_ocr_pred", {})
|
||||
else:
|
||||
cell_boxes = getattr(table_res, "cell_box_list", [])
|
||||
html = getattr(table_res, "pred_html", "")
|
||||
ocr_data = getattr(table_res, "table_ocr_pred", {})
|
||||
|
||||
# table_ocr_pred can be dict (PaddleOCR 3.x) or list (older versions)
|
||||
# For dict format: {"rec_texts": [...], "rec_scores": [...], ...}
|
||||
ocr_texts = []
|
||||
if isinstance(ocr_data, dict):
|
||||
ocr_texts = ocr_data.get("rec_texts", [])
|
||||
elif isinstance(ocr_data, list):
|
||||
ocr_texts = ocr_data
|
||||
|
||||
# Try to get bbox from parsing_res_list
|
||||
bbox = table_bboxes.get(i, [0.0, 0.0, 0.0, 0.0])
|
||||
# Handle numpy arrays - check length explicitly to avoid boolean ambiguity
|
||||
try:
|
||||
bbox_len = len(bbox) if hasattr(bbox, "__len__") else 0
|
||||
if bbox_len < 4:
|
||||
bbox = [0.0, 0.0, 0.0, 0.0]
|
||||
except (TypeError, ValueError):
|
||||
bbox = [0.0, 0.0, 0.0, 0.0]
|
||||
|
||||
# Build cells from cell_box_list and OCR text
|
||||
cells = []
|
||||
# Check cell_boxes length explicitly to avoid numpy array boolean issues
|
||||
has_cell_boxes = False
|
||||
try:
|
||||
has_cell_boxes = len(cell_boxes) > 0 if hasattr(cell_boxes, "__len__") else bool(cell_boxes)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if has_cell_boxes:
|
||||
# Check ocr_texts length safely for numpy arrays
|
||||
ocr_texts_len = 0
|
||||
try:
|
||||
ocr_texts_len = len(ocr_texts) if hasattr(ocr_texts, "__len__") else 0
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
for j, cell_bbox in enumerate(cell_boxes):
|
||||
cell_text = ocr_texts[j] if ocr_texts_len > j else ""
|
||||
# Convert cell_bbox to list safely (may be numpy array)
|
||||
cell_bbox_list = []
|
||||
try:
|
||||
cell_bbox_list = list(cell_bbox) if hasattr(cell_bbox, "__iter__") else []
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
cells.append({
|
||||
"text": cell_text,
|
||||
"bbox": cell_bbox_list,
|
||||
"row": 0, # Row/col info not directly available
|
||||
"col": j,
|
||||
})
|
||||
|
||||
# Default confidence for PaddleX 3.x results
|
||||
confidence = 0.9
|
||||
|
||||
logger.info(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
||||
tables.append(TableDetectionResult(
|
||||
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
|
||||
html=html,
|
||||
confidence=confidence,
|
||||
table_type="wired", # PaddleX 3.x handles both types
|
||||
cells=cells,
|
||||
))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.warning(f"Failed to parse table_res {i}: {e}\n{traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse PaddleX result: {type(e).__name__}: {e}")
|
||||
|
||||
return tables
|
||||
|
||||
def _is_table_element(self, element: Any) -> bool:
|
||||
"""Check if element is a table."""
|
||||
if hasattr(element, "label"):
|
||||
return element.label.lower() in ("table", "wired_table", "wireless_table")
|
||||
if hasattr(element, "type"):
|
||||
return element.type.lower() in ("table", "wired_table", "wireless_table")
|
||||
return False
|
||||
|
||||
def _extract_table_data(self, element: Any) -> TableDetectionResult | None:
|
||||
"""Extract table data from PP-StructureV3 element."""
|
||||
try:
|
||||
# Get bounding box
|
||||
bbox = self._get_bbox(element)
|
||||
if bbox is None:
|
||||
return None
|
||||
|
||||
# Get HTML content
|
||||
html = self._get_html(element)
|
||||
|
||||
# Get confidence
|
||||
confidence = getattr(element, "score", 0.9)
|
||||
if isinstance(confidence, (list, tuple)):
|
||||
confidence = float(confidence[0]) if confidence else 0.9
|
||||
|
||||
# Determine table type
|
||||
table_type = self._get_table_type(element)
|
||||
|
||||
# Get cells if available
|
||||
cells = self._get_cells(element)
|
||||
|
||||
return TableDetectionResult(
|
||||
bbox=bbox,
|
||||
html=html,
|
||||
confidence=float(confidence),
|
||||
table_type=table_type,
|
||||
cells=cells,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table data: {e}")
|
||||
return None
|
||||
|
||||
def _get_bbox(self, element: Any) -> tuple[float, float, float, float] | None:
|
||||
"""Extract bounding box from element."""
|
||||
if hasattr(element, "bbox"):
|
||||
bbox = element.bbox
|
||||
if len(bbox) >= 4:
|
||||
return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]))
|
||||
if hasattr(element, "box"):
|
||||
box = element.box
|
||||
if len(box) >= 4:
|
||||
return (float(box[0]), float(box[1]), float(box[2]), float(box[3]))
|
||||
return None
|
||||
|
||||
def _get_html(self, element: Any) -> str:
|
||||
"""Extract HTML content from element."""
|
||||
if hasattr(element, "html"):
|
||||
return str(element.html)
|
||||
if hasattr(element, "table_html"):
|
||||
return str(element.table_html)
|
||||
if hasattr(element, "res") and isinstance(element.res, dict):
|
||||
return element.res.get("html", "")
|
||||
return ""
|
||||
|
||||
def _get_table_type(self, element: Any) -> str:
|
||||
"""Determine table type (wired or wireless)."""
|
||||
label = ""
|
||||
if hasattr(element, "label"):
|
||||
label = str(element.label).lower()
|
||||
elif hasattr(element, "type"):
|
||||
label = str(element.type).lower()
|
||||
|
||||
if "wireless" in label or "borderless" in label:
|
||||
return "wireless"
|
||||
return "wired"
|
||||
|
||||
def _get_cells(self, element: Any) -> list[dict[str, Any]]:
|
||||
"""Extract cell-level data from element."""
|
||||
cells: list[dict[str, Any]] = []
|
||||
|
||||
if hasattr(element, "cells"):
|
||||
for cell in element.cells:
|
||||
cell_data = {
|
||||
"text": getattr(cell, "text", ""),
|
||||
"row": getattr(cell, "row", 0),
|
||||
"col": getattr(cell, "col", 0),
|
||||
"row_span": getattr(cell, "row_span", 1),
|
||||
"col_span": getattr(cell, "col_span", 1),
|
||||
}
|
||||
if hasattr(cell, "bbox"):
|
||||
cell_data["bbox"] = cell.bbox
|
||||
cells.append(cell_data)
|
||||
|
||||
return cells
|
||||
|
||||
def detect_from_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
page_number: int = 0,
|
||||
dpi: int = 300,
|
||||
) -> list[TableDetectionResult]:
|
||||
"""
|
||||
Detect tables from a PDF page.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file.
|
||||
page_number: Page number (0-indexed).
|
||||
dpi: Resolution for rendering.
|
||||
|
||||
Returns:
|
||||
List of TableDetectionResult for the specified page.
|
||||
"""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
pdf_path = Path(pdf_path)
|
||||
if not pdf_path.exists():
|
||||
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
||||
|
||||
logger.info(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
||||
|
||||
# Render specific page
|
||||
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
|
||||
if page_no == page_number:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
logger.info(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
||||
return self.detect(image_array)
|
||||
|
||||
raise ValueError(f"Page {page_number} not found in PDF")
|
||||
449
packages/backend/backend/table/text_line_items_extractor.py
Normal file
449
packages/backend/backend/table/text_line_items_extractor.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Text-Based Line Items Extractor
|
||||
|
||||
Fallback extraction for invoices where PP-StructureV3 cannot detect table structures
|
||||
(e.g., borderless/wireless tables). Uses spatial analysis of OCR text elements to
|
||||
identify and group line items.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal, InvalidOperation
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextElement:
|
||||
"""Single text element from OCR."""
|
||||
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float] # x1, y1, x2, y2
|
||||
confidence: float = 1.0
|
||||
|
||||
@property
|
||||
def center_y(self) -> float:
|
||||
"""Vertical center of the element."""
|
||||
return (self.bbox[1] + self.bbox[3]) / 2
|
||||
|
||||
@property
|
||||
def center_x(self) -> float:
|
||||
"""Horizontal center of the element."""
|
||||
return (self.bbox[0] + self.bbox[2]) / 2
|
||||
|
||||
@property
|
||||
def height(self) -> float:
|
||||
"""Height of the element."""
|
||||
return self.bbox[3] - self.bbox[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextLineItem:
|
||||
"""Line item extracted from text elements."""
|
||||
|
||||
row_index: int
|
||||
description: str | None = None
|
||||
quantity: str | None = None
|
||||
unit: str | None = None
|
||||
unit_price: str | None = None
|
||||
amount: str | None = None
|
||||
article_number: str | None = None
|
||||
vat_rate: str | None = None
|
||||
is_deduction: bool = False # True if this row is a deduction/discount
|
||||
confidence: float = 0.7 # Lower default confidence for text-based extraction
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextLineItemsResult:
|
||||
"""Result of text-based line items extraction."""
|
||||
|
||||
items: list[TextLineItem]
|
||||
header_row: list[str]
|
||||
extraction_method: str = "text_spatial"
|
||||
|
||||
|
||||
# Swedish amount pattern: 1 234,56 or 1234.56 or 1,234.56
|
||||
AMOUNT_PATTERN = re.compile(
|
||||
r"(?<![0-9])(?:"
|
||||
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
|
||||
r"|-?\d{1,3}(?:,\d{3})*(?:\.\d{2})?" # US: 1,234.56
|
||||
r"|-?\d+(?:[.,]\d{2})?" # Simple: 1234,56 or 1234.56
|
||||
r")(?:\s*(?:kr|SEK|:-))?" # Optional currency suffix
|
||||
r"(?![0-9])"
|
||||
)
|
||||
|
||||
# Quantity patterns
|
||||
QUANTITY_PATTERN = re.compile(
|
||||
r"^(?:"
|
||||
r"\d+(?:[.,]\d+)?\s*(?:st|pcs|m|kg|l|h|tim|timmar)?" # Number with optional unit
|
||||
r")$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# VAT rate patterns
|
||||
VAT_RATE_PATTERN = re.compile(r"(\d+)\s*%")
|
||||
|
||||
# Keywords indicating a line item area
|
||||
LINE_ITEM_KEYWORDS = [
|
||||
"beskrivning",
|
||||
"artikel",
|
||||
"produkt",
|
||||
"belopp",
|
||||
"summa",
|
||||
"antal",
|
||||
"pris",
|
||||
"á-pris",
|
||||
"a-pris",
|
||||
"moms",
|
||||
]
|
||||
|
||||
# Keywords indicating NOT line items (summary area)
|
||||
SUMMARY_KEYWORDS = [
|
||||
"att betala",
|
||||
"total",
|
||||
"summa att betala",
|
||||
"betalningsvillkor",
|
||||
"förfallodatum",
|
||||
"bankgiro",
|
||||
"plusgiro",
|
||||
"ocr-nummer",
|
||||
"fakturabelopp",
|
||||
"exkl. moms",
|
||||
"inkl. moms",
|
||||
"varav moms",
|
||||
]
|
||||
|
||||
|
||||
class TextLineItemsExtractor:
|
||||
"""
|
||||
Extract line items from text elements using spatial analysis.
|
||||
|
||||
This is a fallback for when PP-StructureV3 cannot detect table structures.
|
||||
It groups text elements by vertical position and identifies patterns
|
||||
that match line item rows.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
row_tolerance: float = 15.0, # Max vertical distance to consider same row
|
||||
min_items_for_valid: int = 2, # Minimum items to consider extraction valid
|
||||
):
|
||||
"""
|
||||
Initialize extractor.
|
||||
|
||||
Args:
|
||||
row_tolerance: Maximum vertical distance (pixels) between elements
|
||||
to consider them on the same row.
|
||||
min_items_for_valid: Minimum number of line items required for
|
||||
extraction to be considered successful.
|
||||
"""
|
||||
self.row_tolerance = row_tolerance
|
||||
self.min_items_for_valid = min_items_for_valid
|
||||
|
||||
def extract_from_parsing_res(
|
||||
self, parsing_res_list: list[dict[str, Any]]
|
||||
) -> TextLineItemsResult | None:
|
||||
"""
|
||||
Extract line items from PP-StructureV3 parsing_res_list.
|
||||
|
||||
Args:
|
||||
parsing_res_list: List of parsed elements from PP-StructureV3.
|
||||
|
||||
Returns:
|
||||
TextLineItemsResult if line items found, None otherwise.
|
||||
"""
|
||||
if not parsing_res_list:
|
||||
logger.debug("No parsing_res_list provided")
|
||||
return None
|
||||
|
||||
# Extract text elements from parsing results
|
||||
text_elements = self._extract_text_elements(parsing_res_list)
|
||||
logger.info(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
||||
|
||||
if len(text_elements) < 5: # Need at least a few elements
|
||||
logger.debug("Too few text elements for line item extraction")
|
||||
return None
|
||||
|
||||
return self.extract_from_text_elements(text_elements)
|
||||
|
||||
def extract_from_text_elements(
|
||||
self, text_elements: list[TextElement]
|
||||
) -> TextLineItemsResult | None:
|
||||
"""
|
||||
Extract line items from a list of text elements.
|
||||
|
||||
Args:
|
||||
text_elements: List of TextElement objects.
|
||||
|
||||
Returns:
|
||||
TextLineItemsResult if line items found, None otherwise.
|
||||
"""
|
||||
# Group elements by row
|
||||
rows = self._group_by_row(text_elements)
|
||||
logger.info(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
||||
|
||||
# Find the line items section
|
||||
item_rows = self._identify_line_item_rows(rows)
|
||||
logger.info(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
||||
|
||||
if len(item_rows) < self.min_items_for_valid:
|
||||
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
|
||||
return None
|
||||
|
||||
# Extract structured items
|
||||
items = self._parse_line_items(item_rows)
|
||||
logger.info(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
||||
|
||||
if len(items) < self.min_items_for_valid:
|
||||
return None
|
||||
|
||||
return TextLineItemsResult(
|
||||
items=items,
|
||||
header_row=[], # No explicit header in text-based extraction
|
||||
extraction_method="text_spatial",
|
||||
)
|
||||
|
||||
def _extract_text_elements(
|
||||
self, parsing_res_list: list[dict[str, Any]]
|
||||
) -> list[TextElement]:
|
||||
"""Extract TextElement objects from parsing_res_list."""
|
||||
elements = []
|
||||
|
||||
for elem in parsing_res_list:
|
||||
try:
|
||||
# Get label and bbox - handle both dict and LayoutBlock objects
|
||||
if isinstance(elem, dict):
|
||||
label = elem.get("label", "")
|
||||
bbox = elem.get("bbox", [])
|
||||
# Try both 'text' and 'content' keys
|
||||
text = elem.get("text", "") or elem.get("content", "")
|
||||
else:
|
||||
label = getattr(elem, "label", "")
|
||||
bbox = getattr(elem, "bbox", [])
|
||||
# LayoutBlock objects use 'content' attribute
|
||||
text = getattr(elem, "content", "") or getattr(elem, "text", "")
|
||||
|
||||
# Only process text elements (skip images, tables, etc.)
|
||||
if label not in ("text", "paragraph_title", "aside_text"):
|
||||
continue
|
||||
|
||||
# Validate bbox
|
||||
if not self._valid_bbox(bbox):
|
||||
continue
|
||||
|
||||
# Clean text
|
||||
text = str(text).strip() if text else ""
|
||||
if not text:
|
||||
continue
|
||||
|
||||
elements.append(
|
||||
TextElement(
|
||||
text=text,
|
||||
bbox=(
|
||||
float(bbox[0]),
|
||||
float(bbox[1]),
|
||||
float(bbox[2]),
|
||||
float(bbox[3]),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse element: {e}")
|
||||
continue
|
||||
|
||||
return elements
|
||||
|
||||
def _valid_bbox(self, bbox: Any) -> bool:
|
||||
"""Check if bbox is valid (has 4 elements)."""
|
||||
try:
|
||||
return len(bbox) >= 4 if hasattr(bbox, "__len__") else False
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
def _group_by_row(
|
||||
self, elements: list[TextElement]
|
||||
) -> list[list[TextElement]]:
|
||||
"""
|
||||
Group text elements into rows based on vertical position.
|
||||
|
||||
Elements within row_tolerance of each other are considered same row.
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
# Sort by vertical position
|
||||
sorted_elements = sorted(elements, key=lambda e: e.center_y)
|
||||
|
||||
rows = []
|
||||
current_row = [sorted_elements[0]]
|
||||
current_y = sorted_elements[0].center_y
|
||||
|
||||
for elem in sorted_elements[1:]:
|
||||
if abs(elem.center_y - current_y) <= self.row_tolerance:
|
||||
# Same row
|
||||
current_row.append(elem)
|
||||
else:
|
||||
# New row
|
||||
if current_row:
|
||||
# Sort row by horizontal position
|
||||
current_row.sort(key=lambda e: e.center_x)
|
||||
rows.append(current_row)
|
||||
current_row = [elem]
|
||||
current_y = elem.center_y
|
||||
|
||||
# Don't forget last row
|
||||
if current_row:
|
||||
current_row.sort(key=lambda e: e.center_x)
|
||||
rows.append(current_row)
|
||||
|
||||
return rows
|
||||
|
||||
def _identify_line_item_rows(
|
||||
self, rows: list[list[TextElement]]
|
||||
) -> list[list[TextElement]]:
|
||||
"""
|
||||
Identify which rows are likely line items.
|
||||
|
||||
Line item rows typically have:
|
||||
- Multiple elements per row
|
||||
- At least one amount-like value
|
||||
- Description text
|
||||
"""
|
||||
item_rows = []
|
||||
in_item_section = False
|
||||
|
||||
for row in rows:
|
||||
row_text = " ".join(e.text for e in row).lower()
|
||||
|
||||
# Check if we're entering summary section
|
||||
if any(kw in row_text for kw in SUMMARY_KEYWORDS):
|
||||
in_item_section = False
|
||||
continue
|
||||
|
||||
# Check if this looks like a header row
|
||||
if any(kw in row_text for kw in LINE_ITEM_KEYWORDS):
|
||||
in_item_section = True
|
||||
continue # Skip header row itself
|
||||
|
||||
# Check if row looks like a line item
|
||||
if in_item_section or self._looks_like_line_item(row):
|
||||
if self._looks_like_line_item(row):
|
||||
item_rows.append(row)
|
||||
|
||||
return item_rows
|
||||
|
||||
def _looks_like_line_item(self, row: list[TextElement]) -> bool:
|
||||
"""Check if a row looks like a line item."""
|
||||
if len(row) < 2:
|
||||
return False
|
||||
|
||||
row_text = " ".join(e.text for e in row)
|
||||
|
||||
# Must have at least one amount
|
||||
amounts = AMOUNT_PATTERN.findall(row_text)
|
||||
if not amounts:
|
||||
return False
|
||||
|
||||
# Should have some description text (not just numbers)
|
||||
has_description = any(
|
||||
len(e.text) > 3 and not AMOUNT_PATTERN.fullmatch(e.text.strip())
|
||||
for e in row
|
||||
)
|
||||
|
||||
return has_description
|
||||
|
||||
def _parse_line_items(
|
||||
self, item_rows: list[list[TextElement]]
|
||||
) -> list[TextLineItem]:
|
||||
"""Parse line item rows into structured items."""
|
||||
items = []
|
||||
|
||||
for idx, row in enumerate(item_rows):
|
||||
item = self._parse_single_row(row, idx)
|
||||
if item:
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _parse_single_row(
|
||||
self, row: list[TextElement], row_index: int
|
||||
) -> TextLineItem | None:
|
||||
"""Parse a single row into a line item."""
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# Combine all text for analysis
|
||||
all_text = " ".join(e.text for e in row)
|
||||
|
||||
# Find amounts (rightmost is usually the total)
|
||||
amounts = list(AMOUNT_PATTERN.finditer(all_text))
|
||||
if not amounts:
|
||||
return None
|
||||
|
||||
# Last amount is typically line total
|
||||
amount_match = amounts[-1]
|
||||
amount = amount_match.group(0).strip()
|
||||
|
||||
# Second to last might be unit price
|
||||
unit_price = None
|
||||
if len(amounts) >= 2:
|
||||
unit_price = amounts[-2].group(0).strip()
|
||||
|
||||
# Look for quantity
|
||||
quantity = None
|
||||
for elem in row:
|
||||
text = elem.text.strip()
|
||||
if QUANTITY_PATTERN.match(text):
|
||||
quantity = text
|
||||
break
|
||||
|
||||
# Look for VAT rate
|
||||
vat_rate = None
|
||||
vat_match = VAT_RATE_PATTERN.search(all_text)
|
||||
if vat_match:
|
||||
vat_rate = vat_match.group(1)
|
||||
|
||||
# Description is typically the longest non-numeric text
|
||||
description = None
|
||||
max_len = 0
|
||||
for elem in row:
|
||||
text = elem.text.strip()
|
||||
# Skip if it looks like a number/amount
|
||||
if AMOUNT_PATTERN.fullmatch(text):
|
||||
continue
|
||||
if QUANTITY_PATTERN.match(text):
|
||||
continue
|
||||
if len(text) > max_len:
|
||||
description = text
|
||||
max_len = len(text)
|
||||
|
||||
return TextLineItem(
|
||||
row_index=row_index,
|
||||
description=description,
|
||||
quantity=quantity,
|
||||
unit_price=unit_price,
|
||||
amount=amount,
|
||||
vat_rate=vat_rate,
|
||||
confidence=0.7,
|
||||
)
|
||||
|
||||
|
||||
def convert_text_line_item(item: TextLineItem) -> "LineItem":
|
||||
"""Convert TextLineItem to standard LineItem dataclass."""
|
||||
from .line_items_extractor import LineItem
|
||||
|
||||
return LineItem(
|
||||
row_index=item.row_index,
|
||||
description=item.description,
|
||||
quantity=item.quantity,
|
||||
unit=item.unit,
|
||||
unit_price=item.unit_price,
|
||||
amount=item.amount,
|
||||
article_number=item.article_number,
|
||||
vat_rate=item.vat_rate,
|
||||
is_deduction=item.is_deduction,
|
||||
confidence=item.confidence,
|
||||
)
|
||||
@@ -1,7 +1,19 @@
|
||||
"""
|
||||
Cross-validation module for verifying field extraction using LLM.
|
||||
Cross-validation module for verifying field extraction.
|
||||
|
||||
Includes LLM validation and VAT cross-validation.
|
||||
"""
|
||||
|
||||
from .llm_validator import LLMValidator
|
||||
from .vat_validator import (
|
||||
VATValidationResult,
|
||||
VATValidator,
|
||||
MathCheckResult,
|
||||
)
|
||||
|
||||
__all__ = ['LLMValidator']
|
||||
__all__ = [
|
||||
"LLMValidator",
|
||||
"VATValidationResult",
|
||||
"VATValidator",
|
||||
"MathCheckResult",
|
||||
]
|
||||
|
||||
267
packages/backend/backend/validation/vat_validator.py
Normal file
267
packages/backend/backend/validation/vat_validator.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
VAT Validator
|
||||
|
||||
Cross-validates VAT information from multiple sources:
|
||||
- Mathematical verification (base × rate = vat)
|
||||
- Line items vs VAT summary comparison
|
||||
- Consistency with existing amount field
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal, InvalidOperation
|
||||
|
||||
from backend.vat.vat_extractor import VATSummary, AmountParser
|
||||
from backend.table.line_items_extractor import LineItemsResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathCheckResult:
|
||||
"""Result of a single VAT rate mathematical check."""
|
||||
|
||||
rate: float
|
||||
base_amount: float | None
|
||||
expected_vat: float | None
|
||||
actual_vat: float
|
||||
is_valid: bool
|
||||
tolerance: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class VATValidationResult:
|
||||
"""Complete VAT validation result."""
|
||||
|
||||
is_valid: bool
|
||||
confidence_score: float # 0.0 - 1.0
|
||||
|
||||
# Mathematical verification
|
||||
math_checks: list[MathCheckResult]
|
||||
total_check: bool # incl = excl + total_vat?
|
||||
|
||||
# Source comparison
|
||||
line_items_vs_summary: bool | None # line items total = VAT summary?
|
||||
amount_consistency: bool | None # total_incl_vat = existing amount field?
|
||||
|
||||
# Review flags
|
||||
needs_review: bool
|
||||
review_reasons: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class VATValidator:
|
||||
"""Validates VAT information using multiple cross-checks."""
|
||||
|
||||
def __init__(self, tolerance: float = 0.02):
|
||||
"""
|
||||
Initialize validator.
|
||||
|
||||
Args:
|
||||
tolerance: Acceptable difference for math checks (default 0.02 = 2 cents)
|
||||
"""
|
||||
self.tolerance = tolerance
|
||||
self.amount_parser = AmountParser()
|
||||
|
||||
def validate(
|
||||
self,
|
||||
vat_summary: VATSummary,
|
||||
line_items: LineItemsResult | None = None,
|
||||
existing_amount: str | None = None,
|
||||
) -> VATValidationResult:
|
||||
"""
|
||||
Validate VAT information.
|
||||
|
||||
Args:
|
||||
vat_summary: Extracted VAT summary.
|
||||
line_items: Optional line items for comparison.
|
||||
existing_amount: Optional existing amount field from YOLO extraction.
|
||||
|
||||
Returns:
|
||||
VATValidationResult with all check results.
|
||||
"""
|
||||
review_reasons: list[str] = []
|
||||
|
||||
# Handle empty summary
|
||||
if not vat_summary.breakdowns and not vat_summary.total_vat:
|
||||
return VATValidationResult(
|
||||
is_valid=False,
|
||||
confidence_score=0.0,
|
||||
math_checks=[],
|
||||
total_check=False,
|
||||
line_items_vs_summary=None,
|
||||
amount_consistency=None,
|
||||
needs_review=True,
|
||||
review_reasons=["No VAT information found"],
|
||||
)
|
||||
|
||||
# Run all checks
|
||||
math_checks = self._run_math_checks(vat_summary)
|
||||
total_check = self._check_totals(vat_summary)
|
||||
line_items_check = self._check_line_items(vat_summary, line_items)
|
||||
amount_check = self._check_amount_consistency(vat_summary, existing_amount)
|
||||
|
||||
# Collect review reasons
|
||||
math_failures = [c for c in math_checks if not c.is_valid]
|
||||
if math_failures:
|
||||
review_reasons.append(f"Math check failed for {len(math_failures)} VAT rate(s)")
|
||||
|
||||
if not total_check:
|
||||
review_reasons.append("Total amount mismatch (excl + vat != incl)")
|
||||
|
||||
if line_items_check is False:
|
||||
review_reasons.append("Line items total doesn't match VAT summary")
|
||||
|
||||
if amount_check is False:
|
||||
review_reasons.append("VAT total doesn't match existing amount field")
|
||||
|
||||
# Calculate overall validity and confidence
|
||||
all_math_valid = all(c.is_valid for c in math_checks) if math_checks else True
|
||||
is_valid = all_math_valid and total_check and (amount_check is not False)
|
||||
|
||||
confidence_score = self._calculate_confidence(
|
||||
vat_summary, math_checks, total_check, line_items_check, amount_check
|
||||
)
|
||||
|
||||
needs_review = len(review_reasons) > 0 or confidence_score < 0.7
|
||||
|
||||
return VATValidationResult(
|
||||
is_valid=is_valid,
|
||||
confidence_score=confidence_score,
|
||||
math_checks=math_checks,
|
||||
total_check=total_check,
|
||||
line_items_vs_summary=line_items_check,
|
||||
amount_consistency=amount_check,
|
||||
needs_review=needs_review,
|
||||
review_reasons=review_reasons,
|
||||
)
|
||||
|
||||
def _run_math_checks(self, vat_summary: VATSummary) -> list[MathCheckResult]:
|
||||
"""Run mathematical verification for each VAT rate."""
|
||||
results = []
|
||||
|
||||
for breakdown in vat_summary.breakdowns:
|
||||
actual_vat = self.amount_parser.parse(breakdown.vat_amount)
|
||||
if actual_vat is None:
|
||||
continue
|
||||
|
||||
base_amount = None
|
||||
expected_vat = None
|
||||
is_valid = True
|
||||
|
||||
if breakdown.base_amount:
|
||||
base_amount = self.amount_parser.parse(breakdown.base_amount)
|
||||
if base_amount is not None:
|
||||
expected_vat = base_amount * (breakdown.rate / 100)
|
||||
is_valid = abs(expected_vat - actual_vat) <= self.tolerance
|
||||
|
||||
results.append(
|
||||
MathCheckResult(
|
||||
rate=breakdown.rate,
|
||||
base_amount=base_amount,
|
||||
expected_vat=expected_vat,
|
||||
actual_vat=actual_vat,
|
||||
is_valid=is_valid,
|
||||
tolerance=self.tolerance,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _check_totals(self, vat_summary: VATSummary) -> bool:
|
||||
"""Check if total_excl + total_vat = total_incl."""
|
||||
if not vat_summary.total_excl_vat or not vat_summary.total_incl_vat:
|
||||
# Can't verify without both values
|
||||
return True # Assume ok if we can't check
|
||||
|
||||
excl = self.amount_parser.parse(vat_summary.total_excl_vat)
|
||||
incl = self.amount_parser.parse(vat_summary.total_incl_vat)
|
||||
|
||||
if excl is None or incl is None:
|
||||
return True # Can't verify
|
||||
|
||||
# Calculate expected VAT
|
||||
if vat_summary.total_vat:
|
||||
vat = self.amount_parser.parse(vat_summary.total_vat)
|
||||
if vat is not None:
|
||||
expected_incl = excl + vat
|
||||
return abs(expected_incl - incl) <= self.tolerance
|
||||
# Can't verify if vat parsing failed
|
||||
return True
|
||||
else:
|
||||
# Sum up breakdown VAT amounts
|
||||
total_vat = sum(
|
||||
self.amount_parser.parse(b.vat_amount) or 0
|
||||
for b in vat_summary.breakdowns
|
||||
)
|
||||
expected_incl = excl + total_vat
|
||||
return abs(expected_incl - incl) <= self.tolerance
|
||||
|
||||
def _check_line_items(
|
||||
self, vat_summary: VATSummary, line_items: LineItemsResult | None
|
||||
) -> bool | None:
|
||||
"""Check if line items total matches VAT summary."""
|
||||
if line_items is None or not line_items.items:
|
||||
return None # No comparison possible
|
||||
|
||||
# Sum line item amounts
|
||||
line_total = 0.0
|
||||
for item in line_items.items:
|
||||
if item.amount:
|
||||
amount = self.amount_parser.parse(item.amount)
|
||||
if amount is not None:
|
||||
line_total += amount
|
||||
|
||||
# Compare with VAT summary total
|
||||
if vat_summary.total_excl_vat:
|
||||
summary_total = self.amount_parser.parse(vat_summary.total_excl_vat)
|
||||
if summary_total is not None:
|
||||
# Allow larger tolerance for line items (rounding errors)
|
||||
return abs(line_total - summary_total) <= 1.0
|
||||
|
||||
return None
|
||||
|
||||
def _check_amount_consistency(
|
||||
self, vat_summary: VATSummary, existing_amount: str | None
|
||||
) -> bool | None:
|
||||
"""Check if VAT total matches existing amount field."""
|
||||
if existing_amount is None:
|
||||
return None # No comparison possible
|
||||
|
||||
existing = self.amount_parser.parse(existing_amount)
|
||||
if existing is None:
|
||||
return None
|
||||
|
||||
if vat_summary.total_incl_vat:
|
||||
vat_total = self.amount_parser.parse(vat_summary.total_incl_vat)
|
||||
if vat_total is not None:
|
||||
return abs(existing - vat_total) <= self.tolerance
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
vat_summary: VATSummary,
|
||||
math_checks: list[MathCheckResult],
|
||||
total_check: bool,
|
||||
line_items_check: bool | None,
|
||||
amount_check: bool | None,
|
||||
) -> float:
|
||||
"""Calculate overall confidence score."""
|
||||
score = vat_summary.confidence # Start with extraction confidence
|
||||
|
||||
# Adjust based on validation results
|
||||
if math_checks:
|
||||
math_valid_ratio = sum(1 for c in math_checks if c.is_valid) / len(math_checks)
|
||||
score = score * (0.5 + 0.5 * math_valid_ratio)
|
||||
|
||||
if not total_check:
|
||||
score *= 0.5
|
||||
|
||||
if line_items_check is True:
|
||||
score = min(score * 1.1, 1.0) # Boost if line items match
|
||||
elif line_items_check is False:
|
||||
score *= 0.7
|
||||
|
||||
if amount_check is True:
|
||||
score = min(score * 1.1, 1.0) # Boost if amount matches
|
||||
elif amount_check is False:
|
||||
score *= 0.6
|
||||
|
||||
return round(score, 2)
|
||||
19
packages/backend/backend/vat/__init__.py
Normal file
19
packages/backend/backend/vat/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
VAT extraction module.
|
||||
|
||||
Extracts VAT (Moms) information from Swedish invoices using regex patterns.
|
||||
"""
|
||||
|
||||
from .vat_extractor import (
|
||||
VATBreakdown,
|
||||
VATSummary,
|
||||
VATExtractor,
|
||||
AmountParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VATBreakdown",
|
||||
"VATSummary",
|
||||
"VATExtractor",
|
||||
"AmountParser",
|
||||
]
|
||||
350
packages/backend/backend/vat/vat_extractor.py
Normal file
350
packages/backend/backend/vat/vat_extractor.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
VAT Extractor
|
||||
|
||||
Extracts VAT (Moms) information from Swedish invoice text using regex patterns.
|
||||
Supports multiple VAT rates (25%, 12%, 6%, 0%) and various Swedish formats.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
from decimal import Decimal, InvalidOperation
|
||||
|
||||
|
||||
@dataclass
|
||||
class VATBreakdown:
|
||||
"""Single VAT rate breakdown."""
|
||||
|
||||
rate: float # 25.0, 12.0, 6.0, 0.0
|
||||
base_amount: str | None # Tax base (excl VAT)
|
||||
vat_amount: str # VAT amount
|
||||
source: str # 'regex' | 'line_items'
|
||||
|
||||
|
||||
@dataclass
|
||||
class VATSummary:
|
||||
"""Complete VAT summary."""
|
||||
|
||||
breakdowns: list[VATBreakdown]
|
||||
total_excl_vat: str | None
|
||||
total_vat: str | None
|
||||
total_incl_vat: str | None
|
||||
confidence: float
|
||||
|
||||
|
||||
class AmountParser:
|
||||
"""Parse Swedish and European number formats."""
|
||||
|
||||
# Patterns to clean amount strings
|
||||
CURRENCY_PATTERN = re.compile(r"(SEK|kr|:-)\s*", re.IGNORECASE)
|
||||
|
||||
def parse(self, amount_str: str) -> float | None:
|
||||
"""
|
||||
Parse amount string to float.
|
||||
|
||||
Handles:
|
||||
- Swedish: 1 234,56
|
||||
- European: 1.234,56
|
||||
- US: 1,234.56
|
||||
|
||||
Args:
|
||||
amount_str: Amount string to parse.
|
||||
|
||||
Returns:
|
||||
Parsed float value or None if invalid.
|
||||
"""
|
||||
if not amount_str or not amount_str.strip():
|
||||
return None
|
||||
|
||||
# Clean the string
|
||||
cleaned = amount_str.strip()
|
||||
|
||||
# Remove currency
|
||||
cleaned = self.CURRENCY_PATTERN.sub("", cleaned).strip()
|
||||
cleaned = re.sub(r"^SEK\s*", "", cleaned, flags=re.IGNORECASE)
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
# Check for negative
|
||||
is_negative = cleaned.startswith("-")
|
||||
if is_negative:
|
||||
cleaned = cleaned[1:].strip()
|
||||
|
||||
try:
|
||||
# Remove spaces (Swedish thousands separator)
|
||||
cleaned = cleaned.replace(" ", "")
|
||||
|
||||
# Detect format
|
||||
# Swedish/European: comma is decimal separator
|
||||
# US: period is decimal separator
|
||||
has_comma = "," in cleaned
|
||||
has_period = "." in cleaned
|
||||
|
||||
if has_comma and has_period:
|
||||
# Both present - check position
|
||||
comma_pos = cleaned.rfind(",")
|
||||
period_pos = cleaned.rfind(".")
|
||||
|
||||
if comma_pos > period_pos:
|
||||
# European: 1.234,56
|
||||
cleaned = cleaned.replace(".", "")
|
||||
cleaned = cleaned.replace(",", ".")
|
||||
else:
|
||||
# US: 1,234.56
|
||||
cleaned = cleaned.replace(",", "")
|
||||
elif has_comma:
|
||||
# Swedish: 1234,56
|
||||
cleaned = cleaned.replace(",", ".")
|
||||
# else: US format or integer
|
||||
|
||||
value = float(cleaned)
|
||||
return -value if is_negative else value
|
||||
|
||||
except (ValueError, InvalidOperation):
|
||||
return None
|
||||
|
||||
|
||||
class VATExtractor:
|
||||
"""Extract VAT information from invoice text."""
|
||||
|
||||
# VAT extraction patterns
|
||||
# Note: Amount pattern uses [^\n] to avoid crossing line boundaries
|
||||
VAT_PATTERNS = [
|
||||
# Moms 25%: 2 500,00 or Moms 25% 2 500,00
|
||||
re.compile(
|
||||
r"[Mm]oms\s*(\d+(?:[,\.]\d+)?)\s*%\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
|
||||
re.MULTILINE,
|
||||
),
|
||||
# Varav moms 25% 2 500,00
|
||||
re.compile(
|
||||
r"[Vv]arav\s+moms\s+(\d+(?:[,\.]\d+)?)\s*%\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
|
||||
re.MULTILINE,
|
||||
),
|
||||
# 25% moms: 2 500,00 (at line start or after whitespace)
|
||||
re.compile(
|
||||
r"(?:^|\s)(\d+(?:[,\.]\d+)?)\s*%\s*moms\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
|
||||
re.MULTILINE,
|
||||
),
|
||||
# Moms (25%): 2 500,00
|
||||
re.compile(
|
||||
r"[Mm]oms\s*\((\d+(?:[,\.]\d+)?)\s*%\)\s*:?\s*([\d ,\.]+?)(?:\s*$|\s+[a-zA-Z])",
|
||||
re.MULTILINE,
|
||||
),
|
||||
]
|
||||
|
||||
# Pattern with base amount (Underlag)
|
||||
VAT_WITH_BASE_PATTERN = re.compile(
|
||||
r"[Mm]oms\s*(\d+(?:[,\.]\d+)?)\s*%\s*:?\s*([\d\s,\.]+)"
|
||||
r"(?:.*?[Uu]nderlag\s*([\d\s,\.]+))?",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
# Total patterns
|
||||
TOTAL_EXCL_PATTERN = re.compile(
|
||||
r"(?:[Ss]umma|[Tt]otal(?:t)?|[Nn]etto)\s*(?:exkl\.?\s*)?(?:moms)?\s*:?\s*([\d\s,\.]+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
TOTAL_VAT_PATTERN = re.compile(
|
||||
r"(?:[Ss]umma|[Tt]otal(?:t)?)\s*moms\s*:?\s*([\d\s,\.]+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
TOTAL_INCL_PATTERN = re.compile(
|
||||
r"(?:[Ss]umma|[Tt]otal(?:t)?|[Bb]rutto)\s*(?:inkl\.?\s*)?(?:moms|att\s*betala)?\s*:?\s*([\d\s,\.]+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.amount_parser = AmountParser()
|
||||
|
||||
def extract(self, text: str) -> VATSummary:
|
||||
"""
|
||||
Extract VAT information from text.
|
||||
|
||||
Args:
|
||||
text: Invoice text (OCR output).
|
||||
|
||||
Returns:
|
||||
VATSummary with extracted information.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return VATSummary(
|
||||
breakdowns=[],
|
||||
total_excl_vat=None,
|
||||
total_vat=None,
|
||||
total_incl_vat=None,
|
||||
confidence=0.0,
|
||||
)
|
||||
|
||||
breakdowns = self._extract_breakdowns(text)
|
||||
total_excl = self._extract_total_excl(text)
|
||||
total_vat = self._extract_total_vat(text)
|
||||
total_incl = self._extract_total_incl(text)
|
||||
|
||||
confidence = self._calculate_confidence(
|
||||
breakdowns, total_excl, total_vat, total_incl
|
||||
)
|
||||
|
||||
return VATSummary(
|
||||
breakdowns=breakdowns,
|
||||
total_excl_vat=total_excl,
|
||||
total_vat=total_vat,
|
||||
total_incl_vat=total_incl,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
def _extract_breakdowns(self, text: str) -> list[VATBreakdown]:
|
||||
"""Extract individual VAT rate breakdowns."""
|
||||
breakdowns = []
|
||||
seen_rates = set()
|
||||
|
||||
# Try pattern with base amount first
|
||||
for match in self.VAT_WITH_BASE_PATTERN.finditer(text):
|
||||
rate = self._parse_rate(match.group(1))
|
||||
vat_amount = self._clean_amount(match.group(2))
|
||||
base_amount = (
|
||||
self._clean_amount(match.group(3)) if match.group(3) else None
|
||||
)
|
||||
|
||||
if rate is not None and vat_amount and rate not in seen_rates:
|
||||
seen_rates.add(rate)
|
||||
breakdowns.append(
|
||||
VATBreakdown(
|
||||
rate=rate,
|
||||
base_amount=base_amount,
|
||||
vat_amount=vat_amount,
|
||||
source="regex",
|
||||
)
|
||||
)
|
||||
|
||||
# Try other patterns
|
||||
for pattern in self.VAT_PATTERNS:
|
||||
for match in pattern.finditer(text):
|
||||
rate = self._parse_rate(match.group(1))
|
||||
vat_amount = self._clean_amount(match.group(2))
|
||||
|
||||
if rate is not None and vat_amount and rate not in seen_rates:
|
||||
seen_rates.add(rate)
|
||||
breakdowns.append(
|
||||
VATBreakdown(
|
||||
rate=rate,
|
||||
base_amount=None,
|
||||
vat_amount=vat_amount,
|
||||
source="regex",
|
||||
)
|
||||
)
|
||||
|
||||
return breakdowns
|
||||
|
||||
def _extract_total_excl(self, text: str) -> str | None:
|
||||
"""Extract total excluding VAT."""
|
||||
# Look for specific patterns first
|
||||
patterns = [
|
||||
re.compile(r"[Ss]umma\s+exkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
|
||||
re.compile(r"[Nn]etto\s*:?\s*([\d\s,\.]+)"),
|
||||
re.compile(r"[Ee]xkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
return self._clean_amount(match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _extract_total_vat(self, text: str) -> str | None:
|
||||
"""Extract total VAT amount."""
|
||||
patterns = [
|
||||
re.compile(r"[Ss]umma\s+moms\s*:?\s*([\d\s,\.]+)"),
|
||||
re.compile(r"[Tt]otal(?:t)?\s+moms\s*:?\s*([\d\s,\.]+)"),
|
||||
# Generic "Moms:" without percentage
|
||||
re.compile(r"^[Mm]oms\s*:?\s*([\d\s,\.]+)", re.MULTILINE),
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
return self._clean_amount(match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _extract_total_incl(self, text: str) -> str | None:
|
||||
"""Extract total including VAT."""
|
||||
patterns = [
|
||||
re.compile(r"[Ss]umma\s+inkl\.?\s*moms\s*:?\s*([\d\s,\.]+)"),
|
||||
re.compile(r"[Tt]otal(?:t)?\s+att\s+betala\s*:?\s*([\d\s,\.]+)"),
|
||||
re.compile(r"[Bb]rutto\s*:?\s*([\d\s,\.]+)"),
|
||||
re.compile(r"[Aa]tt\s+betala\s*:?\s*([\d\s,\.]+)"),
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
return self._clean_amount(match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _parse_rate(self, rate_str: str) -> float | None:
|
||||
"""Parse VAT rate string to float."""
|
||||
try:
|
||||
rate_str = rate_str.replace(",", ".")
|
||||
return float(rate_str)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _clean_amount(self, amount_str: str) -> str | None:
|
||||
"""Clean and validate amount string."""
|
||||
if not amount_str:
|
||||
return None
|
||||
|
||||
cleaned = amount_str.strip()
|
||||
|
||||
# Remove trailing non-numeric chars (except comma/period)
|
||||
cleaned = re.sub(r"[^\d\s,\.]+$", "", cleaned).strip()
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
# Validate it parses as a number
|
||||
if self.amount_parser.parse(cleaned) is None:
|
||||
return None
|
||||
|
||||
return cleaned
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
breakdowns: list[VATBreakdown],
|
||||
total_excl: str | None,
|
||||
total_vat: str | None,
|
||||
total_incl: str | None,
|
||||
) -> float:
|
||||
"""Calculate confidence score based on extracted data."""
|
||||
score = 0.0
|
||||
|
||||
# Has VAT breakdowns
|
||||
if breakdowns:
|
||||
score += 0.3
|
||||
|
||||
# Has total excluding VAT
|
||||
if total_excl:
|
||||
score += 0.2
|
||||
|
||||
# Has total VAT
|
||||
if total_vat:
|
||||
score += 0.2
|
||||
|
||||
# Has total including VAT
|
||||
if total_incl:
|
||||
score += 0.15
|
||||
|
||||
# Mathematical consistency check
|
||||
if total_excl and total_vat and total_incl:
|
||||
excl = self.amount_parser.parse(total_excl)
|
||||
vat = self.amount_parser.parse(total_vat)
|
||||
incl = self.amount_parser.parse(total_incl)
|
||||
|
||||
if excl and vat and incl:
|
||||
expected = excl + vat
|
||||
if abs(expected - incl) < 0.02: # Allow 2 cent tolerance
|
||||
score += 0.15
|
||||
|
||||
return min(score, 1.0)
|
||||
@@ -12,7 +12,7 @@ import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from backend.web.schemas.inference import (
|
||||
@@ -20,6 +20,12 @@ from backend.web.schemas.inference import (
|
||||
HealthResponse,
|
||||
InferenceResponse,
|
||||
InferenceResult,
|
||||
LineItemSchema,
|
||||
LineItemsResultSchema,
|
||||
MathCheckResultSchema,
|
||||
VATBreakdownSchema,
|
||||
VATSummarySchema,
|
||||
VATValidationResultSchema,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
@@ -67,12 +73,21 @@ def create_inference_router(
|
||||
)
|
||||
async def infer_document(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
extract_line_items: bool = Form(
|
||||
default=False,
|
||||
description="Extract line items and VAT information (business features)",
|
||||
),
|
||||
) -> InferenceResponse:
|
||||
"""
|
||||
Process a document and extract invoice fields.
|
||||
|
||||
Accepts PDF or image files (PNG, JPG, JPEG).
|
||||
Returns extracted field values with confidence scores.
|
||||
|
||||
When extract_line_items=True, also extracts:
|
||||
- Line items (products/services with quantities and amounts)
|
||||
- VAT summary (multiple tax rates breakdown)
|
||||
- VAT validation (cross-validation results)
|
||||
"""
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
@@ -116,7 +131,9 @@ def create_inference_router(
|
||||
# Process based on file type
|
||||
if file_ext == ".pdf":
|
||||
service_result = inference_service.process_pdf(
|
||||
upload_path, document_id=doc_id
|
||||
upload_path,
|
||||
document_id=doc_id,
|
||||
extract_line_items=extract_line_items,
|
||||
)
|
||||
else:
|
||||
service_result = inference_service.process_image(
|
||||
@@ -128,6 +145,39 @@ def create_inference_router(
|
||||
if service_result.visualization_path:
|
||||
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
|
||||
|
||||
# Build business features schemas if present
|
||||
line_items_schema = None
|
||||
vat_summary_schema = None
|
||||
vat_validation_schema = None
|
||||
|
||||
if service_result.line_items:
|
||||
line_items_schema = LineItemsResultSchema(
|
||||
items=[LineItemSchema(**item) for item in service_result.line_items.get("items", [])],
|
||||
header_row=service_result.line_items.get("header_row", []),
|
||||
total_amount=service_result.line_items.get("total_amount"),
|
||||
)
|
||||
|
||||
if service_result.vat_summary:
|
||||
vat_summary_schema = VATSummarySchema(
|
||||
breakdowns=[VATBreakdownSchema(**b) for b in service_result.vat_summary.get("breakdowns", [])],
|
||||
total_excl_vat=service_result.vat_summary.get("total_excl_vat"),
|
||||
total_vat=service_result.vat_summary.get("total_vat"),
|
||||
total_incl_vat=service_result.vat_summary.get("total_incl_vat"),
|
||||
confidence=service_result.vat_summary.get("confidence", 0.0),
|
||||
)
|
||||
|
||||
if service_result.vat_validation:
|
||||
vat_validation_schema = VATValidationResultSchema(
|
||||
is_valid=service_result.vat_validation.get("is_valid", False),
|
||||
confidence_score=service_result.vat_validation.get("confidence_score", 0.0),
|
||||
math_checks=[MathCheckResultSchema(**m) for m in service_result.vat_validation.get("math_checks", [])],
|
||||
total_check=service_result.vat_validation.get("total_check", False),
|
||||
line_items_vs_summary=service_result.vat_validation.get("line_items_vs_summary"),
|
||||
amount_consistency=service_result.vat_validation.get("amount_consistency"),
|
||||
needs_review=service_result.vat_validation.get("needs_review", False),
|
||||
review_reasons=service_result.vat_validation.get("review_reasons", []),
|
||||
)
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=service_result.document_id,
|
||||
success=service_result.success,
|
||||
@@ -140,6 +190,9 @@ def create_inference_router(
|
||||
processing_time_ms=service_result.processing_time_ms,
|
||||
visualization_url=viz_url,
|
||||
errors=service_result.errors,
|
||||
line_items=line_items_schema,
|
||||
vat_summary=vat_summary_schema,
|
||||
vat_validation=vat_validation_schema,
|
||||
)
|
||||
|
||||
return InferenceResponse(
|
||||
|
||||
@@ -69,6 +69,17 @@ class InferenceResult(BaseModel):
|
||||
)
|
||||
errors: list[str] = Field(default_factory=list, description="Error messages")
|
||||
|
||||
# Business features (optional, only when extract_line_items=True)
|
||||
line_items: "LineItemsResultSchema | None" = Field(
|
||||
None, description="Extracted line items (when extract_line_items=True)"
|
||||
)
|
||||
vat_summary: "VATSummarySchema | None" = Field(
|
||||
None, description="VAT summary (when extract_line_items=True)"
|
||||
)
|
||||
vat_validation: "VATValidationResultSchema | None" = Field(
|
||||
None, description="VAT validation result (when extract_line_items=True)"
|
||||
)
|
||||
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
"""API response for inference endpoint."""
|
||||
@@ -194,3 +205,90 @@ class RateLimitInfo(BaseModel):
|
||||
limit: int = Field(..., description="Maximum requests per minute")
|
||||
remaining: int = Field(..., description="Remaining requests in current window")
|
||||
reset_at: datetime = Field(..., description="Time when limit resets")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Business Features Schemas (Line Items, VAT)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LineItemSchema(BaseModel):
|
||||
"""Single line item from invoice."""
|
||||
|
||||
row_index: int = Field(..., description="Row index in the table")
|
||||
description: str | None = Field(None, description="Product/service description")
|
||||
quantity: str | None = Field(None, description="Quantity")
|
||||
unit: str | None = Field(None, description="Unit (st, pcs, etc.)")
|
||||
unit_price: str | None = Field(None, description="Price per unit")
|
||||
amount: str | None = Field(None, description="Line total amount")
|
||||
article_number: str | None = Field(None, description="Article/product number")
|
||||
vat_rate: str | None = Field(None, description="VAT rate (e.g., '25')")
|
||||
is_deduction: bool = Field(default=False, description="True if this row is a deduction/discount (avdrag/rabatt)")
|
||||
confidence: float = Field(default=0.0, ge=0, le=1, description="Extraction confidence")
|
||||
|
||||
|
||||
class LineItemsResultSchema(BaseModel):
|
||||
"""Line items extraction result."""
|
||||
|
||||
items: list[LineItemSchema] = Field(default_factory=list, description="Extracted line items")
|
||||
header_row: list[str] = Field(default_factory=list, description="Table header row")
|
||||
total_amount: str | None = Field(None, description="Calculated total from line items")
|
||||
|
||||
|
||||
class VATBreakdownSchema(BaseModel):
|
||||
"""Single VAT rate breakdown."""
|
||||
|
||||
rate: float = Field(..., description="VAT rate (e.g., 25.0, 12.0, 6.0)")
|
||||
base_amount: str | None = Field(None, description="Tax base amount (excluding VAT)")
|
||||
vat_amount: str | None = Field(None, description="VAT amount")
|
||||
source: str = Field(default="regex", description="Extraction source (regex or line_items)")
|
||||
|
||||
|
||||
class VATSummarySchema(BaseModel):
|
||||
"""VAT summary information."""
|
||||
|
||||
breakdowns: list[VATBreakdownSchema] = Field(
|
||||
default_factory=list, description="VAT breakdowns by rate"
|
||||
)
|
||||
total_excl_vat: str | None = Field(None, description="Total excluding VAT")
|
||||
total_vat: str | None = Field(None, description="Total VAT amount")
|
||||
total_incl_vat: str | None = Field(None, description="Total including VAT")
|
||||
confidence: float = Field(default=0.0, ge=0, le=1, description="Extraction confidence")
|
||||
|
||||
|
||||
class MathCheckResultSchema(BaseModel):
|
||||
"""Single math validation check result."""
|
||||
|
||||
rate: float = Field(..., description="VAT rate checked")
|
||||
base_amount: float | None = Field(None, description="Base amount")
|
||||
expected_vat: float | None = Field(None, description="Expected VAT (base * rate)")
|
||||
actual_vat: float | None = Field(None, description="Actual VAT from invoice")
|
||||
is_valid: bool = Field(..., description="Whether math check passed")
|
||||
tolerance: float = Field(..., description="Tolerance used for comparison")
|
||||
|
||||
|
||||
class VATValidationResultSchema(BaseModel):
|
||||
"""VAT cross-validation result."""
|
||||
|
||||
is_valid: bool = Field(..., description="Overall validation status")
|
||||
confidence_score: float = Field(
|
||||
..., ge=0, le=1, description="Validation confidence score"
|
||||
)
|
||||
math_checks: list[MathCheckResultSchema] = Field(
|
||||
default_factory=list, description="Math check results per VAT rate"
|
||||
)
|
||||
total_check: bool = Field(default=False, description="Whether total calculation is valid")
|
||||
line_items_vs_summary: bool | None = Field(
|
||||
None, description="Whether line items match VAT summary"
|
||||
)
|
||||
amount_consistency: bool | None = Field(
|
||||
None, description="Whether total matches detected amount field"
|
||||
)
|
||||
needs_review: bool = Field(default=False, description="Whether manual review is recommended")
|
||||
review_reasons: list[str] = Field(
|
||||
default_factory=list, description="Reasons for manual review"
|
||||
)
|
||||
|
||||
|
||||
# Rebuild models to resolve forward references
|
||||
InferenceResult.model_rebuild()
|
||||
|
||||
@@ -42,6 +42,11 @@ class ServiceResult:
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
# Business features (optional, populated when extract_line_items=True)
|
||||
line_items: dict | None = None
|
||||
vat_summary: dict | None = None
|
||||
vat_validation: dict | None = None
|
||||
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
@@ -74,6 +79,7 @@ class InferenceService:
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
self._business_features_enabled = False
|
||||
|
||||
def _resolve_model_path(self) -> Path:
|
||||
"""Resolve the model path to use for inference.
|
||||
@@ -95,12 +101,16 @@ class InferenceService:
|
||||
|
||||
return self.model_config.model_path
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
def initialize(self, enable_business_features: bool = False) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading).
|
||||
|
||||
Args:
|
||||
enable_business_features: Whether to enable line items and VAT extraction
|
||||
"""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing inference service...")
|
||||
logger.info(f"Initializing inference service (business_features={enable_business_features})...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
@@ -118,16 +128,18 @@ class InferenceService:
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
# Initialize full pipeline with optional business features
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
enable_fallback=True,
|
||||
enable_business_features=enable_business_features,
|
||||
)
|
||||
|
||||
self._is_initialized = True
|
||||
self._business_features_enabled = enable_business_features
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s with model: {model_path}")
|
||||
|
||||
@@ -242,6 +254,7 @@ class InferenceService:
|
||||
pdf_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
extract_line_items: bool = False,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process a PDF file and extract invoice fields.
|
||||
@@ -250,12 +263,17 @@ class InferenceService:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
extract_line_items: Whether to extract line items and VAT info
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
self.initialize(enable_business_features=extract_line_items)
|
||||
elif extract_line_items and not self._business_features_enabled:
|
||||
# Reinitialize with business features if needed
|
||||
self._is_initialized = False
|
||||
self.initialize(enable_business_features=True)
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
@@ -263,8 +281,12 @@ class InferenceService:
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
|
||||
# Run inference pipeline with optional business features
|
||||
pipeline_result = self._pipeline.process_pdf(
|
||||
pdf_path,
|
||||
document_id=doc_id,
|
||||
extract_line_items=extract_line_items,
|
||||
)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
@@ -288,6 +310,12 @@ class InferenceService:
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Include business features if extracted
|
||||
if extract_line_items:
|
||||
result.line_items = pipeline_result._line_items_to_json() if pipeline_result.line_items else None
|
||||
result.vat_summary = pipeline_result._vat_summary_to_json() if pipeline_result.vat_summary else None
|
||||
result.vat_validation = pipeline_result._vat_validation_to_json() if pipeline_result.vat_validation else None
|
||||
|
||||
# Save visualization (render first page)
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ setup(
|
||||
name="invoice-backend",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.11",
|
||||
python_requires=">=3.10", # 3.10 for RTX 50 series SM120 wheel
|
||||
install_requires=[
|
||||
"invoice-shared",
|
||||
"fastapi>=0.104.0",
|
||||
|
||||
Reference in New Issue
Block a user