Files
invoice-master-poc-v2/packages/backend/backend/pipeline/pipeline.py
Yaojia Wang d8f2acb762 fix: change default OCR language from English to Swedish
Project targets Swedish invoice extraction. PaddleOCR sv model provides
better recognition of Swedish-specific characters (å, ä, ö).
2026-02-12 23:19:51 +01:00

838 lines
34 KiB
Python

"""
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
Supports both basic field extraction and business invoice features
(line items, VAT extraction, cross-validation).
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import logging
import time
import re
logger = logging.getLogger(__name__)
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
# Business invoice feature imports (optional - for extract_line_items mode)
try:
from ..table.line_items_extractor import LineItem, LineItemsResult, LineItemsExtractor
from ..table.structure_detector import TableDetector
from ..vat.vat_extractor import VATSummary, VATExtractor
from ..validation.vat_validator import VATValidationResult, VATValidator
BUSINESS_FEATURES_AVAILABLE = True
except ImportError:
BUSINESS_FEATURES_AVAILABLE = False
LineItem = None
LineItemsResult = None
TableDetector = None
VATSummary = None
VATValidationResult = None
@dataclass
class CrossValidationResult:
"""Result of cross-validation between payment_line and other fields."""
is_valid: bool = False
ocr_match: bool | None = None # None if not comparable
amount_match: bool | None = None
bankgiro_match: bool | None = None
plusgiro_match: bool | None = None
payment_line_ocr: str | None = None
payment_line_amount: str | None = None
payment_line_account: str | None = None
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
details: list[str] = field(default_factory=list)
@dataclass
class InferenceResult:
"""Result of invoice processing."""
document_id: str | None = None
success: bool = False
fields: dict[str, Any] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
raw_detections: list[Detection] = field(default_factory=list)
extracted_fields: list[ExtractedField] = field(default_factory=list)
processing_time_ms: float = 0.0
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
# Business invoice features (optional)
line_items: Any | None = None # LineItemsResult when available
vat_summary: Any | None = None # VATSummary when available
vat_validation: Any | None = None # VATValidationResult when available
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
result = {
'DocumentId': self.document_id,
'InvoiceNumber': self.fields.get('InvoiceNumber'),
'InvoiceDate': self.fields.get('InvoiceDate'),
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
'OCR': self.fields.get('OCR'),
'Bankgiro': self.fields.get('Bankgiro'),
'Plusgiro': self.fields.get('Plusgiro'),
'Amount': self.fields.get('Amount'),
'supplier_org_number': self.fields.get('supplier_org_number'),
'customer_number': self.fields.get('customer_number'),
'payment_line': self.fields.get('payment_line'),
'confidence': self.confidence,
'success': self.success,
'fallback_used': self.fallback_used
}
# Add bboxes if present
if self.bboxes:
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
# Add cross-validation results if present
if self.cross_validation:
result['cross_validation'] = {
'is_valid': self.cross_validation.is_valid,
'ocr_match': self.cross_validation.ocr_match,
'amount_match': self.cross_validation.amount_match,
'bankgiro_match': self.cross_validation.bankgiro_match,
'plusgiro_match': self.cross_validation.plusgiro_match,
'payment_line_ocr': self.cross_validation.payment_line_ocr,
'payment_line_amount': self.cross_validation.payment_line_amount,
'payment_line_account': self.cross_validation.payment_line_account,
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
# Add business invoice features if present
if self.line_items is not None:
result['line_items'] = self._line_items_to_json()
if self.vat_summary is not None:
result['vat_summary'] = self._vat_summary_to_json()
if self.vat_validation is not None:
result['vat_validation'] = self._vat_validation_to_json()
return result
def _line_items_to_json(self) -> dict | None:
"""Convert LineItemsResult to JSON."""
if self.line_items is None:
return None
li = self.line_items
return {
'items': [
{
'row_index': item.row_index,
'description': item.description,
'quantity': item.quantity,
'unit': item.unit,
'unit_price': item.unit_price,
'amount': item.amount,
'article_number': item.article_number,
'vat_rate': item.vat_rate,
'is_deduction': item.is_deduction,
'confidence': item.confidence,
}
for item in li.items
],
'header_row': li.header_row,
'total_amount': li.total_amount,
}
def _vat_summary_to_json(self) -> dict | None:
"""Convert VATSummary to JSON."""
if self.vat_summary is None:
return None
vs = self.vat_summary
return {
'breakdowns': [
{
'rate': b.rate,
'base_amount': b.base_amount,
'vat_amount': b.vat_amount,
'source': b.source,
}
for b in vs.breakdowns
],
'total_excl_vat': vs.total_excl_vat,
'total_vat': vs.total_vat,
'total_incl_vat': vs.total_incl_vat,
'confidence': vs.confidence,
}
def _vat_validation_to_json(self) -> dict | None:
"""Convert VATValidationResult to JSON."""
if self.vat_validation is None:
return None
vv = self.vat_validation
return {
'is_valid': vv.is_valid,
'confidence_score': vv.confidence_score,
'math_checks': [
{
'rate': mc.rate,
'base_amount': mc.base_amount,
'expected_vat': mc.expected_vat,
'actual_vat': mc.actual_vat,
'is_valid': mc.is_valid,
'tolerance': mc.tolerance,
}
for mc in vv.math_checks
],
'total_check': vv.total_check,
'line_items_vs_summary': vv.line_items_vs_summary,
'amount_consistency': vv.amount_consistency,
'needs_review': vv.needs_review,
'review_reasons': vv.review_reasons,
}
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
class InferencePipeline:
"""
Complete inference pipeline for invoice data extraction.
Pipeline flow:
1. PDF -> Image rendering
2. YOLO detection of field regions
3. OCR extraction from detected regions
4. Field normalization and validation
5. Fallback to full-page OCR if YOLO fails
"""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'sv',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True,
enable_business_features: bool = False,
vat_tolerance: float = 0.5
):
"""
Initialize inference pipeline.
Args:
model_path: Path to trained YOLO model
confidence_threshold: Detection confidence threshold
ocr_lang: Language for OCR
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
enable_business_features: Enable line items/VAT extraction
vat_tolerance: Tolerance for VAT math checks (in currency units)
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu, dpi=dpi)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
self.enable_business_features = enable_business_features
self.vat_tolerance = vat_tolerance
# Initialize business feature components if enabled and available
self.line_items_extractor = None
self.vat_extractor = None
self.vat_validator = None
self._business_ocr_engine = None # Lazy-initialized for VAT text extraction
self._table_detector = None # Shared TableDetector for line items extraction
if enable_business_features:
if not BUSINESS_FEATURES_AVAILABLE:
raise ImportError(
"Business features require table, vat, and validation modules. "
"Please ensure they are properly installed."
)
# Create shared TableDetector for performance (PP-StructureV3 init is slow)
self._table_detector = TableDetector()
# Pass shared detector to LineItemsExtractor
self.line_items_extractor = LineItemsExtractor(table_detector=self._table_detector)
self.vat_extractor = VATExtractor()
self.vat_validator = VATValidator(tolerance=vat_tolerance)
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None,
extract_line_items: bool | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
extract_line_items: Whether to extract line items and VAT info.
If None, uses the enable_business_features setting from __init__.
Returns:
InferenceResult with extracted fields
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(pdf_path).stem
)
# Determine if business features should be used
use_business_features = (
extract_line_items if extract_line_items is not None
else self.enable_business_features
)
try:
all_detections = []
all_extracted = []
all_ocr_text = [] # Collect OCR text for VAT extraction
# Check if PDF has readable text layer (avoids OCR for text PDFs)
from shared.pdf.extractor import PDFDocument
is_text_pdf = False
pdf_tokens_by_page: dict[int, list] = {}
try:
with PDFDocument(pdf_path) as pdf_doc:
is_text_pdf = pdf_doc.is_text_pdf()
if is_text_pdf:
for pg in range(pdf_doc.page_count):
pdf_tokens_by_page[pg] = list(
pdf_doc.extract_text_tokens(pg)
)
logger.info(
"Text PDF detected, extracted %d tokens from %d pages",
sum(len(t) for t in pdf_tokens_by_page.values()),
len(pdf_tokens_by_page),
)
except Exception as e:
logger.warning("PDF text detection failed, falling back to OCR: %s", e)
is_text_pdf = False
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run YOLO detection
detections = self.detector.detect(image_array, page_no=page_no)
all_detections.extend(detections)
# Extract fields from detections
for detection in detections:
if is_text_pdf and page_no in pdf_tokens_by_page:
extracted = self.extractor.extract_from_detection_with_pdf(
detection,
pdf_tokens_by_page[page_no],
image.width,
image.height,
)
else:
extracted = self.extractor.extract_from_detection(
detection, image_array
)
all_extracted.append(extracted)
# Collect full-page OCR text for VAT extraction (only if business features enabled)
if use_business_features:
page_text = self._get_full_page_text(image_array)
all_ocr_text.append(page_text)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
# Merge extracted fields (prefer highest confidence)
self._merge_fields(result)
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
self._dedup_invoice_number(result)
# Extract business invoice features if enabled
if use_business_features:
self._extract_business_features(pdf_path, result, '\n'.join(all_ocr_text))
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _get_full_page_text(self, image_array) -> str:
"""Extract full page text using OCR for VAT extraction."""
from shared.ocr import OCREngine
import logging
logger = logging.getLogger(__name__)
try:
# Lazy initialize OCR engine to avoid repeated model loading
if self._business_ocr_engine is None:
self._business_ocr_engine = OCREngine()
tokens = self._business_ocr_engine.extract_from_image(image_array, page_no=0)
return ' '.join(t.text for t in tokens)
except Exception as e:
logger.warning(f"OCR extraction for VAT failed: {e}")
return ""
def _extract_business_features(
self,
pdf_path: str | Path,
result: InferenceResult,
full_text: str
) -> None:
"""
Extract line items, VAT summary, and perform cross-validation.
Args:
pdf_path: Path to PDF file
result: InferenceResult to populate
full_text: Full OCR text from all pages
"""
if not BUSINESS_FEATURES_AVAILABLE:
result.errors.append("Business features not available")
return
if not self.line_items_extractor or not self.vat_extractor or not self.vat_validator:
result.errors.append("Business feature extractors not initialized")
return
try:
# Extract line items from tables
logger.info(f"Extracting line items from PDF: {pdf_path}")
line_items_result = self.line_items_extractor.extract_from_pdf(str(pdf_path))
logger.info(f"Line items extraction result: {line_items_result is not None}, items={len(line_items_result.items) if line_items_result else 0}")
if line_items_result and line_items_result.items:
result.line_items = line_items_result
logger.info(f"Set result.line_items with {len(line_items_result.items)} items")
# Extract VAT summary from text
logger.info(f"Extracting VAT summary from text ({len(full_text)} chars)")
vat_summary = self.vat_extractor.extract(full_text)
logger.info(f"VAT summary extraction result: {vat_summary is not None}")
if vat_summary:
result.vat_summary = vat_summary
# Cross-validate VAT information
existing_amount = result.fields.get('Amount')
vat_validation = self.vat_validator.validate(
vat_summary,
line_items=line_items_result,
existing_amount=str(existing_amount) if existing_amount else None
)
result.vat_validation = vat_validation
logger.info(f"VAT validation completed: is_valid={vat_validation.is_valid if vat_validation else None}")
except Exception as e:
import traceback
error_detail = f"{type(e).__name__}: {e}"
logger.error(f"Business feature extraction failed: {error_detail}\n{traceback.format_exc()}")
result.errors.append(f"Business feature extraction error: {error_detail}")
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping best candidate for each field.
Selection priority:
1. Prefer candidates without validation errors
2. Among equal validity, prefer higher confidence
"""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
if not extracted.is_valid or not extracted.normalized_value:
continue
if extracted.field_name not in field_candidates:
field_candidates[extracted.field_name] = []
field_candidates[extracted.field_name].append(extracted)
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
# Sort by: (no validation error, confidence) - descending
# This prefers candidates without errors, then by confidence
best = max(
candidates,
key=lambda x: (x.validation_error is None, x.confidence)
)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Validate date consistency
self._validate_dates(result)
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
# Remove InvoiceNumber if it duplicates OCR or Bankgiro
self._dedup_invoice_number(result)
def _validate_dates(self, result: InferenceResult) -> None:
"""Remove InvoiceDueDate if it is earlier than InvoiceDate."""
invoice_date = result.fields.get('InvoiceDate')
due_date = result.fields.get('InvoiceDueDate')
if invoice_date and due_date and due_date < invoice_date:
del result.fields['InvoiceDueDate']
result.confidence.pop('InvoiceDueDate', None)
result.bboxes.pop('InvoiceDueDate', None)
def _dedup_invoice_number(self, result: InferenceResult) -> None:
"""Remove InvoiceNumber if it duplicates OCR or Bankgiro digits."""
inv_num = result.fields.get('InvoiceNumber')
if not inv_num:
return
inv_digits = re.sub(r'\D', '', str(inv_num))
# Check against OCR
ocr = result.fields.get('OCR')
if ocr and inv_digits == re.sub(r'\D', '', str(ocr)):
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
return
# Check against Bankgiro (exact or substring match)
bg = result.fields.get('Bankgiro')
if bg:
bg_digits = re.sub(r'\D', '', str(bg))
if inv_digits == bg_digits or inv_digits in bg_digits:
del result.fields['InvoiceNumber']
result.confidence.pop('InvoiceNumber', None)
result.bboxes.pop('InvoiceNumber', None)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple
"""
parsed = self.payment_line_parser.parse(payment_line)
if not parsed.is_valid:
return None, None, None
return parsed.ocr_number, parsed.amount, parsed.account_number
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
"""
Cross-validate payment_line data against other detected fields.
Payment line values take PRIORITY over individually detected fields.
Swedish payment line (Betalningsrad) contains:
- OCR reference number
- Amount (kronor and öre)
- Bankgiro or Plusgiro account number
This method:
1. Parses payment_line to extract OCR, Amount, Account
2. Compares with separately detected fields for validation
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
"""
payment_line = result.fields.get('payment_line')
if not payment_line:
return
cv = CrossValidationResult()
cv.details = []
# Parse machine-readable payment line format
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
cv.payment_line_ocr = ocr
cv.payment_line_amount = amount
# Determine account type based on digit count
if account:
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
if len(account) >= 7:
cv.payment_line_account_type = 'bankgiro'
# Format: XXX-XXXX or XXXX-XXXX
if len(account) == 7:
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
else:
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
else:
cv.payment_line_account_type = 'plusgiro'
# Format: XXXXXXX-X
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
# Cross-validate and OVERRIDE with payment_line values
# OCR: payment_line takes priority
detected_ocr = result.fields.get('OCR')
if cv.payment_line_ocr:
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
if detected_ocr:
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
if cv.ocr_match:
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
else:
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
else:
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
# OVERRIDE: use payment_line OCR
result.fields['OCR'] = cv.payment_line_ocr
result.confidence['OCR'] = 0.95 # High confidence for payment_line
# Amount: payment_line takes priority
detected_amount = result.fields.get('Amount')
if cv.payment_line_amount:
if detected_amount:
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
det_amount = self._normalize_amount_for_compare(str(detected_amount))
cv.amount_match = pl_amount == det_amount
if cv.amount_match:
cv.details.append(f"Amount match: {cv.payment_line_amount}")
else:
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
else:
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
# OVERRIDE: use payment_line Amount
result.fields['Amount'] = cv.payment_line_amount
result.confidence['Amount'] = 0.95
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_bankgiro = result.fields.get('Bankgiro')
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_bankgiro:
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
cv.bankgiro_match = pl_bg_digits == det_bg_digits
if cv.bankgiro_match:
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
else:
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_plusgiro = result.fields.get('Plusgiro')
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_plusgiro:
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
cv.plusgiro_match = pl_pg_digits == det_pg_digits
if cv.plusgiro_match:
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
else:
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Determine overall validity
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
# has both accounts, the other one cannot be matched - this is expected and OK.
# Only count the account type that payment_line actually has.
matches = [cv.ocr_match, cv.amount_match]
# Only include account match if payment_line has that account type
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
matches.append(cv.bankgiro_match)
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
matches.append(cv.plusgiro_match)
valid_matches = [m for m in matches if m is not None]
if valid_matches:
match_count = sum(1 for m in valid_matches if m)
cv.is_valid = match_count >= min(2, len(valid_matches))
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
else:
# No comparison possible
cv.is_valid = True
cv.details.append("No comparison available from payment_line")
result.cross_validation = cv
def _normalize_amount_for_compare(self, amount: str) -> float | None:
"""Normalize amount string to float for comparison."""
try:
# Remove spaces, convert comma to dot
cleaned = amount.replace(' ', '').replace(',', '.')
# Handle Swedish format with space as thousands separator
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
return round(float(cleaned), 2)
except (ValueError, AttributeError):
return None
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
important_fields = ['InvoiceDate', 'InvoiceDueDate', 'supplier_organisation_number']
key_missing = sum(1 for f in key_fields if f not in result.fields)
important_missing = sum(1 for f in important_fields if f not in result.fields)
# Fallback if any key field missing OR 2+ important fields missing
return key_missing >= 1 or important_missing >= 2
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
from shared.pdf.renderer import render_pdf_to_images
from shared.ocr import OCREngine
from PIL import Image
import io
import numpy as np
result.fallback_used = True
ocr_engine = OCREngine()
try:
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Full page OCR
tokens = ocr_engine.extract_from_image(image_array, page_no)
full_text = ' '.join(t.text for t in tokens)
# Try to extract missing fields with regex patterns
self._extract_with_patterns(full_text, result)
except Exception as e:
result.errors.append(f"Fallback OCR error: {e}")
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s+betala)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'(?:summa|total|belopp)\s*[:.]?\s*([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)?',
r'([\d\s\.]*\d+[,\.]\d{2})\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(?<!\d)(\d{3,4}[-\s]\d{4})(?!\d)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
],
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
'InvoiceDate': [
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:fakturadatum|invoice\s*date)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'InvoiceDueDate': [
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{4}[-/]\d{2}[-/]\d{2})',
r'(?:f[oö]rfallo(?:dag|datum)?|due\s*date|betala\s*senast)\s*[:.]?\s*(\d{2}[-/]\d{2}[-/]\d{4})',
],
'supplier_organisation_number': [
r'(?:org\.?\s*n[ru]|organisationsnummer)\s*[:.]?\s*(\d{6}[-\s]?\d{4})',
],
'Plusgiro': [
r'(?:plusgiro|pg)\s*[:.]?\s*(\d[\d\s-]{4,12}\d)',
],
}
for field_name, field_patterns in patterns.items():
if field_name in result.fields:
continue
for pattern in field_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
value = match.group(1).strip()
# Normalize the value
if field_name == 'Amount':
value = value.replace(' ', '').replace(',', '.')
try:
value = f"{float(value):.2f}"
except ValueError:
continue
elif field_name == 'Bankgiro':
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Normalize DD/MM/YYYY to YYYY-MM-DD
date_match = re.match(r'(\d{2})[-/](\d{2})[-/](\d{4})', value)
if date_match:
value = f"{date_match.group(3)}-{date_match.group(2)}-{date_match.group(1)}"
# Replace / with -
value = value.replace('/', '-')
elif field_name == 'InvoiceNumber':
# Skip year-like values (2024, 2025, 2026, etc.)
if re.match(r'^20\d{2}$', value):
continue
elif field_name == 'supplier_organisation_number':
# Ensure NNNNNN-NNNN format
digits = re.sub(r'\D', '', value)
if len(digits) == 10:
value = f"{digits[:6]}-{digits[6:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex
break
def process_image(
self,
image_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a single image (for pre-rendered pages).
Args:
image_path: Path to image file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from PIL import Image
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(image_path).stem
)
try:
image = Image.open(image_path)
image_array = np.array(image)
# Run detection
detections = self.detector.detect(image_array, page_no=0)
result.raw_detections = detections
# Extract fields
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
result.extracted_fields.append(extracted)
# Merge fields
self._merge_fields(result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result