Files
invoice-master-poc-v2/packages/inference/inference/pipeline/pipeline.py
Yaojia Wang a516de4320 WIP
2026-02-01 00:08:40 +01:00

500 lines
20 KiB
Python

"""
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import time
import re
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor, ExtractedField
from .payment_line_parser import PaymentLineParser
@dataclass
class CrossValidationResult:
"""Result of cross-validation between payment_line and other fields."""
is_valid: bool = False
ocr_match: bool | None = None # None if not comparable
amount_match: bool | None = None
bankgiro_match: bool | None = None
plusgiro_match: bool | None = None
payment_line_ocr: str | None = None
payment_line_amount: str | None = None
payment_line_account: str | None = None
payment_line_account_type: str | None = None # 'bankgiro' or 'plusgiro'
details: list[str] = field(default_factory=list)
@dataclass
class InferenceResult:
"""Result of invoice processing."""
document_id: str | None = None
success: bool = False
fields: dict[str, Any] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
bboxes: dict[str, tuple[float, float, float, float]] = field(default_factory=dict) # Field bboxes in pixels
raw_detections: list[Detection] = field(default_factory=list)
extracted_fields: list[ExtractedField] = field(default_factory=list)
processing_time_ms: float = 0.0
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
cross_validation: CrossValidationResult | None = None
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
result = {
'DocumentId': self.document_id,
'InvoiceNumber': self.fields.get('InvoiceNumber'),
'InvoiceDate': self.fields.get('InvoiceDate'),
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
'OCR': self.fields.get('OCR'),
'Bankgiro': self.fields.get('Bankgiro'),
'Plusgiro': self.fields.get('Plusgiro'),
'Amount': self.fields.get('Amount'),
'supplier_org_number': self.fields.get('supplier_org_number'),
'customer_number': self.fields.get('customer_number'),
'payment_line': self.fields.get('payment_line'),
'confidence': self.confidence,
'success': self.success,
'fallback_used': self.fallback_used
}
# Add bboxes if present
if self.bboxes:
result['bboxes'] = {k: list(v) for k, v in self.bboxes.items()}
# Add cross-validation results if present
if self.cross_validation:
result['cross_validation'] = {
'is_valid': self.cross_validation.is_valid,
'ocr_match': self.cross_validation.ocr_match,
'amount_match': self.cross_validation.amount_match,
'bankgiro_match': self.cross_validation.bankgiro_match,
'plusgiro_match': self.cross_validation.plusgiro_match,
'payment_line_ocr': self.cross_validation.payment_line_ocr,
'payment_line_amount': self.cross_validation.payment_line_amount,
'payment_line_account': self.cross_validation.payment_line_account,
'payment_line_account_type': self.cross_validation.payment_line_account_type,
'details': self.cross_validation.details,
}
return result
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
class InferencePipeline:
"""
Complete inference pipeline for invoice data extraction.
Pipeline flow:
1. PDF -> Image rendering
2. YOLO detection of field regions
3. OCR extraction from detected regions
4. Field normalization and validation
5. Fallback to full-page OCR if YOLO fails
"""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'en',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
):
"""
Initialize inference pipeline.
Args:
model_path: Path to trained YOLO model
confidence_threshold: Detection confidence threshold
ocr_lang: Language for OCR
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.payment_line_parser = PaymentLineParser()
self.dpi = dpi
self.enable_fallback = enable_fallback
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from shared.pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(pdf_path).stem
)
try:
all_detections = []
all_extracted = []
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run YOLO detection
detections = self.detector.detect(image_array, page_no=page_no)
all_detections.extend(detections)
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
all_extracted.append(extracted)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
# Merge extracted fields (prefer highest confidence)
self._merge_fields(result)
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
if not extracted.is_valid or not extracted.normalized_value:
continue
if extracted.field_name not in field_candidates:
field_candidates[extracted.field_name] = []
field_candidates[extracted.field_name].append(extracted)
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)
result.bboxes[field_name] = best.bbox
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format using unified PaymentLineParser.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple
"""
parsed = self.payment_line_parser.parse(payment_line)
if not parsed.is_valid:
return None, None, None
return parsed.ocr_number, parsed.amount, parsed.account_number
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
"""
Cross-validate payment_line data against other detected fields.
Payment line values take PRIORITY over individually detected fields.
Swedish payment line (Betalningsrad) contains:
- OCR reference number
- Amount (kronor and öre)
- Bankgiro or Plusgiro account number
This method:
1. Parses payment_line to extract OCR, Amount, Account
2. Compares with separately detected fields for validation
3. OVERWRITES detected fields with payment_line values (payment_line is authoritative)
"""
payment_line = result.fields.get('payment_line')
if not payment_line:
return
cv = CrossValidationResult()
cv.details = []
# Parse machine-readable payment line format
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
cv.payment_line_ocr = ocr
cv.payment_line_amount = amount
# Determine account type based on digit count
if account:
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
if len(account) >= 7:
cv.payment_line_account_type = 'bankgiro'
# Format: XXX-XXXX or XXXX-XXXX
if len(account) == 7:
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
else:
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
else:
cv.payment_line_account_type = 'plusgiro'
# Format: XXXXXXX-X
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
# Cross-validate and OVERRIDE with payment_line values
# OCR: payment_line takes priority
detected_ocr = result.fields.get('OCR')
if cv.payment_line_ocr:
pl_ocr_digits = re.sub(r'\D', '', cv.payment_line_ocr)
if detected_ocr:
detected_ocr_digits = re.sub(r'\D', '', str(detected_ocr))
cv.ocr_match = pl_ocr_digits == detected_ocr_digits
if cv.ocr_match:
cv.details.append(f"OCR match: {cv.payment_line_ocr}")
else:
cv.details.append(f"OCR: payment_line={cv.payment_line_ocr} (override detected={detected_ocr})")
else:
cv.details.append(f"OCR: {cv.payment_line_ocr} (from payment_line)")
# OVERRIDE: use payment_line OCR
result.fields['OCR'] = cv.payment_line_ocr
result.confidence['OCR'] = 0.95 # High confidence for payment_line
# Amount: payment_line takes priority
detected_amount = result.fields.get('Amount')
if cv.payment_line_amount:
if detected_amount:
pl_amount = self._normalize_amount_for_compare(cv.payment_line_amount)
det_amount = self._normalize_amount_for_compare(str(detected_amount))
cv.amount_match = pl_amount == det_amount
if cv.amount_match:
cv.details.append(f"Amount match: {cv.payment_line_amount}")
else:
cv.details.append(f"Amount: payment_line={cv.payment_line_amount} (override detected={detected_amount})")
else:
cv.details.append(f"Amount: {cv.payment_line_amount} (from payment_line)")
# OVERRIDE: use payment_line Amount
result.fields['Amount'] = cv.payment_line_amount
result.confidence['Amount'] = 0.95
# Bankgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_bankgiro = result.fields.get('Bankgiro')
if cv.payment_line_account_type == 'bankgiro' and cv.payment_line_account:
pl_bg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_bankgiro:
det_bg_digits = re.sub(r'\D', '', str(detected_bankgiro))
cv.bankgiro_match = pl_bg_digits == det_bg_digits
if cv.bankgiro_match:
cv.details.append(f"Bankgiro match confirmed: {detected_bankgiro}")
else:
cv.details.append(f"Bankgiro mismatch: detected={detected_bankgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Plusgiro: compare only, do NOT override (payment_line account detection is unreliable)
detected_plusgiro = result.fields.get('Plusgiro')
if cv.payment_line_account_type == 'plusgiro' and cv.payment_line_account:
pl_pg_digits = re.sub(r'\D', '', cv.payment_line_account)
if detected_plusgiro:
det_pg_digits = re.sub(r'\D', '', str(detected_plusgiro))
cv.plusgiro_match = pl_pg_digits == det_pg_digits
if cv.plusgiro_match:
cv.details.append(f"Plusgiro match confirmed: {detected_plusgiro}")
else:
cv.details.append(f"Plusgiro mismatch: detected={detected_plusgiro}, payment_line={cv.payment_line_account}")
# Do NOT override - keep detected value
# Determine overall validity
# Note: payment_line only contains ONE account (either BG or PG), so when invoice
# has both accounts, the other one cannot be matched - this is expected and OK.
# Only count the account type that payment_line actually has.
matches = [cv.ocr_match, cv.amount_match]
# Only include account match if payment_line has that account type
if cv.payment_line_account_type == 'bankgiro' and cv.bankgiro_match is not None:
matches.append(cv.bankgiro_match)
elif cv.payment_line_account_type == 'plusgiro' and cv.plusgiro_match is not None:
matches.append(cv.plusgiro_match)
valid_matches = [m for m in matches if m is not None]
if valid_matches:
match_count = sum(1 for m in valid_matches if m)
cv.is_valid = match_count >= min(2, len(valid_matches))
cv.details.append(f"Validation: {match_count}/{len(valid_matches)} fields match")
else:
# No comparison possible
cv.is_valid = True
cv.details.append("No comparison available from payment_line")
result.cross_validation = cv
def _normalize_amount_for_compare(self, amount: str) -> float | None:
"""Normalize amount string to float for comparison."""
try:
# Remove spaces, convert comma to dot
cleaned = amount.replace(' ', '').replace(',', '.')
# Handle Swedish format with space as thousands separator
cleaned = re.sub(r'(\d)\s+(\d)', r'\1\2', cleaned)
return round(float(cleaned), 2)
except (ValueError, AttributeError):
return None
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
from shared.pdf.renderer import render_pdf_to_images
from shared.ocr import OCREngine
from PIL import Image
import io
import numpy as np
result.fallback_used = True
ocr_engine = OCREngine()
try:
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Full page OCR
tokens = ocr_engine.extract_from_image(image_array, page_no)
full_text = ' '.join(t.text for t in tokens)
# Try to extract missing fields with regex patterns
self._extract_with_patterns(full_text, result)
except Exception as e:
result.errors.append(f"Fallback OCR error: {e}")
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
],
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
}
for field_name, field_patterns in patterns.items():
if field_name in result.fields:
continue
for pattern in field_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
value = match.group(1).strip()
# Normalize the value
if field_name == 'Amount':
value = value.replace(' ', '').replace(',', '.')
try:
value = f"{float(value):.2f}"
except ValueError:
continue
elif field_name == 'Bankgiro':
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex
break
def process_image(
self,
image_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a single image (for pre-rendered pages).
Args:
image_path: Path to image file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from PIL import Image
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(image_path).stem
)
try:
image = Image.open(image_path)
image_array = np.array(image)
# Run detection
detections = self.detector.detect(image_array, page_no=0)
result.raw_detections = detections
# Extract fields
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
result.extracted_fields.append(extracted)
# Merge fields
self._merge_fields(result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result