631 lines
23 KiB
Python
631 lines
23 KiB
Python
"""
|
|
Field Extractor Module
|
|
|
|
Extracts and validates field values from detected regions.
|
|
|
|
This module is used during inference to extract values from OCR text.
|
|
It uses shared utilities from shared.utils for text cleaning and validation.
|
|
|
|
Enhanced features:
|
|
- Multi-source fusion with confidence weighting
|
|
- Smart amount parsing with multiple strategies
|
|
- Enhanced date format unification
|
|
- OCR error correction integration
|
|
|
|
Refactored to use modular normalizers for each field type.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from collections import defaultdict
|
|
import re
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from shared.fields import CLASS_TO_FIELD
|
|
from .yolo_detector import Detection
|
|
|
|
# Import shared utilities for text cleaning and validation
|
|
from shared.utils.validators import FieldValidators
|
|
from shared.utils.ocr_corrections import OCRCorrections
|
|
|
|
# Import new unified parsers
|
|
from .payment_line_parser import PaymentLineParser
|
|
from .customer_number_parser import CustomerNumberParser
|
|
|
|
# Import normalizers
|
|
from .normalizers import (
|
|
BaseNormalizer,
|
|
NormalizationResult,
|
|
create_normalizer_registry,
|
|
EnhancedAmountNormalizer,
|
|
EnhancedDateNormalizer,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ExtractedField:
|
|
"""Represents an extracted field value."""
|
|
field_name: str
|
|
raw_text: str
|
|
normalized_value: str | None
|
|
confidence: float
|
|
detection_confidence: float
|
|
ocr_confidence: float
|
|
bbox: tuple[float, float, float, float]
|
|
page_no: int
|
|
is_valid: bool = True
|
|
validation_error: str | None = None
|
|
# Multi-source fusion fields
|
|
alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...]
|
|
extraction_method: str = 'single' # 'single', 'fused', 'corrected'
|
|
ocr_corrections_applied: list[str] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary."""
|
|
result = {
|
|
'field_name': self.field_name,
|
|
'value': self.normalized_value,
|
|
'raw_text': self.raw_text,
|
|
'confidence': self.confidence,
|
|
'bbox': list(self.bbox),
|
|
'page_no': self.page_no,
|
|
'is_valid': self.is_valid,
|
|
'validation_error': self.validation_error
|
|
}
|
|
if self.alternative_values:
|
|
result['alternatives'] = self.alternative_values
|
|
if self.extraction_method != 'single':
|
|
result['extraction_method'] = self.extraction_method
|
|
return result
|
|
|
|
|
|
class FieldExtractor:
|
|
"""Extracts field values from detected regions using OCR or PDF text."""
|
|
|
|
def __init__(
|
|
self,
|
|
ocr_lang: str = 'en',
|
|
use_gpu: bool = False,
|
|
bbox_padding: float = 0.1,
|
|
dpi: int = 300,
|
|
use_enhanced_parsing: bool = False
|
|
):
|
|
"""
|
|
Initialize field extractor.
|
|
|
|
Args:
|
|
ocr_lang: Language for OCR
|
|
use_gpu: Whether to use GPU for OCR
|
|
bbox_padding: Padding to add around bboxes (as fraction)
|
|
dpi: DPI used for rendering (for coordinate conversion)
|
|
use_enhanced_parsing: Whether to use enhanced normalizers
|
|
"""
|
|
self.ocr_lang = ocr_lang
|
|
self.use_gpu = use_gpu
|
|
self.bbox_padding = bbox_padding
|
|
self.dpi = dpi
|
|
self._ocr_engine = None # Lazy init
|
|
self.use_enhanced_parsing = use_enhanced_parsing
|
|
|
|
# Initialize new unified parsers
|
|
self.payment_line_parser = PaymentLineParser()
|
|
self.customer_number_parser = CustomerNumberParser()
|
|
|
|
# Initialize normalizer registry
|
|
self._normalizers = create_normalizer_registry(use_enhanced=use_enhanced_parsing)
|
|
|
|
@property
|
|
def ocr_engine(self):
|
|
"""Lazy-load OCR engine only when needed."""
|
|
if self._ocr_engine is None:
|
|
from shared.ocr import OCREngine
|
|
self._ocr_engine = OCREngine(lang=self.ocr_lang)
|
|
return self._ocr_engine
|
|
|
|
def extract_from_detection_with_pdf(
|
|
self,
|
|
detection: Detection,
|
|
pdf_tokens: list,
|
|
image_width: int,
|
|
image_height: int
|
|
) -> ExtractedField:
|
|
"""
|
|
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
|
|
|
|
Args:
|
|
detection: Detection object with bbox in pixel coordinates
|
|
pdf_tokens: List of Token objects from PDF text extraction
|
|
image_width: Width of rendered image in pixels
|
|
image_height: Height of rendered image in pixels
|
|
|
|
Returns:
|
|
ExtractedField object
|
|
"""
|
|
# Convert detection bbox from pixels to PDF points
|
|
scale = 72 / self.dpi # points per pixel
|
|
x0_pdf = detection.bbox[0] * scale
|
|
y0_pdf = detection.bbox[1] * scale
|
|
x1_pdf = detection.bbox[2] * scale
|
|
y1_pdf = detection.bbox[3] * scale
|
|
|
|
# Add padding in points
|
|
pad = 3 # Small padding in points
|
|
|
|
# Find tokens that overlap with detection bbox
|
|
matching_tokens = []
|
|
for token in pdf_tokens:
|
|
if token.page_no != detection.page_no:
|
|
continue
|
|
tx0, ty0, tx1, ty1 = token.bbox
|
|
# Check overlap
|
|
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
|
|
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
|
|
# Calculate overlap ratio to prioritize better matches
|
|
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
|
|
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
|
|
if overlap_x > 0 and overlap_y > 0:
|
|
token_area = (tx1 - tx0) * (ty1 - ty0)
|
|
overlap_area = overlap_x * overlap_y
|
|
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
|
|
matching_tokens.append((token, overlap_ratio))
|
|
|
|
# Sort by overlap ratio and combine text
|
|
matching_tokens.sort(key=lambda x: -x[1])
|
|
raw_text = ' '.join(t[0].text for t in matching_tokens)
|
|
|
|
# Get field name
|
|
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
|
|
|
# Normalize and validate
|
|
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
|
field_name, raw_text
|
|
)
|
|
|
|
return ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=raw_text,
|
|
normalized_value=normalized_value,
|
|
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
|
|
detection_confidence=detection.confidence,
|
|
ocr_confidence=1.0, # PDF text is always accurate
|
|
bbox=detection.bbox,
|
|
page_no=detection.page_no,
|
|
is_valid=is_valid,
|
|
validation_error=validation_error
|
|
)
|
|
|
|
def extract_from_detection(
|
|
self,
|
|
detection: Detection,
|
|
image: np.ndarray | Image.Image
|
|
) -> ExtractedField:
|
|
"""
|
|
Extract field value from a detection region using OCR.
|
|
|
|
Args:
|
|
detection: Detection object
|
|
image: Full page image
|
|
|
|
Returns:
|
|
ExtractedField object
|
|
"""
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
# Get padded bbox
|
|
h, w = image.shape[:2]
|
|
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
|
|
|
|
# Crop region
|
|
x0, y0, x1, y1 = [int(v) for v in bbox]
|
|
region = image[y0:y1, x0:x1]
|
|
|
|
# Run OCR on region
|
|
ocr_tokens = self.ocr_engine.extract_from_image(region)
|
|
|
|
# Combine all OCR text
|
|
raw_text = ' '.join(t.text for t in ocr_tokens)
|
|
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
|
|
|
# Get field name
|
|
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
|
|
|
# Normalize and validate
|
|
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
|
field_name, raw_text
|
|
)
|
|
|
|
# Combined confidence
|
|
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
|
|
|
|
return ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=raw_text,
|
|
normalized_value=normalized_value,
|
|
confidence=confidence,
|
|
detection_confidence=detection.confidence,
|
|
ocr_confidence=ocr_confidence,
|
|
bbox=detection.bbox,
|
|
page_no=detection.page_no,
|
|
is_valid=is_valid,
|
|
validation_error=validation_error
|
|
)
|
|
|
|
def _normalize_and_validate(
|
|
self,
|
|
field_name: str,
|
|
raw_text: str
|
|
) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize and validate extracted text for a field.
|
|
|
|
Uses modular normalizers for each field type.
|
|
Falls back to legacy methods for payment_line and customer_number.
|
|
|
|
Returns:
|
|
(normalized_value, is_valid, validation_error)
|
|
"""
|
|
text = raw_text.strip()
|
|
|
|
if not text:
|
|
return None, False, "Empty text"
|
|
|
|
# Special handling for payment_line and customer_number (use unified parsers)
|
|
if field_name == 'payment_line':
|
|
return self._normalize_payment_line(text)
|
|
|
|
if field_name == 'customer_number':
|
|
return self._normalize_customer_number(text)
|
|
|
|
# Use normalizer registry for other fields
|
|
normalizer = self._normalizers.get(field_name)
|
|
if normalizer:
|
|
result = normalizer.normalize(text)
|
|
return result.to_tuple()
|
|
|
|
# Fallback for unknown fields
|
|
return text, True, None
|
|
|
|
def _normalize_payment_line(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize payment line region text using unified PaymentLineParser.
|
|
|
|
Extracts the machine-readable payment line format from OCR text.
|
|
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
|
|
|
Examples:
|
|
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
|
|
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
|
|
|
|
Returns normalized format preserving ALL components including Amount.
|
|
This allows downstream cross-validation to extract fields properly.
|
|
"""
|
|
# Use unified payment line parser
|
|
return self.payment_line_parser.format_for_field_extractor(
|
|
self.payment_line_parser.parse(text)
|
|
)
|
|
|
|
def _normalize_customer_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
|
"""
|
|
Normalize customer number text using unified CustomerNumberParser.
|
|
|
|
Supports various Swedish customer number formats:
|
|
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
|
|
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
|
|
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
|
|
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
|
|
"""
|
|
return self.customer_number_parser.parse(text)
|
|
|
|
def extract_all_fields(
|
|
self,
|
|
detections: list[Detection],
|
|
image: np.ndarray | Image.Image
|
|
) -> list[ExtractedField]:
|
|
"""
|
|
Extract fields from all detections.
|
|
|
|
Args:
|
|
detections: List of detections
|
|
image: Full page image
|
|
|
|
Returns:
|
|
List of ExtractedField objects
|
|
"""
|
|
fields = []
|
|
|
|
for detection in detections:
|
|
field = self.extract_from_detection(detection, image)
|
|
fields.append(field)
|
|
|
|
return fields
|
|
|
|
@staticmethod
|
|
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
|
|
"""
|
|
Infer OCR field from InvoiceNumber if not detected.
|
|
|
|
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
|
|
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
|
|
|
|
Args:
|
|
fields: Dict of field_name -> normalized_value
|
|
|
|
Returns:
|
|
Updated fields dict with inferred OCR if applicable
|
|
"""
|
|
# If OCR already exists, no need to infer
|
|
if fields.get('OCR'):
|
|
return fields
|
|
|
|
# If InvoiceNumber exists and is numeric, use it as OCR
|
|
invoice_number = fields.get('InvoiceNumber')
|
|
if invoice_number:
|
|
# Check if it's mostly digits (valid OCR reference)
|
|
digits_only = re.sub(r'\D', '', invoice_number)
|
|
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
|
|
fields['OCR'] = invoice_number
|
|
|
|
return fields
|
|
|
|
# =========================================================================
|
|
# Multi-Source Fusion with Confidence Weighting
|
|
# =========================================================================
|
|
|
|
def fuse_multiple_detections(
|
|
self,
|
|
extracted_fields: list[ExtractedField]
|
|
) -> list[ExtractedField]:
|
|
"""
|
|
Fuse multiple detections of the same field using confidence-weighted voting.
|
|
|
|
When YOLO detects the same field type multiple times (e.g., multiple Amount boxes),
|
|
this method selects the best value or combines them intelligently.
|
|
|
|
Strategies:
|
|
1. For numeric fields (Amount, OCR): prefer values that pass validation
|
|
2. For date fields: prefer values in expected range
|
|
3. For giro numbers: prefer values with valid Luhn checksum
|
|
4. General: weighted vote by confidence scores
|
|
|
|
Args:
|
|
extracted_fields: List of all extracted fields (may have duplicates)
|
|
|
|
Returns:
|
|
List with duplicates resolved to single best value per field
|
|
"""
|
|
# Group fields by name
|
|
fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list)
|
|
for field in extracted_fields:
|
|
fields_by_name[field.field_name].append(field)
|
|
|
|
fused_fields = []
|
|
|
|
for field_name, candidates in fields_by_name.items():
|
|
if len(candidates) == 1:
|
|
# No fusion needed
|
|
fused_fields.append(candidates[0])
|
|
else:
|
|
# Multiple candidates - fuse them
|
|
fused = self._fuse_field_candidates(field_name, candidates)
|
|
fused_fields.append(fused)
|
|
|
|
return fused_fields
|
|
|
|
def _fuse_field_candidates(
|
|
self,
|
|
field_name: str,
|
|
candidates: list[ExtractedField]
|
|
) -> ExtractedField:
|
|
"""
|
|
Fuse multiple candidates for a single field.
|
|
|
|
Returns the best candidate with alternatives recorded.
|
|
"""
|
|
# Sort by confidence (descending)
|
|
sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True)
|
|
|
|
# Collect all unique values with their max confidence
|
|
value_scores: dict[str, tuple[float, ExtractedField]] = {}
|
|
for c in sorted_candidates:
|
|
if c.normalized_value:
|
|
if c.normalized_value not in value_scores:
|
|
value_scores[c.normalized_value] = (c.confidence, c)
|
|
else:
|
|
# Keep the higher confidence one
|
|
if c.confidence > value_scores[c.normalized_value][0]:
|
|
value_scores[c.normalized_value] = (c.confidence, c)
|
|
|
|
if not value_scores:
|
|
# No valid values, return the highest confidence candidate
|
|
return sorted_candidates[0]
|
|
|
|
# Field-specific fusion strategy
|
|
best_value, best_field = self._select_best_value(field_name, value_scores)
|
|
|
|
# Record alternatives
|
|
alternatives = [
|
|
(v, score) for v, (score, _) in value_scores.items()
|
|
if v != best_value
|
|
]
|
|
|
|
# Create fused result
|
|
result = ExtractedField(
|
|
field_name=field_name,
|
|
raw_text=best_field.raw_text,
|
|
normalized_value=best_value,
|
|
confidence=value_scores[best_value][0],
|
|
detection_confidence=best_field.detection_confidence,
|
|
ocr_confidence=best_field.ocr_confidence,
|
|
bbox=best_field.bbox,
|
|
page_no=best_field.page_no,
|
|
is_valid=best_field.is_valid,
|
|
validation_error=best_field.validation_error,
|
|
alternative_values=alternatives,
|
|
extraction_method='fused' if len(value_scores) > 1 else 'single'
|
|
)
|
|
|
|
return result
|
|
|
|
def _select_best_value(
|
|
self,
|
|
field_name: str,
|
|
value_scores: dict[str, tuple[float, ExtractedField]]
|
|
) -> tuple[str, ExtractedField]:
|
|
"""
|
|
Select the best value for a field using field-specific logic.
|
|
|
|
Returns (best_value, best_field)
|
|
"""
|
|
items = list(value_scores.items())
|
|
|
|
# Field-specific selection
|
|
if field_name in ('Bankgiro', 'Plusgiro', 'OCR'):
|
|
# Prefer values with valid Luhn checksum
|
|
for value, (score, field) in items:
|
|
digits = re.sub(r'\D', '', value)
|
|
if FieldValidators.luhn_checksum(digits):
|
|
return value, field
|
|
|
|
elif field_name == 'Amount':
|
|
# Prefer larger amounts (usually the total, not subtotals)
|
|
amounts = []
|
|
for value, (score, field) in items:
|
|
try:
|
|
amt = float(value.replace(',', '.'))
|
|
amounts.append((amt, value, field))
|
|
except ValueError:
|
|
continue
|
|
if amounts:
|
|
# Return the largest amount
|
|
amounts.sort(reverse=True)
|
|
return amounts[0][1], amounts[0][2]
|
|
|
|
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
|
# Prefer dates in reasonable range
|
|
from datetime import datetime
|
|
for value, (score, field) in items:
|
|
try:
|
|
dt = datetime.strptime(value, '%Y-%m-%d')
|
|
# Prefer recent dates (within last 2 years and next 1 year)
|
|
now = datetime.now()
|
|
if now.year - 2 <= dt.year <= now.year + 1:
|
|
return value, field
|
|
except ValueError:
|
|
continue
|
|
|
|
# Default: return highest confidence value
|
|
best = max(items, key=lambda x: x[1][0])
|
|
return best[0], best[1][1]
|
|
|
|
# =========================================================================
|
|
# Apply OCR Corrections to Raw Text
|
|
# =========================================================================
|
|
|
|
def apply_ocr_corrections(
|
|
self,
|
|
field_name: str,
|
|
raw_text: str
|
|
) -> tuple[str, list[str]]:
|
|
"""
|
|
Apply OCR corrections to raw text based on field type.
|
|
|
|
Returns (corrected_text, list_of_corrections_applied)
|
|
"""
|
|
corrections_applied = []
|
|
|
|
if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'):
|
|
# Aggressive correction for numeric fields
|
|
result = OCRCorrections.correct_digits(raw_text, aggressive=True)
|
|
if result.corrections_applied:
|
|
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
|
return result.corrected, corrections_applied
|
|
|
|
elif field_name == 'Amount':
|
|
# Conservative correction for amounts (preserve decimal separators)
|
|
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
|
|
if result.corrections_applied:
|
|
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
|
return result.corrected, corrections_applied
|
|
|
|
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
|
# Conservative correction for dates
|
|
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
|
|
if result.corrections_applied:
|
|
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
|
return result.corrected, corrections_applied
|
|
|
|
# No correction for other fields
|
|
return raw_text, []
|
|
|
|
# =========================================================================
|
|
# Extraction with All Enhancements
|
|
# =========================================================================
|
|
|
|
def extract_with_enhancements(
|
|
self,
|
|
detection: Detection,
|
|
pdf_tokens: list,
|
|
image_width: int,
|
|
image_height: int,
|
|
use_enhanced_parsing: bool = True
|
|
) -> ExtractedField:
|
|
"""
|
|
Extract field value with all enhancements enabled.
|
|
|
|
Combines:
|
|
1. OCR error correction
|
|
2. Enhanced amount/date parsing
|
|
3. Multi-strategy extraction
|
|
|
|
Args:
|
|
detection: Detection object
|
|
pdf_tokens: PDF text tokens
|
|
image_width: Image width in pixels
|
|
image_height: Image height in pixels
|
|
use_enhanced_parsing: Whether to use enhanced parsing methods
|
|
|
|
Returns:
|
|
ExtractedField with enhancements applied
|
|
"""
|
|
# First, extract using standard method
|
|
base_result = self.extract_from_detection_with_pdf(
|
|
detection, pdf_tokens, image_width, image_height
|
|
)
|
|
|
|
if not use_enhanced_parsing:
|
|
return base_result
|
|
|
|
# Apply OCR corrections
|
|
corrected_text, corrections = self.apply_ocr_corrections(
|
|
base_result.field_name, base_result.raw_text
|
|
)
|
|
|
|
# Re-normalize with enhanced methods if corrections were applied
|
|
if corrections or base_result.normalized_value is None:
|
|
# Use enhanced normalizers for Amount and Date fields
|
|
if base_result.field_name == 'Amount':
|
|
enhanced_normalizer = EnhancedAmountNormalizer()
|
|
result = enhanced_normalizer.normalize(corrected_text)
|
|
normalized, is_valid, error = result.to_tuple()
|
|
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
|
enhanced_normalizer = EnhancedDateNormalizer()
|
|
result = enhanced_normalizer.normalize(corrected_text)
|
|
normalized, is_valid, error = result.to_tuple()
|
|
else:
|
|
# Re-run standard normalization with corrected text
|
|
normalized, is_valid, error = self._normalize_and_validate(
|
|
base_result.field_name, corrected_text
|
|
)
|
|
|
|
# Update result if we got a better value
|
|
if normalized and (not base_result.normalized_value or is_valid):
|
|
base_result.normalized_value = normalized
|
|
base_result.is_valid = is_valid
|
|
base_result.validation_error = error
|
|
base_result.ocr_corrections_applied = corrections
|
|
if corrections:
|
|
base_result.extraction_method = 'corrected'
|
|
|
|
return base_result
|