Project targets Swedish invoice extraction. PaddleOCR sv model provides better recognition of Swedish-specific characters (å, ä, ö).
838 lines
34 KiB
Python
838 lines
34 KiB
Python
"""
|
|
Inference Pipeline
|
|
|
|
Complete pipeline for extracting invoice data from PDFs.
|
|
Supports both basic field extraction and business invoice features
|
|
(line items, VAT extraction, cross-validation).
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
import logging
|
|
import time
|
|
import re
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from shared.fields import CLASS_TO_FIELD
|
|
from .yolo_detector import YOLODetector, Detection
|
|
from .field_extractor import FieldExtractor, ExtractedField
|
|
from .payment_line_parser import PaymentLineParser
|
|
|
|
# Business invoice feature imports (optional - for extract_line_items mode)
|
|
try:
|
|
from ..table.line_items_extractor import LineItem, LineItemsResult, LineItemsExtractor
|
|
from ..table.structure_detector import TableDetector
|
|
from ..vat.vat_extractor import VATSummary, VATExtractor
|
|
from ..validation.vat_validator import VATValidationResult, VATValidator
|
|
BUSINESS_FEATURES_AVAILABLE = True
|
|
except ImportError:
|
|
BUSINESS_FEATURES_AVAILABLE = False
|
|
LineItem = None
|
|
LineItemsResult = None
|
|
TableDetector = None
|
|
VATSummary = None
|
|
VATValidationResult = None
|
|
|
|
|
|
@dataclass
|
|
class CrossValidationResult:
|
|
"""Result of cross-validation between payment_line and other fields."""
|
|
is_valid: bool = False
|
|
ocr_match: bool | None = None # None if not comparable
|
|
amount_match: bool | None = None
|
|
bankgiro_match: bool | None = None
|
|
plusgiro_match: bool | None = None
|
|
payment_line_ocr: str | None = None
|
|
payment_line_amount: str | None = None
|
|
payment_line_account: str | None = None
|
|
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
|
|
details: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class InferenceResult:
|
|
"""Result of invoice processing."""
|
|
document_id: str | None = None
|
|
success: bool = False
|
|
fields: dict[str, Any] = field(default_factory=dict)
|
|
confidence: dict[str, float] = field(default_factory=dict)
|
|
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
|
|
raw_detections: list[Detection] = field(default_factory=list)
|
|
extracted_fields: list[ExtractedField] = field(default_factory=list)
|
|
processing_time_ms: float = 0.0
|
|
errors: list[str] = field(default_factory=list)
|
|
fallback_used: bool = False
|
|
cross_validation: CrossValidationResult | None = None
|
|
# Business invoice features (optional)
|
|
line_items: Any | None = None # LineItemsResult when available
|
|
vat_summary: Any | None = None # VATSummary when available
|
|
vat_validation: Any | None = None # VATValidationResult when available
|
|
|
|
def to_json(self) -> dict:
|
|
"""Convert to JSON-serializable dictionary."""
|
|
result = {
|
|
'DocumentId': self.document_id,
|
|
'InvoiceNumber': self.fields.get('InvoiceNumber'),
|
|
'InvoiceDate': self.fields.get('InvoiceDate'),
|
|
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
|
|
'OCR': self.fields.get('OCR'),
|
|
'Bankgiro': self.fields.get('Bankgiro'),
|
|
'Plusgiro': self.fields.get('Plusgiro'),
|
|
'Amount': self.fields.get('Amount'),
|
|
'supplier_org_number': self.fields.get('supplier_org_number'),
|
|
'customer_number': self.fields.get('customer_number'),
|
|
'payment_line': self.fields.get('payment_line'),
|
|
'confidence': self.confidence,
|
|
'success': self.success,
|
|
'fallback_used': self.fallback_used
|
|
}
|
|
# Add bboxes if present
|
|
if self.bboxes:
|
|
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
|
|
# Add cross-validation results if present
|
|
if self.cross_validation:
|
|
result['cross_validation'] = {
|
|
'is_valid': self.cross_validation.is_valid,
|
|
'ocr_match': self.cross_validation.ocr_match,
|
|
'amount_match': self.cross_validation.amount_match,
|
|
'bankgiro_match': self.cross_validation.bankgiro_match,
|
|
'plusgiro_match': self.cross_validation.plusgiro_match,
|
|
'payment_line_ocr': self.cross_validation.payment_line_ocr,
|
|
'payment_line_amount': self.cross_validation.payment_line_amount,
|
|
'payment_line_account': self.cross_validation.payment_line_account,
|
|
'payment_line_account_type': self.cross_validation.payment_line_account_type,
|
|
'details': self.cross_validation.details,
|
|
}
|
|
|
|
# Add business invoice features if present
|
|
if self.line_items is not None:
|
|
result['line_items'] = self._line_items_to_json()
|
|
if self.vat_summary is not None:
|
|
result['vat_summary'] = self._vat_summary_to_json()
|
|
if self.vat_validation is not None:
|
|
result['vat_validation'] = self._vat_validation_to_json()
|
|
|
|
return result
|
|
|
|
def _line_items_to_json(self) -> dict | None:
|
|
"""Convert LineItemsResult to JSON."""
|
|
if self.line_items is None:
|
|
return None
|
|
li = self.line_items
|
|
return {
|
|
'items': [
|
|
{
|
|
'row_index': item.row_index,
|
|
'description': item.description,
|
|
'quantity': item.quantity,
|
|
'unit': item.unit,
|
|
'unit_price': item.unit_price,
|
|
'amount': item.amount,
|
|
'article_number': item.article_number,
|
|
'vat_rate': item.vat_rate,
|
|
'is_deduction': item.is_deduction,
|
|
'confidence': item.confidence,
|
|
}
|
|
for item in li.items
|
|
],
|
|
'header_row': li.header_row,
|
|
'total_amount': li.total_amount,
|
|
}
|
|
|
|
def _vat_summary_to_json(self) -> dict | None:
|
|
"""Convert VATSummary to JSON."""
|
|
if self.vat_summary is None:
|
|
return None
|
|
vs = self.vat_summary
|
|
return {
|
|
'breakdowns': [
|
|
{
|
|
'rate': b.rate,
|
|
'base_amount': b.base_amount,
|
|
'vat_amount': b.vat_amount,
|
|
'source': b.source,
|
|
}
|
|
for b in vs.breakdowns
|
|
],
|
|
'total_excl_vat': vs.total_excl_vat,
|
|
'total_vat': vs.total_vat,
|
|
'total_incl_vat': vs.total_incl_vat,
|
|
'confidence': vs.confidence,
|
|
}
|
|
|
|
def _vat_validation_to_json(self) -> dict | None:
|
|
"""Convert VATValidationResult to JSON."""
|
|
if self.vat_validation is None:
|
|
return None
|
|
vv = self.vat_validation
|
|
return {
|
|
'is_valid': vv.is_valid,
|
|
'confidence_score': vv.confidence_score,
|
|
'math_checks': [
|
|
{
|
|
'rate': mc.rate,
|
|
'base_amount': mc.base_amount,
|
|
'expected_vat': mc.expected_vat,
|
|
'actual_vat': mc.actual_vat,
|
|
'is_valid': mc.is_valid,
|
|
'tolerance': mc.tolerance,
|
|
}
|
|
for mc in vv.math_checks
|
|
],
|
|
'total_check': vv.total_check,
|
|
'line_items_vs_summary': vv.line_items_vs_summary,
|
|
'amount_consistency': vv.amount_consistency,
|
|
'needs_review': vv.needs_review,
|
|
'review_reasons': vv.review_reasons,
|
|
}
|
|
|
|
def get_field(self, field_name: str) -> tuple[Any, float]:
|
|
"""Get field value and confidence."""
|
|
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
|
|
|
|
|
|
class InferencePipeline:
|
|
"""
|
|
Complete inference pipeline for invoice data extraction.
|
|
|
|
Pipeline flow:
|
|
1. PDF -> Image rendering
|
|
2. YOLO detection of field regions
|
|
3. OCR extraction from detected regions
|
|
4. Field normalization and validation
|
|
5. Fallback to full-page OCR if YOLO fails
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str | Path,
|
|
confidence_threshold: float = 0.5,
|
|
ocr_lang: str = 'sv',
|
|
use_gpu: bool = False,
|
|
dpi: int = 300,
|
|
enable_fallback: bool = True,
|
|
enable_business_features: bool = False,
|
|
vat_tolerance: float = 0.5
|
|
):
|
|
"""
|
|
Initialize inference pipeline.
|
|
|
|
Args:
|
|
model_path: Path to trained YOLO model
|
|
confidence_threshold: Detection confidence threshold
|
|
ocr_lang: Language for OCR
|
|
use_gpu: Whether to use GPU
|
|
dpi: Resolution for PDF rendering
|
|
enable_fallback: Enable fallback to full-page OCR
|
|
enable_business_features: Enable line items/VAT extraction
|
|
vat_tolerance: Tolerance for VAT math checks (in currency units)
|
|
"""
|
|
self.detector = YOLODetector(
|
|
model_path,
|
|
confidence_threshold=confidence_threshold,
|
|
device='cuda' if use_gpu else 'cpu'
|
|
)
|
|
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
|
|
self.payment_line_parser = PaymentLineParser()
|
|
self.dpi = dpi
|
|
self.enable_fallback = enable_fallback
|
|
self.enable_business_features = enable_business_features
|
|
self.vat_tolerance = vat_tolerance
|
|
|
|
# Initialize business feature components if enabled and available
|
|
self.line_items_extractor = None
|
|
self.vat_extractor = None
|
|
self.vat_validator = None
|
|
self._business_ocr_engine = None # Lazy-initialized for VAT text extraction
|
|
self._table_detector = None # Shared TableDetector for line items extraction
|
|
|
|
if enable_business_features:
|
|
if not BUSINESS_FEATURES_AVAILABLE:
|
|
raise ImportError(
|
|
"Business features require table, vat, and validation modules. "
|
|
"Please ensure they are properly installed."
|
|
)
|
|
# Create shared TableDetector for performance (PP-StructureV3 init is slow)
|
|
self._table_detector = TableDetector()
|
|
# Pass shared detector to LineItemsExtractor
|
|
self.line_items_extractor = LineItemsExtractor(table_detector=self._table_detector)
|
|
self.vat_extractor = VATExtractor()
|
|
self.vat_validator = VATValidator(tolerance=vat_tolerance)
|
|
|
|
def process_pdf(
|
|
self,
|
|
pdf_path: str | Path,
|
|
document_id: str | None = None,
|
|
extract_line_items: bool | None = None
|
|
) -> InferenceResult:
|
|
"""
|
|
Process a PDF and extract invoice fields.
|
|
|
|
Args:
|
|
pdf_path: Path to PDF file
|
|
document_id: Optional document ID
|
|
extract_line_items: Whether to extract line items and VAT info.
|
|
If None, uses the enable_business_features setting from __init__.
|
|
|
|
Returns:
|
|
InferenceResult with extracted fields
|
|
"""
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
from PIL import Image
|
|
import io
|
|
import numpy as np
|
|
|
|
start_time = time.time()
|
|
|
|
result = InferenceResult(
|
|
document_id=document_id or Path(pdf_path).stem
|
|
)
|
|
|
|
# Determine if business features should be used
|
|
use_business_features = (
|
|
extract_line_items if extract_line_items is not None
|
|
else self.enable_business_features
|
|
)
|
|
|
|
try:
|
|
all_detections = []
|
|
all_extracted = []
|
|
all_ocr_text = [] # Collect OCR text for VAT extraction
|
|
|
|
# Check if PDF has readable text layer (avoids OCR for text PDFs)
|
|
from shared.pdf.extractor import PDFDocument
|
|
is_text_pdf = False
|
|
pdf_tokens_by_page: dict[int, list] = {}
|
|
try:
|
|
with PDFDocument(pdf_path) as pdf_doc:
|
|
is_text_pdf = pdf_doc.is_text_pdf()
|
|
if is_text_pdf:
|
|
for pg in range(pdf_doc.page_count):
|
|
pdf_tokens_by_page[pg] = list(
|
|
pdf_doc.extract_text_tokens(pg)
|
|
)
|
|
logger.info(
|
|
"Text PDF detected, extracted %d tokens from %d pages",
|
|
sum(len(t) for t in pdf_tokens_by_page.values()),
|
|
len(pdf_tokens_by_page),
|
|
)
|
|
except Exception as e:
|
|
logger.warning("PDF text detection failed, falling back to OCR: %s", e)
|
|
is_text_pdf = False
|
|
|
|
# Process each page
|
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
|
# Convert to numpy array
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
image_array = np.array(image)
|
|
|
|
# Run YOLO detection
|
|
detections = self.detector.detect(image_array, page_no=page_no)
|
|
all_detections.extend(detections)
|
|
|
|
# Extract fields from detections
|
|
for detection in detections:
|
|
if is_text_pdf and page_no in pdf_tokens_by_page:
|
|
extracted = self.extractor.extract_from_detection_with_pdf(
|
|
detection,
|
|
pdf_tokens_by_page[page_no],
|
|
image.width,
|
|
image.height,
|
|
)
|
|
else:
|
|
extracted = self.extractor.extract_from_detection(
|
|
detection, image_array
|
|
)
|
|
all_extracted.append(extracted)
|
|
|
|
# Collect full-page OCR text for VAT extraction (only if business features enabled)
|
|
if use_business_features:
|
|
page_text = self._get_full_page_text(image_array)
|
|
all_ocr_text.append(page_text)
|
|
|
|
result.raw_detections = all_detections
|
|
result.extracted_fields = all_extracted
|
|
|
|
# Merge extracted fields (prefer highest confidence)
|
|
self._merge_fields(result)
|
|
|
|
# Fallback if key fields are missing
|
|
if self.enable_fallback and self._needs_fallback(result):
|
|
self._run_fallback(pdf_path, result)
|
|
self._dedup_invoice_number(result)
|
|
|
|
# Extract business invoice features if enabled
|
|
if use_business_features:
|
|
self._extract_business_features(pdf_path, result, '\n'.join(all_ocr_text))
|
|
|
|
result.success = len(result.fields) > 0
|
|
|
|
except Exception as e:
|
|
result.errors.append(str(e))
|
|
result.success = False
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
return result
|
|
|
|
def _get_full_page_text(self, image_array) -> str:
|
|
"""Extract full page text using OCR for VAT extraction."""
|
|
from shared.ocr import OCREngine
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
# Lazy initialize OCR engine to avoid repeated model loading
|
|
if self._business_ocr_engine is None:
|
|
self._business_ocr_engine = OCREngine()
|
|
|
|
tokens = self._business_ocr_engine.extract_from_image(image_array, page_no=0)
|
|
return ' '.join(t.text for t in tokens)
|
|
except Exception as e:
|
|
logger.warning(f"OCR extraction for VAT failed: {e}")
|
|
return ""
|
|
|
|
def _extract_business_features(
|
|
self,
|
|
pdf_path: str | Path,
|
|
result: InferenceResult,
|
|
full_text: str
|
|
) -> None:
|
|
"""
|
|
Extract line items, VAT summary, and perform cross-validation.
|
|
|
|
Args:
|
|
pdf_path: Path to PDF file
|
|
result: InferenceResult to populate
|
|
full_text: Full OCR text from all pages
|
|
"""
|
|
if not BUSINESS_FEATURES_AVAILABLE:
|
|
result.errors.append("Business features not available")
|
|
return
|
|
|
|
if not self.line_items_extractor or not self.vat_extractor or not self.vat_validator:
|
|
result.errors.append("Business feature extractors not initialized")
|
|
return
|
|
|
|
try:
|
|
# Extract line items from tables
|
|
logger.info(f"Extracting line items from PDF: {pdf_path}")
|
|
line_items_result = self.line_items_extractor.extract_from_pdf(str(pdf_path))
|
|
logger.info(f"Line items extraction result: {line_items_result is not None}, items={len(line_items_result.items) if line_items_result else 0}")
|
|
if line_items_result and line_items_result.items:
|
|
result.line_items = line_items_result
|
|
logger.info(f"Set result.line_items with {len(line_items_result.items)} items")
|
|
|
|
# Extract VAT summary from text
|
|
logger.info(f"Extracting VAT summary from text ({len(full_text)} chars)")
|
|
vat_summary = self.vat_extractor.extract(full_text)
|
|
logger.info(f"VAT summary extraction result: {vat_summary is not None}")
|
|
if vat_summary:
|
|
result.vat_summary = vat_summary
|
|
|
|
# Cross-validate VAT information
|
|
existing_amount = result.fields.get('Amount')
|
|
vat_validation = self.vat_validator.validate(
|
|
vat_summary,
|
|
line_items=line_items_result,
|
|
existing_amount=str(existing_amount) if existing_amount else None
|
|
)
|
|
result.vat_validation = vat_validation
|
|
logger.info(f"VAT validation completed: is_valid={vat_validation.is_valid if vat_validation else None}")
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
error_detail = f"{type(e).__name__}: {e}"
|
|
logger.error(f"Business feature extraction failed: {error_detail}\n{traceback.format_exc()}")
|
|
result.errors.append(f"Business feature extraction error: {error_detail}")
|
|
|
|
def _merge_fields(self, result: InferenceResult) -> None:
|
|
"""Merge extracted fields, keeping best candidate for each field.
|
|
|
|
Selection priority:
|
|
1. Prefer candidates without validation errors
|
|
2. Among equal validity, prefer higher confidence
|
|
"""
|
|
field_candidates: dict[str, list[ExtractedField]] = {}
|
|
|
|
for extracted in result.extracted_fields:
|
|
if not extracted.is_valid or not extracted.normalized_value:
|
|
continue
|
|
|
|
if extracted.field_name not in field_candidates:
|
|
field_candidates[extracted.field_name] = []
|
|
field_candidates[extracted.field_name].append(extracted)
|
|
|
|
# Select best candidate for each field
|
|
for field_name, candidates in field_candidates.items():
|
|
# Sort by: (no validation error, confidence) - descending
|
|
# This prefers candidates without errors, then by confidence
|
|
best = max(
|
|
candidates,
|
|
key=lambda x: (x.validation_error is None, x.confidence)
|
|
)
|
|
result.fields[field_name] = best.normalized_value
|
|
result.confidence[field_name] = best.confidence
|
|
# Store bbox for each field (useful for payment_line and other fields)
|
|
result.bboxes[field_name] = best.bbox
|
|
|
|
# Validate date consistency
|
|
self._validate_dates(result)
|
|
|
|
# Perform cross-validation if payment_line is detected
|
|
self._cross_validate_payment_line(result)
|
|
|
|
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
|
|
self._dedup_invoice_number(result)
|
|
|
|
def _validate_dates(self, result: InferenceResult) -> None:
|
|
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
|
|
invoice_date = result.fields.get('InvoiceDate')
|
|
due_date = result.fields.get('InvoiceDueDate')
|
|
if invoice_date and due_date and due_date < invoice_date:
|
|
del result.fields['InvoiceDueDate']
|
|
result.confidence.pop('InvoiceDueDate', None)
|
|
result.bboxes.pop('InvoiceDueDate', None)
|
|
|
|
def _dedup_invoice_number(self, result: InferenceResult) -> None:
|
|
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
|
|
inv_num = result.fields.get('InvoiceNumber')
|
|
if not inv_num:
|
|
return
|
|
inv_digits = re.sub(r'\D', '', str(inv_num))
|
|
|
|
# Check against OCR
|
|
ocr = result.fields.get('OCR')
|
|
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
|
|
del result.fields['InvoiceNumber']
|
|
result.confidence.pop('InvoiceNumber', None)
|
|
result.bboxes.pop('InvoiceNumber', None)
|
|
return
|
|
|
|
# Check against Bankgiro (exact or substring match)
|
|
bg = result.fields.get('Bankgiro')
|
|
if bg:
|
|
bg_digits = re.sub(r'\D', '', str(bg))
|
|
if inv_digits == bg_digits or inv_digits in bg_digits:
|
|
del result.fields['InvoiceNumber']
|
|
result.confidence.pop('InvoiceNumber', None)
|
|
result.bboxes.pop('InvoiceNumber', None)
|
|
|
|
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
|
"""
|
|
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
|
|
|
|
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
|
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
|
|
|
|
Returns: (ocr, amount, account) tuple
|
|
"""
|
|
parsed = self.payment_line_parser.parse(payment_line)
|
|
|
|
if not parsed.is_valid:
|
|
return None, None, None
|
|
|
|
return parsed.ocr_number, parsed.amount, parsed.account_number
|
|
|
|
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
|
|
"""
|
|
Cross-validate payment_line data against other detected fields.
|
|
Payment line values take PRIORITY over individually detected fields.
|
|
|
|
Swedish payment line (Betalningsrad) contains:
|
|
- OCR reference number
|
|
- Amount (kronor and öre)
|
|
- Bankgiro or Plusgiro account number
|
|
|
|
This method:
|
|
1. Parses payment_line to extract OCR, Amount, Account
|
|
2. Compares with separately detected fields for validation
|
|
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
|
|
"""
|
|
payment_line = result.fields.get('payment_line')
|
|
if not payment_line:
|
|
return
|
|
|
|
cv = CrossValidationResult()
|
|
cv.details = []
|
|
|
|
# Parse machine-readable payment line format
|
|
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
|
|
|
|
cv.payment_line_ocr = ocr
|
|
cv.payment_line_amount = amount
|
|
|
|
# Determine account type based on digit count
|
|
if account:
|
|
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
|
|
if len(account) >= 7:
|
|
cv.payment_line_account_type = 'bankgiro'
|
|
# Format: XXX-XXXX or XXXX-XXXX
|
|
if len(account) == 7:
|
|
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
|
|
else:
|
|
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
|
|
else:
|
|
cv.payment_line_account_type = 'plusgiro'
|
|
# Format: XXXXXXX-X
|
|
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
|
|
|
|
# Cross-validate and OVERRIDE with payment_line values
|
|
|
|
# OCR: payment_line takes priority
|
|
detected_ocr = result.fields.get('OCR')
|
|
if cv.payment_line_ocr:
|
|
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
|
|
if detected_ocr:
|
|
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
|
|
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
|
|
if cv.ocr_match:
|
|
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
|
|
else:
|
|
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
|
|
else:
|
|
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
|
|
# OVERRIDE: use payment_line OCR
|
|
result.fields['OCR'] = cv.payment_line_ocr
|
|
result.confidence['OCR'] = 0.95 # High confidence for payment_line
|
|
|
|
# Amount: payment_line takes priority
|
|
detected_amount = result.fields.get('Amount')
|
|
if cv.payment_line_amount:
|
|
if detected_amount:
|
|
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
|
|
det_amount = self._normalize_amount_for_compare(str(detected_amount))
|
|
cv.amount_match = pl_amount == det_amount
|
|
if cv.amount_match:
|
|
cv.details.append(f"Amount match: {cv.payment_line_amount}")
|
|
else:
|
|
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
|
|
else:
|
|
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
|
|
# OVERRIDE: use payment_line Amount
|
|
result.fields['Amount'] = cv.payment_line_amount
|
|
result.confidence['Amount'] = 0.95
|
|
|
|
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
|
detected_bankgiro = result.fields.get('Bankgiro')
|
|
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
|
|
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
|
if detected_bankgiro:
|
|
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
|
|
cv.bankgiro_match = pl_bg_digits == det_bg_digits
|
|
if cv.bankgiro_match:
|
|
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
|
|
else:
|
|
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
|
|
# Do NOT override - keep detected value
|
|
|
|
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
|
detected_plusgiro = result.fields.get('Plusgiro')
|
|
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
|
|
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
|
if detected_plusgiro:
|
|
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
|
|
cv.plusgiro_match = pl_pg_digits == det_pg_digits
|
|
if cv.plusgiro_match:
|
|
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
|
|
else:
|
|
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
|
|
# Do NOT override - keep detected value
|
|
|
|
# Determine overall validity
|
|
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
|
|
# has both accounts, the other one cannot be matched - this is expected and OK.
|
|
# Only count the account type that payment_line actually has.
|
|
matches = [cv.ocr_match, cv.amount_match]
|
|
|
|
# Only include account match if payment_line has that account type
|
|
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
|
|
matches.append(cv.bankgiro_match)
|
|
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
|
|
matches.append(cv.plusgiro_match)
|
|
|
|
valid_matches = [m for m in matches if m is not None]
|
|
if valid_matches:
|
|
match_count = sum(1 for m in valid_matches if m)
|
|
cv.is_valid = match_count >= min(2, len(valid_matches))
|
|
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
|
|
else:
|
|
# No comparison possible
|
|
cv.is_valid = True
|
|
cv.details.append("No comparison available from payment_line")
|
|
|
|
result.cross_validation = cv
|
|
|
|
def _normalize_amount_for_compare(self, amount: str) -> float | None:
|
|
"""Normalize amount string to float for comparison."""
|
|
try:
|
|
# Remove spaces, convert comma to dot
|
|
cleaned = amount.replace(' ', '').replace(',', '.')
|
|
# Handle Swedish format with space as thousands separator
|
|
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
|
|
return round(float(cleaned), 2)
|
|
except (ValueError, AttributeError):
|
|
return None
|
|
|
|
def _needs_fallback(self, result: InferenceResult) -> bool:
|
|
"""Check if fallback OCR is needed."""
|
|
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
|
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
|
|
|
|
key_missing = sum(1 for f in key_fields if f not in result.fields)
|
|
important_missing = sum(1 for f in important_fields if f not in result.fields)
|
|
|
|
# Fallback if any key field missing OR 2+ important fields missing
|
|
return key_missing >= 1 or important_missing >= 2
|
|
|
|
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
|
"""Run full-page OCR fallback."""
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
from shared.ocr import OCREngine
|
|
from PIL import Image
|
|
import io
|
|
import numpy as np
|
|
|
|
result.fallback_used = True
|
|
ocr_engine = OCREngine()
|
|
|
|
try:
|
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
image_array = np.array(image)
|
|
|
|
# Full page OCR
|
|
tokens = ocr_engine.extract_from_image(image_array, page_no)
|
|
full_text = ' '.join(t.text for t in tokens)
|
|
|
|
# Try to extract missing fields with regex patterns
|
|
self._extract_with_patterns(full_text, result)
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Fallback OCR error: {e}")
|
|
|
|
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
|
|
"""Extract fields using regex patterns (fallback)."""
|
|
patterns = {
|
|
'Amount': [
|
|
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
|
|
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
|
|
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
|
|
],
|
|
'Bankgiro': [
|
|
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
|
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
|
|
],
|
|
'OCR': [
|
|
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
|
],
|
|
'InvoiceNumber': [
|
|
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
|
],
|
|
'InvoiceDate': [
|
|
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
|
|
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
|
|
],
|
|
'InvoiceDueDate': [
|
|
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
|
|
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
|
|
],
|
|
'supplier_organisation_number': [
|
|
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
|
|
],
|
|
'Plusgiro': [
|
|
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
|
|
],
|
|
}
|
|
|
|
for field_name, field_patterns in patterns.items():
|
|
if field_name in result.fields:
|
|
continue
|
|
|
|
for pattern in field_patterns:
|
|
match = re.search(pattern, text, re.IGNORECASE)
|
|
if match:
|
|
value = match.group(1).strip()
|
|
|
|
# Normalize the value
|
|
if field_name == 'Amount':
|
|
value = value.replace(' ', '').replace(',', '.')
|
|
try:
|
|
value = f"{float(value):.2f}"
|
|
except ValueError:
|
|
continue
|
|
elif field_name == 'Bankgiro':
|
|
digits = re.sub(r'\D', '', value)
|
|
if len(digits) == 8:
|
|
value = f"{digits[:4]}-{digits[4:]}"
|
|
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
|
# Normalize DD/MM/YYYY to YYYY-MM-DD
|
|
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
|
|
if date_match:
|
|
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
|
|
# Replace / with -
|
|
value = value.replace('/', '-')
|
|
elif field_name == 'InvoiceNumber':
|
|
# Skip year-like values (2024, 2025, 2026, etc.)
|
|
if re.match(r'^20\d{2}$', value):
|
|
continue
|
|
elif field_name == 'supplier_organisation_number':
|
|
# Ensure NNNNNN-NNNN format
|
|
digits = re.sub(r'\D', '', value)
|
|
if len(digits) == 10:
|
|
value = f"{digits[:6]}-{digits[6:]}"
|
|
|
|
result.fields[field_name] = value
|
|
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
|
break
|
|
|
|
def process_image(
|
|
self,
|
|
image_path: str | Path,
|
|
document_id: str | None = None
|
|
) -> InferenceResult:
|
|
"""
|
|
Process a single image (for pre-rendered pages).
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
document_id: Optional document ID
|
|
|
|
Returns:
|
|
InferenceResult with extracted fields
|
|
"""
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
start_time = time.time()
|
|
|
|
result = InferenceResult(
|
|
document_id=document_id or Path(image_path).stem
|
|
)
|
|
|
|
try:
|
|
image = Image.open(image_path)
|
|
image_array = np.array(image)
|
|
|
|
# Run detection
|
|
detections = self.detector.detect(image_array, page_no=0)
|
|
result.raw_detections = detections
|
|
|
|
# Extract fields
|
|
for detection in detections:
|
|
extracted = self.extractor.extract_from_detection(detection, image_array)
|
|
result.extracted_fields.append(extracted)
|
|
|
|
# Merge fields
|
|
self._merge_fields(result)
|
|
result.success = len(result.fields) > 0
|
|
|
|
except Exception as e:
|
|
result.errors.append(str(e))
|
|
result.success = False
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
return result
|