Files
invoice-master-poc-v2/packages/inference/inference/pipeline/field_extractor.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

631 lines
23 KiB
Python

"""
Field Extractor Module
Extracts and validates field values from detected regions.
This module is used during inference to extract values from OCR text.
It uses shared utilities from shared.utils for text cleaning and validation.
Enhanced features:
- Multi-source fusion with confidence weighting
- Smart amount parsing with multiple strategies
- Enhanced date format unification
- OCR error correction integration
Refactored to use modular normalizers for each field type.
"""
from dataclasses import dataclass, field
from collections import defaultdict
import re
import numpy as np
from PIL import Image
from shared.fields import CLASS_TO_FIELD
from .yolo_detector import Detection
# Import shared utilities for text cleaning and validation
from shared.utils.validators import FieldValidators
from shared.utils.ocr_corrections import OCRCorrections
# Import new unified parsers
from .payment_line_parser import PaymentLineParser
from .customer_number_parser import CustomerNumberParser
# Import normalizers
from .normalizers import (
BaseNormalizer,
NormalizationResult,
create_normalizer_registry,
EnhancedAmountNormalizer,
EnhancedDateNormalizer,
)
@dataclass
class ExtractedField:
"""Represents an extracted field value."""
field_name: str
raw_text: str
normalized_value: str | None
confidence: float
detection_confidence: float
ocr_confidence: float
bbox: tuple[float, float, float, float]
page_no: int
is_valid: bool = True
validation_error: str | None = None
# Multi-source fusion fields
alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...]
extraction_method: str = 'single' # 'single', 'fused', 'corrected'
ocr_corrections_applied: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
result = {
'field_name': self.field_name,
'value': self.normalized_value,
'raw_text': self.raw_text,
'confidence': self.confidence,
'bbox': list(self.bbox),
'page_no': self.page_no,
'is_valid': self.is_valid,
'validation_error': self.validation_error
}
if self.alternative_values:
result['alternatives'] = self.alternative_values
if self.extraction_method != 'single':
result['extraction_method'] = self.extraction_method
return result
class FieldExtractor:
"""Extracts field values from detected regions using OCR or PDF text."""
def __init__(
self,
ocr_lang: str = 'en',
use_gpu: bool = False,
bbox_padding: float = 0.1,
dpi: int = 300,
use_enhanced_parsing: bool = False
):
"""
Initialize field extractor.
Args:
ocr_lang: Language for OCR
use_gpu: Whether to use GPU for OCR
bbox_padding: Padding to add around bboxes (as fraction)
dpi: DPI used for rendering (for coordinate conversion)
use_enhanced_parsing: Whether to use enhanced normalizers
"""
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self.bbox_padding = bbox_padding
self.dpi = dpi
self._ocr_engine = None # Lazy init
self.use_enhanced_parsing = use_enhanced_parsing
# Initialize new unified parsers
self.payment_line_parser = PaymentLineParser()
self.customer_number_parser = CustomerNumberParser()
# Initialize normalizer registry
self._normalizers = create_normalizer_registry(use_enhanced=use_enhanced_parsing)
@property
def ocr_engine(self):
"""Lazy-load OCR engine only when needed."""
if self._ocr_engine is None:
from shared.ocr import OCREngine
self._ocr_engine = OCREngine(lang=self.ocr_lang)
return self._ocr_engine
def extract_from_detection_with_pdf(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int
) -> ExtractedField:
"""
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
Args:
detection: Detection object with bbox in pixel coordinates
pdf_tokens: List of Token objects from PDF text extraction
image_width: Width of rendered image in pixels
image_height: Height of rendered image in pixels
Returns:
ExtractedField object
"""
# Convert detection bbox from pixels to PDF points
scale = 72 / self.dpi # points per pixel
x0_pdf = detection.bbox[0] * scale
y0_pdf = detection.bbox[1] * scale
x1_pdf = detection.bbox[2] * scale
y1_pdf = detection.bbox[3] * scale
# Add padding in points
pad = 3 # Small padding in points
# Find tokens that overlap with detection bbox
matching_tokens = []
for token in pdf_tokens:
if token.page_no != detection.page_no:
continue
tx0, ty0, tx1, ty1 = token.bbox
# Check overlap
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
# Calculate overlap ratio to prioritize better matches
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
if overlap_x > 0 and overlap_y > 0:
token_area = (tx1 - tx0) * (ty1 - ty0)
overlap_area = overlap_x * overlap_y
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
detection_confidence=detection.confidence,
ocr_confidence=1.0, # PDF text is always accurate
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def extract_from_detection(
self,
detection: Detection,
image: np.ndarray | Image.Image
) -> ExtractedField:
"""
Extract field value from a detection region using OCR.
Args:
detection: Detection object
image: Full page image
Returns:
ExtractedField object
"""
if isinstance(image, Image.Image):
image = np.array(image)
# Get padded bbox
h, w = image.shape[:2]
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
# Crop region
x0, y0, x1, y1 = [int(v) for v in bbox]
region = image[y0:y1, x0:x1]
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
# Combined confidence
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=confidence,
detection_confidence=detection.confidence,
ocr_confidence=ocr_confidence,
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def _normalize_and_validate(
self,
field_name: str,
raw_text: str
) -> tuple[str | None, bool, str | None]:
"""
Normalize and validate extracted text for a field.
Uses modular normalizers for each field type.
Falls back to legacy methods for payment_line and customer_number.
Returns:
(normalized_value, is_valid, validation_error)
"""
text = raw_text.strip()
if not text:
return None, False, "Empty text"
# Special handling for payment_line and customer_number (use unified parsers)
if field_name == 'payment_line':
return self._normalize_payment_line(text)
if field_name == 'customer_number':
return self._normalize_customer_number(text)
# Use normalizer registry for other fields
normalizer = self._normalizers.get(field_name)
if normalizer:
result = normalizer.normalize(text)
return result.to_tuple()
# Fallback for unknown fields
return text, True, None
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize payment line region text using unified PaymentLineParser.
Extracts the machine-readable payment line format from OCR text.
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Examples:
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
Returns normalized format preserving ALL components including Amount.
This allows downstream cross-validation to extract fields properly.
"""
# Use unified payment line parser
return self.payment_line_parser.format_for_field_extractor(
self.payment_line_parser.parse(text)
)
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize customer number text using unified CustomerNumberParser.
Supports various Swedish customer number formats:
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
"""
return self.customer_number_parser.parse(text)
def extract_all_fields(
self,
detections: list[Detection],
image: np.ndarray | Image.Image
) -> list[ExtractedField]:
"""
Extract fields from all detections.
Args:
detections: List of detections
image: Full page image
Returns:
List of ExtractedField objects
"""
fields = []
for detection in detections:
field = self.extract_from_detection(detection, image)
fields.append(field)
return fields
@staticmethod
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
"""
Infer OCR field from InvoiceNumber if not detected.
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
Args:
fields: Dict of field_name -> normalized_value
Returns:
Updated fields dict with inferred OCR if applicable
"""
# If OCR already exists, no need to infer
if fields.get('OCR'):
return fields
# If InvoiceNumber exists and is numeric, use it as OCR
invoice_number = fields.get('InvoiceNumber')
if invoice_number:
# Check if it's mostly digits (valid OCR reference)
digits_only = re.sub(r'\D', '', invoice_number)
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
fields['OCR'] = invoice_number
return fields
# =========================================================================
# Multi-Source Fusion with Confidence Weighting
# =========================================================================
def fuse_multiple_detections(
self,
extracted_fields: list[ExtractedField]
) -> list[ExtractedField]:
"""
Fuse multiple detections of the same field using confidence-weighted voting.
When YOLO detects the same field type multiple times (e.g., multiple Amount boxes),
this method selects the best value or combines them intelligently.
Strategies:
1. For numeric fields (Amount, OCR): prefer values that pass validation
2. For date fields: prefer values in expected range
3. For giro numbers: prefer values with valid Luhn checksum
4. General: weighted vote by confidence scores
Args:
extracted_fields: List of all extracted fields (may have duplicates)
Returns:
List with duplicates resolved to single best value per field
"""
# Group fields by name
fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list)
for field in extracted_fields:
fields_by_name[field.field_name].append(field)
fused_fields = []
for field_name, candidates in fields_by_name.items():
if len(candidates) == 1:
# No fusion needed
fused_fields.append(candidates[0])
else:
# Multiple candidates - fuse them
fused = self._fuse_field_candidates(field_name, candidates)
fused_fields.append(fused)
return fused_fields
def _fuse_field_candidates(
self,
field_name: str,
candidates: list[ExtractedField]
) -> ExtractedField:
"""
Fuse multiple candidates for a single field.
Returns the best candidate with alternatives recorded.
"""
# Sort by confidence (descending)
sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True)
# Collect all unique values with their max confidence
value_scores: dict[str, tuple[float, ExtractedField]] = {}
for c in sorted_candidates:
if c.normalized_value:
if c.normalized_value not in value_scores:
value_scores[c.normalized_value] = (c.confidence, c)
else:
# Keep the higher confidence one
if c.confidence > value_scores[c.normalized_value][0]:
value_scores[c.normalized_value] = (c.confidence, c)
if not value_scores:
# No valid values, return the highest confidence candidate
return sorted_candidates[0]
# Field-specific fusion strategy
best_value, best_field = self._select_best_value(field_name, value_scores)
# Record alternatives
alternatives = [
(v, score) for v, (score, _) in value_scores.items()
if v != best_value
]
# Create fused result
result = ExtractedField(
field_name=field_name,
raw_text=best_field.raw_text,
normalized_value=best_value,
confidence=value_scores[best_value][0],
detection_confidence=best_field.detection_confidence,
ocr_confidence=best_field.ocr_confidence,
bbox=best_field.bbox,
page_no=best_field.page_no,
is_valid=best_field.is_valid,
validation_error=best_field.validation_error,
alternative_values=alternatives,
extraction_method='fused' if len(value_scores) > 1 else 'single'
)
return result
def _select_best_value(
self,
field_name: str,
value_scores: dict[str, tuple[float, ExtractedField]]
) -> tuple[str, ExtractedField]:
"""
Select the best value for a field using field-specific logic.
Returns (best_value, best_field)
"""
items = list(value_scores.items())
# Field-specific selection
if field_name in ('Bankgiro', 'Plusgiro', 'OCR'):
# Prefer values with valid Luhn checksum
for value, (score, field) in items:
digits = re.sub(r'\D', '', value)
if FieldValidators.luhn_checksum(digits):
return value, field
elif field_name == 'Amount':
# Prefer larger amounts (usually the total, not subtotals)
amounts = []
for value, (score, field) in items:
try:
amt = float(value.replace(',', '.'))
amounts.append((amt, value, field))
except ValueError:
continue
if amounts:
# Return the largest amount
amounts.sort(reverse=True)
return amounts[0][1], amounts[0][2]
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Prefer dates in reasonable range
from datetime import datetime
for value, (score, field) in items:
try:
dt = datetime.strptime(value, '%Y-%m-%d')
# Prefer recent dates (within last 2 years and next 1 year)
now = datetime.now()
if now.year - 2 <= dt.year <= now.year + 1:
return value, field
except ValueError:
continue
# Default: return highest confidence value
best = max(items, key=lambda x: x[1][0])
return best[0], best[1][1]
# =========================================================================
# Apply OCR Corrections to Raw Text
# =========================================================================
def apply_ocr_corrections(
self,
field_name: str,
raw_text: str
) -> tuple[str, list[str]]:
"""
Apply OCR corrections to raw text based on field type.
Returns (corrected_text, list_of_corrections_applied)
"""
corrections_applied = []
if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'):
# Aggressive correction for numeric fields
result = OCRCorrections.correct_digits(raw_text, aggressive=True)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
elif field_name == 'Amount':
# Conservative correction for amounts (preserve decimal separators)
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Conservative correction for dates
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
# No correction for other fields
return raw_text, []
# =========================================================================
# Extraction with All Enhancements
# =========================================================================
def extract_with_enhancements(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int,
use_enhanced_parsing: bool = True
) -> ExtractedField:
"""
Extract field value with all enhancements enabled.
Combines:
1. OCR error correction
2. Enhanced amount/date parsing
3. Multi-strategy extraction
Args:
detection: Detection object
pdf_tokens: PDF text tokens
image_width: Image width in pixels
image_height: Image height in pixels
use_enhanced_parsing: Whether to use enhanced parsing methods
Returns:
ExtractedField with enhancements applied
"""
# First, extract using standard method
base_result = self.extract_from_detection_with_pdf(
detection, pdf_tokens, image_width, image_height
)
if not use_enhanced_parsing:
return base_result
# Apply OCR corrections
corrected_text, corrections = self.apply_ocr_corrections(
base_result.field_name, base_result.raw_text
)
# Re-normalize with enhanced methods if corrections were applied
if corrections or base_result.normalized_value is None:
# Use enhanced normalizers for Amount and Date fields
if base_result.field_name == 'Amount':
enhanced_normalizer = EnhancedAmountNormalizer()
result = enhanced_normalizer.normalize(corrected_text)
normalized, is_valid, error = result.to_tuple()
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
enhanced_normalizer = EnhancedDateNormalizer()
result = enhanced_normalizer.normalize(corrected_text)
normalized, is_valid, error = result.to_tuple()
else:
# Re-run standard normalization with corrected text
normalized, is_valid, error = self._normalize_and_validate(
base_result.field_name, corrected_text
)
# Update result if we got a better value
if normalized and (not base_result.normalized_value or is_valid):
base_result.normalized_value = normalized
base_result.is_valid = is_valid
base_result.validation_error = error
base_result.ocr_corrections_applied = corrections
if corrections:
base_result.extraction_method = 'corrected'
return base_result