500 lines
20 KiB
Python
500 lines
20 KiB
Python
"""
|
|
Inference Pipeline
|
|
|
|
Complete pipeline for extracting invoice data from PDFs.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
import time
|
|
import re
|
|
|
|
from shared.fields import CLASS_TO_FIELD
|
|
from .yolo_detector import YOLODetector, Detection
|
|
from .field_extractor import FieldExtractor, ExtractedField
|
|
from .payment_line_parser import PaymentLineParser
|
|
|
|
|
|
@dataclass
|
|
class CrossValidationResult:
|
|
"""Result of cross-validation between payment_line and other fields."""
|
|
is_valid: bool = False
|
|
ocr_match: bool | None = None # None if not comparable
|
|
amount_match: bool | None = None
|
|
bankgiro_match: bool | None = None
|
|
plusgiro_match: bool | None = None
|
|
payment_line_ocr: str | None = None
|
|
payment_line_amount: str | None = None
|
|
payment_line_account: str | None = None
|
|
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
|
|
details: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class InferenceResult:
|
|
"""Result of invoice processing."""
|
|
document_id: str | None = None
|
|
success: bool = False
|
|
fields: dict[str, Any] = field(default_factory=dict)
|
|
confidence: dict[str, float] = field(default_factory=dict)
|
|
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
|
|
raw_detections: list[Detection] = field(default_factory=list)
|
|
extracted_fields: list[ExtractedField] = field(default_factory=list)
|
|
processing_time_ms: float = 0.0
|
|
errors: list[str] = field(default_factory=list)
|
|
fallback_used: bool = False
|
|
cross_validation: CrossValidationResult | None = None
|
|
|
|
def to_json(self) -> dict:
|
|
"""Convert to JSON-serializable dictionary."""
|
|
result = {
|
|
'DocumentId': self.document_id,
|
|
'InvoiceNumber': self.fields.get('InvoiceNumber'),
|
|
'InvoiceDate': self.fields.get('InvoiceDate'),
|
|
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
|
|
'OCR': self.fields.get('OCR'),
|
|
'Bankgiro': self.fields.get('Bankgiro'),
|
|
'Plusgiro': self.fields.get('Plusgiro'),
|
|
'Amount': self.fields.get('Amount'),
|
|
'supplier_org_number': self.fields.get('supplier_org_number'),
|
|
'customer_number': self.fields.get('customer_number'),
|
|
'payment_line': self.fields.get('payment_line'),
|
|
'confidence': self.confidence,
|
|
'success': self.success,
|
|
'fallback_used': self.fallback_used
|
|
}
|
|
# Add bboxes if present
|
|
if self.bboxes:
|
|
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
|
|
# Add cross-validation results if present
|
|
if self.cross_validation:
|
|
result['cross_validation'] = {
|
|
'is_valid': self.cross_validation.is_valid,
|
|
'ocr_match': self.cross_validation.ocr_match,
|
|
'amount_match': self.cross_validation.amount_match,
|
|
'bankgiro_match': self.cross_validation.bankgiro_match,
|
|
'plusgiro_match': self.cross_validation.plusgiro_match,
|
|
'payment_line_ocr': self.cross_validation.payment_line_ocr,
|
|
'payment_line_amount': self.cross_validation.payment_line_amount,
|
|
'payment_line_account': self.cross_validation.payment_line_account,
|
|
'payment_line_account_type': self.cross_validation.payment_line_account_type,
|
|
'details': self.cross_validation.details,
|
|
}
|
|
return result
|
|
|
|
def get_field(self, field_name: str) -> tuple[Any, float]:
|
|
"""Get field value and confidence."""
|
|
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
|
|
|
|
|
|
class InferencePipeline:
|
|
"""
|
|
Complete inference pipeline for invoice data extraction.
|
|
|
|
Pipeline flow:
|
|
1. PDF -> Image rendering
|
|
2. YOLO detection of field regions
|
|
3. OCR extraction from detected regions
|
|
4. Field normalization and validation
|
|
5. Fallback to full-page OCR if YOLO fails
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str | Path,
|
|
confidence_threshold: float = 0.5,
|
|
ocr_lang: str = 'en',
|
|
use_gpu: bool = False,
|
|
dpi: int = 300,
|
|
enable_fallback: bool = True
|
|
):
|
|
"""
|
|
Initialize inference pipeline.
|
|
|
|
Args:
|
|
model_path: Path to trained YOLO model
|
|
confidence_threshold: Detection confidence threshold
|
|
ocr_lang: Language for OCR
|
|
use_gpu: Whether to use GPU
|
|
dpi: Resolution for PDF rendering
|
|
enable_fallback: Enable fallback to full-page OCR
|
|
"""
|
|
self.detector = YOLODetector(
|
|
model_path,
|
|
confidence_threshold=confidence_threshold,
|
|
device='cuda' if use_gpu else 'cpu'
|
|
)
|
|
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
|
|
self.payment_line_parser = PaymentLineParser()
|
|
self.dpi = dpi
|
|
self.enable_fallback = enable_fallback
|
|
|
|
def process_pdf(
|
|
self,
|
|
pdf_path: str | Path,
|
|
document_id: str | None = None
|
|
) -> InferenceResult:
|
|
"""
|
|
Process a PDF and extract invoice fields.
|
|
|
|
Args:
|
|
pdf_path: Path to PDF file
|
|
document_id: Optional document ID
|
|
|
|
Returns:
|
|
InferenceResult with extracted fields
|
|
"""
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
from PIL import Image
|
|
import io
|
|
import numpy as np
|
|
|
|
start_time = time.time()
|
|
|
|
result = InferenceResult(
|
|
document_id=document_id or Path(pdf_path).stem
|
|
)
|
|
|
|
try:
|
|
all_detections = []
|
|
all_extracted = []
|
|
|
|
# Process each page
|
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
|
# Convert to numpy array
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
image_array = np.array(image)
|
|
|
|
# Run YOLO detection
|
|
detections = self.detector.detect(image_array, page_no=page_no)
|
|
all_detections.extend(detections)
|
|
|
|
# Extract fields from detections
|
|
for detection in detections:
|
|
extracted = self.extractor.extract_from_detection(detection, image_array)
|
|
all_extracted.append(extracted)
|
|
|
|
result.raw_detections = all_detections
|
|
result.extracted_fields = all_extracted
|
|
|
|
# Merge extracted fields (prefer highest confidence)
|
|
self._merge_fields(result)
|
|
|
|
# Fallback if key fields are missing
|
|
if self.enable_fallback and self._needs_fallback(result):
|
|
self._run_fallback(pdf_path, result)
|
|
|
|
result.success = len(result.fields) > 0
|
|
|
|
except Exception as e:
|
|
result.errors.append(str(e))
|
|
result.success = False
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
return result
|
|
|
|
def _merge_fields(self, result: InferenceResult) -> None:
|
|
"""Merge extracted fields, keeping highest confidence for each field."""
|
|
field_candidates: dict[str, list[ExtractedField]] = {}
|
|
|
|
for extracted in result.extracted_fields:
|
|
if not extracted.is_valid or not extracted.normalized_value:
|
|
continue
|
|
|
|
if extracted.field_name not in field_candidates:
|
|
field_candidates[extracted.field_name] = []
|
|
field_candidates[extracted.field_name].append(extracted)
|
|
|
|
# Select best candidate for each field
|
|
for field_name, candidates in field_candidates.items():
|
|
best = max(candidates, key=lambda x: x.confidence)
|
|
result.fields[field_name] = best.normalized_value
|
|
result.confidence[field_name] = best.confidence
|
|
# Store bbox for each field (useful for payment_line and other fields)
|
|
result.bboxes[field_name] = best.bbox
|
|
|
|
# Perform cross-validation if payment_line is detected
|
|
self._cross_validate_payment_line(result)
|
|
|
|
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
|
"""
|
|
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
|
|
|
|
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
|
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
|
|
|
|
Returns: (ocr, amount, account) tuple
|
|
"""
|
|
parsed = self.payment_line_parser.parse(payment_line)
|
|
|
|
if not parsed.is_valid:
|
|
return None, None, None
|
|
|
|
return parsed.ocr_number, parsed.amount, parsed.account_number
|
|
|
|
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
|
|
"""
|
|
Cross-validate payment_line data against other detected fields.
|
|
Payment line values take PRIORITY over individually detected fields.
|
|
|
|
Swedish payment line (Betalningsrad) contains:
|
|
- OCR reference number
|
|
- Amount (kronor and öre)
|
|
- Bankgiro or Plusgiro account number
|
|
|
|
This method:
|
|
1. Parses payment_line to extract OCR, Amount, Account
|
|
2. Compares with separately detected fields for validation
|
|
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
|
|
"""
|
|
payment_line = result.fields.get('payment_line')
|
|
if not payment_line:
|
|
return
|
|
|
|
cv = CrossValidationResult()
|
|
cv.details = []
|
|
|
|
# Parse machine-readable payment line format
|
|
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
|
|
|
|
cv.payment_line_ocr = ocr
|
|
cv.payment_line_amount = amount
|
|
|
|
# Determine account type based on digit count
|
|
if account:
|
|
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
|
|
if len(account) >= 7:
|
|
cv.payment_line_account_type = 'bankgiro'
|
|
# Format: XXX-XXXX or XXXX-XXXX
|
|
if len(account) == 7:
|
|
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
|
|
else:
|
|
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
|
|
else:
|
|
cv.payment_line_account_type = 'plusgiro'
|
|
# Format: XXXXXXX-X
|
|
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
|
|
|
|
# Cross-validate and OVERRIDE with payment_line values
|
|
|
|
# OCR: payment_line takes priority
|
|
detected_ocr = result.fields.get('OCR')
|
|
if cv.payment_line_ocr:
|
|
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
|
|
if detected_ocr:
|
|
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
|
|
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
|
|
if cv.ocr_match:
|
|
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
|
|
else:
|
|
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
|
|
else:
|
|
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
|
|
# OVERRIDE: use payment_line OCR
|
|
result.fields['OCR'] = cv.payment_line_ocr
|
|
result.confidence['OCR'] = 0.95 # High confidence for payment_line
|
|
|
|
# Amount: payment_line takes priority
|
|
detected_amount = result.fields.get('Amount')
|
|
if cv.payment_line_amount:
|
|
if detected_amount:
|
|
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
|
|
det_amount = self._normalize_amount_for_compare(str(detected_amount))
|
|
cv.amount_match = pl_amount == det_amount
|
|
if cv.amount_match:
|
|
cv.details.append(f"Amount match: {cv.payment_line_amount}")
|
|
else:
|
|
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
|
|
else:
|
|
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
|
|
# OVERRIDE: use payment_line Amount
|
|
result.fields['Amount'] = cv.payment_line_amount
|
|
result.confidence['Amount'] = 0.95
|
|
|
|
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
|
detected_bankgiro = result.fields.get('Bankgiro')
|
|
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
|
|
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
|
if detected_bankgiro:
|
|
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
|
|
cv.bankgiro_match = pl_bg_digits == det_bg_digits
|
|
if cv.bankgiro_match:
|
|
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
|
|
else:
|
|
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
|
|
# Do NOT override - keep detected value
|
|
|
|
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
|
|
detected_plusgiro = result.fields.get('Plusgiro')
|
|
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
|
|
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
|
|
if detected_plusgiro:
|
|
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
|
|
cv.plusgiro_match = pl_pg_digits == det_pg_digits
|
|
if cv.plusgiro_match:
|
|
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
|
|
else:
|
|
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
|
|
# Do NOT override - keep detected value
|
|
|
|
# Determine overall validity
|
|
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
|
|
# has both accounts, the other one cannot be matched - this is expected and OK.
|
|
# Only count the account type that payment_line actually has.
|
|
matches = [cv.ocr_match, cv.amount_match]
|
|
|
|
# Only include account match if payment_line has that account type
|
|
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
|
|
matches.append(cv.bankgiro_match)
|
|
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
|
|
matches.append(cv.plusgiro_match)
|
|
|
|
valid_matches = [m for m in matches if m is not None]
|
|
if valid_matches:
|
|
match_count = sum(1 for m in valid_matches if m)
|
|
cv.is_valid = match_count >= min(2, len(valid_matches))
|
|
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
|
|
else:
|
|
# No comparison possible
|
|
cv.is_valid = True
|
|
cv.details.append("No comparison available from payment_line")
|
|
|
|
result.cross_validation = cv
|
|
|
|
def _normalize_amount_for_compare(self, amount: str) -> float | None:
|
|
"""Normalize amount string to float for comparison."""
|
|
try:
|
|
# Remove spaces, convert comma to dot
|
|
cleaned = amount.replace(' ', '').replace(',', '.')
|
|
# Handle Swedish format with space as thousands separator
|
|
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
|
|
return round(float(cleaned), 2)
|
|
except (ValueError, AttributeError):
|
|
return None
|
|
|
|
def _needs_fallback(self, result: InferenceResult) -> bool:
|
|
"""Check if fallback OCR is needed."""
|
|
# Check for key fields
|
|
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
|
missing = sum(1 for f in key_fields if f not in result.fields)
|
|
return missing >= 2 # Fallback if 2+ key fields missing
|
|
|
|
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
|
"""Run full-page OCR fallback."""
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
from shared.ocr import OCREngine
|
|
from PIL import Image
|
|
import io
|
|
import numpy as np
|
|
|
|
result.fallback_used = True
|
|
ocr_engine = OCREngine()
|
|
|
|
try:
|
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
image_array = np.array(image)
|
|
|
|
# Full page OCR
|
|
tokens = ocr_engine.extract_from_image(image_array, page_no)
|
|
full_text = ' '.join(t.text for t in tokens)
|
|
|
|
# Try to extract missing fields with regex patterns
|
|
self._extract_with_patterns(full_text, result)
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Fallback OCR error: {e}")
|
|
|
|
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
|
|
"""Extract fields using regex patterns (fallback)."""
|
|
patterns = {
|
|
'Amount': [
|
|
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
|
|
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
|
|
],
|
|
'Bankgiro': [
|
|
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
|
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
|
|
],
|
|
'OCR': [
|
|
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
|
],
|
|
'InvoiceNumber': [
|
|
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
|
],
|
|
}
|
|
|
|
for field_name, field_patterns in patterns.items():
|
|
if field_name in result.fields:
|
|
continue
|
|
|
|
for pattern in field_patterns:
|
|
match = re.search(pattern, text, re.IGNORECASE)
|
|
if match:
|
|
value = match.group(1).strip()
|
|
|
|
# Normalize the value
|
|
if field_name == 'Amount':
|
|
value = value.replace(' ', '').replace(',', '.')
|
|
try:
|
|
value = f"{float(value):.2f}"
|
|
except ValueError:
|
|
continue
|
|
elif field_name == 'Bankgiro':
|
|
digits = re.sub(r'\D', '', value)
|
|
if len(digits) == 8:
|
|
value = f"{digits[:4]}-{digits[4:]}"
|
|
|
|
result.fields[field_name] = value
|
|
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
|
break
|
|
|
|
def process_image(
|
|
self,
|
|
image_path: str | Path,
|
|
document_id: str | None = None
|
|
) -> InferenceResult:
|
|
"""
|
|
Process a single image (for pre-rendered pages).
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
document_id: Optional document ID
|
|
|
|
Returns:
|
|
InferenceResult with extracted fields
|
|
"""
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
start_time = time.time()
|
|
|
|
result = InferenceResult(
|
|
document_id=document_id or Path(image_path).stem
|
|
)
|
|
|
|
try:
|
|
image = Image.open(image_path)
|
|
image_array = np.array(image)
|
|
|
|
# Run detection
|
|
detections = self.detector.detect(image_array, page_no=0)
|
|
result.raw_detections = detections
|
|
|
|
# Extract fields
|
|
for detection in detections:
|
|
extracted = self.extractor.extract_from_detection(detection, image_array)
|
|
result.extracted_fields.append(extracted)
|
|
|
|
# Merge fields
|
|
self._merge_fields(result)
|
|
result.success = len(result.fields) > 0
|
|
|
|
except Exception as e:
|
|
result.errors.append(str(e))
|
|
result.success = False
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
return result
|