Initial commit: Invoice field extraction system using YOLO + OCR
Features: - Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations - Flexible date matching: year-month match, nearby date tolerance - PDF text extraction with PyMuPDF - OCR support for scanned documents (PaddleOCR) - YOLO training and inference pipeline - 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
382
src/inference/field_extractor.py
Normal file
382
src/inference/field_extractor.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
Field Extractor Module
|
||||
|
||||
Extracts and validates field values from detected regions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import re
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .yolo_detector import Detection, CLASS_TO_FIELD
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedField:
|
||||
"""Represents an extracted field value."""
|
||||
field_name: str
|
||||
raw_text: str
|
||||
normalized_value: str | None
|
||||
confidence: float
|
||||
detection_confidence: float
|
||||
ocr_confidence: float
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
is_valid: bool = True
|
||||
validation_error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'field_name': self.field_name,
|
||||
'value': self.normalized_value,
|
||||
'raw_text': self.raw_text,
|
||||
'confidence': self.confidence,
|
||||
'bbox': list(self.bbox),
|
||||
'page_no': self.page_no,
|
||||
'is_valid': self.is_valid,
|
||||
'validation_error': self.validation_error
|
||||
}
|
||||
|
||||
|
||||
class FieldExtractor:
|
||||
"""Extracts field values from detected regions using OCR or PDF text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_lang: str = 'en',
|
||||
use_gpu: bool = False,
|
||||
bbox_padding: float = 0.1,
|
||||
dpi: int = 300
|
||||
):
|
||||
"""
|
||||
Initialize field extractor.
|
||||
|
||||
Args:
|
||||
ocr_lang: Language for OCR
|
||||
use_gpu: Whether to use GPU for OCR
|
||||
bbox_padding: Padding to add around bboxes (as fraction)
|
||||
dpi: DPI used for rendering (for coordinate conversion)
|
||||
"""
|
||||
self.ocr_lang = ocr_lang
|
||||
self.use_gpu = use_gpu
|
||||
self.bbox_padding = bbox_padding
|
||||
self.dpi = dpi
|
||||
self._ocr_engine = None # Lazy init
|
||||
|
||||
@property
|
||||
def ocr_engine(self):
|
||||
"""Lazy-load OCR engine only when needed."""
|
||||
if self._ocr_engine is None:
|
||||
from ..ocr import OCREngine
|
||||
self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu)
|
||||
return self._ocr_engine
|
||||
|
||||
def extract_from_detection_with_pdf(
|
||||
self,
|
||||
detection: Detection,
|
||||
pdf_tokens: list,
|
||||
image_width: int,
|
||||
image_height: int
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
|
||||
|
||||
Args:
|
||||
detection: Detection object with bbox in pixel coordinates
|
||||
pdf_tokens: List of Token objects from PDF text extraction
|
||||
image_width: Width of rendered image in pixels
|
||||
image_height: Height of rendered image in pixels
|
||||
|
||||
Returns:
|
||||
ExtractedField object
|
||||
"""
|
||||
# Convert detection bbox from pixels to PDF points
|
||||
scale = 72 / self.dpi # points per pixel
|
||||
x0_pdf = detection.bbox[0] * scale
|
||||
y0_pdf = detection.bbox[1] * scale
|
||||
x1_pdf = detection.bbox[2] * scale
|
||||
y1_pdf = detection.bbox[3] * scale
|
||||
|
||||
# Add padding in points
|
||||
pad = 3 # Small padding in points
|
||||
|
||||
# Find tokens that overlap with detection bbox
|
||||
matching_tokens = []
|
||||
for token in pdf_tokens:
|
||||
if token.page_no != detection.page_no:
|
||||
continue
|
||||
tx0, ty0, tx1, ty1 = token.bbox
|
||||
# Check overlap
|
||||
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
|
||||
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
|
||||
# Calculate overlap ratio to prioritize better matches
|
||||
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
|
||||
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
|
||||
if overlap_x > 0 and overlap_y > 0:
|
||||
token_area = (tx1 - tx0) * (ty1 - ty0)
|
||||
overlap_area = overlap_x * overlap_y
|
||||
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
|
||||
matching_tokens.append((token, overlap_ratio))
|
||||
|
||||
# Sort by overlap ratio and combine text
|
||||
matching_tokens.sort(key=lambda x: -x[1])
|
||||
raw_text = ' '.join(t[0].text for t in matching_tokens)
|
||||
|
||||
# Get field name
|
||||
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||
|
||||
# Normalize and validate
|
||||
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||
field_name, raw_text
|
||||
)
|
||||
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=raw_text,
|
||||
normalized_value=normalized_value,
|
||||
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
|
||||
detection_confidence=detection.confidence,
|
||||
ocr_confidence=1.0, # PDF text is always accurate
|
||||
bbox=detection.bbox,
|
||||
page_no=detection.page_no,
|
||||
is_valid=is_valid,
|
||||
validation_error=validation_error
|
||||
)
|
||||
|
||||
def extract_from_detection(
|
||||
self,
|
||||
detection: Detection,
|
||||
image: np.ndarray | Image.Image
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Extract field value from a detection region using OCR.
|
||||
|
||||
Args:
|
||||
detection: Detection object
|
||||
image: Full page image
|
||||
|
||||
Returns:
|
||||
ExtractedField object
|
||||
"""
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
|
||||
# Get padded bbox
|
||||
h, w = image.shape[:2]
|
||||
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
|
||||
|
||||
# Crop region
|
||||
x0, y0, x1, y1 = [int(v) for v in bbox]
|
||||
region = image[y0:y1, x0:x1]
|
||||
|
||||
# Run OCR on region
|
||||
ocr_tokens = self.ocr_engine.extract_from_image(region)
|
||||
|
||||
# Combine all OCR text
|
||||
raw_text = ' '.join(t.text for t in ocr_tokens)
|
||||
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
||||
|
||||
# Get field name
|
||||
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||
|
||||
# Normalize and validate
|
||||
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||
field_name, raw_text
|
||||
)
|
||||
|
||||
# Combined confidence
|
||||
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
|
||||
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=raw_text,
|
||||
normalized_value=normalized_value,
|
||||
confidence=confidence,
|
||||
detection_confidence=detection.confidence,
|
||||
ocr_confidence=ocr_confidence,
|
||||
bbox=detection.bbox,
|
||||
page_no=detection.page_no,
|
||||
is_valid=is_valid,
|
||||
validation_error=validation_error
|
||||
)
|
||||
|
||||
def _normalize_and_validate(
|
||||
self,
|
||||
field_name: str,
|
||||
raw_text: str
|
||||
) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Normalize and validate extracted text for a field.
|
||||
|
||||
Returns:
|
||||
(normalized_value, is_valid, validation_error)
|
||||
"""
|
||||
text = raw_text.strip()
|
||||
|
||||
if not text:
|
||||
return None, False, "Empty text"
|
||||
|
||||
if field_name == 'InvoiceNumber':
|
||||
return self._normalize_invoice_number(text)
|
||||
|
||||
elif field_name == 'OCR':
|
||||
return self._normalize_ocr_number(text)
|
||||
|
||||
elif field_name == 'Bankgiro':
|
||||
return self._normalize_bankgiro(text)
|
||||
|
||||
elif field_name == 'Plusgiro':
|
||||
return self._normalize_plusgiro(text)
|
||||
|
||||
elif field_name == 'Amount':
|
||||
return self._normalize_amount(text)
|
||||
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
return self._normalize_date(text)
|
||||
|
||||
else:
|
||||
return text, True, None
|
||||
|
||||
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize invoice number."""
|
||||
# Extract digits only
|
||||
digits = re.sub(r'\D', '', text)
|
||||
|
||||
if len(digits) < 3:
|
||||
return None, False, f"Too few digits: {len(digits)}"
|
||||
|
||||
return digits, True, None
|
||||
|
||||
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize OCR number."""
|
||||
digits = re.sub(r'\D', '', text)
|
||||
|
||||
if len(digits) < 5:
|
||||
return None, False, f"Too few digits for OCR: {len(digits)}"
|
||||
|
||||
return digits, True, None
|
||||
|
||||
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize Bankgiro number."""
|
||||
digits = re.sub(r'\D', '', text)
|
||||
|
||||
if len(digits) == 8:
|
||||
# Format as XXXX-XXXX
|
||||
formatted = f"{digits[:4]}-{digits[4:]}"
|
||||
return formatted, True, None
|
||||
elif len(digits) == 7:
|
||||
# Format as XXX-XXXX
|
||||
formatted = f"{digits[:3]}-{digits[3:]}"
|
||||
return formatted, True, None
|
||||
elif 6 <= len(digits) <= 9:
|
||||
return digits, True, None
|
||||
else:
|
||||
return None, False, f"Invalid Bankgiro length: {len(digits)}"
|
||||
|
||||
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize Plusgiro number."""
|
||||
digits = re.sub(r'\D', '', text)
|
||||
|
||||
if len(digits) >= 6:
|
||||
# Format as XXXXXXX-X
|
||||
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||
return formatted, True, None
|
||||
else:
|
||||
return None, False, f"Invalid Plusgiro length: {len(digits)}"
|
||||
|
||||
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize monetary amount."""
|
||||
# Remove currency and common suffixes
|
||||
text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE)
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
if ',' in text and '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
|
||||
# Try to parse as float
|
||||
try:
|
||||
amount = float(text)
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
return None, False, f"Cannot parse amount: {text}"
|
||||
|
||||
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize date."""
|
||||
from datetime import datetime
|
||||
|
||||
# Common date patterns
|
||||
patterns = [
|
||||
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"),
|
||||
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
|
||||
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
|
||||
(r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"),
|
||||
]
|
||||
|
||||
for pattern, formatter in patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
try:
|
||||
date_str = formatter(match)
|
||||
# Validate date
|
||||
datetime.strptime(date_str, '%Y-%m-%d')
|
||||
return date_str, True, None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None, False, f"Cannot parse date: {text}"
|
||||
|
||||
def extract_all_fields(
|
||||
self,
|
||||
detections: list[Detection],
|
||||
image: np.ndarray | Image.Image
|
||||
) -> list[ExtractedField]:
|
||||
"""
|
||||
Extract fields from all detections.
|
||||
|
||||
Args:
|
||||
detections: List of detections
|
||||
image: Full page image
|
||||
|
||||
Returns:
|
||||
List of ExtractedField objects
|
||||
"""
|
||||
fields = []
|
||||
|
||||
for detection in detections:
|
||||
field = self.extract_from_detection(detection, image)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
@staticmethod
|
||||
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Infer OCR field from InvoiceNumber if not detected.
|
||||
|
||||
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
|
||||
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
|
||||
|
||||
Args:
|
||||
fields: Dict of field_name -> normalized_value
|
||||
|
||||
Returns:
|
||||
Updated fields dict with inferred OCR if applicable
|
||||
"""
|
||||
# If OCR already exists, no need to infer
|
||||
if fields.get('OCR'):
|
||||
return fields
|
||||
|
||||
# If InvoiceNumber exists and is numeric, use it as OCR
|
||||
invoice_number = fields.get('InvoiceNumber')
|
||||
if invoice_number:
|
||||
# Check if it's mostly digits (valid OCR reference)
|
||||
digits_only = re.sub(r'\D', '', invoice_number)
|
||||
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
|
||||
fields['OCR'] = invoice_number
|
||||
|
||||
return fields
|
||||
Reference in New Issue
Block a user