Initial commit: Invoice field extraction system using YOLO + OCR

Features:
- Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations
- Flexible date matching: year-month match, nearby date tolerance
- PDF text extraction with PyMuPDF
- OCR support for scanned documents (PaddleOCR)
- YOLO training and inference pipeline
- 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Yaojia Wang
2026-01-10 17:44:14 +01:00
commit 8938661850
35 changed files with 5020 additions and 0 deletions

View File

@@ -0,0 +1,382 @@
"""
Field Extractor Module
Extracts and validates field values from detected regions.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import re
import numpy as np
from PIL import Image
from .yolo_detector import Detection, CLASS_TO_FIELD
@dataclass
class ExtractedField:
"""Represents an extracted field value."""
field_name: str
raw_text: str
normalized_value: str | None
confidence: float
detection_confidence: float
ocr_confidence: float
bbox: tuple[float, float, float, float]
page_no: int
is_valid: bool = True
validation_error: str | None = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
'field_name': self.field_name,
'value': self.normalized_value,
'raw_text': self.raw_text,
'confidence': self.confidence,
'bbox': list(self.bbox),
'page_no': self.page_no,
'is_valid': self.is_valid,
'validation_error': self.validation_error
}
class FieldExtractor:
"""Extracts field values from detected regions using OCR or PDF text."""
def __init__(
self,
ocr_lang: str = 'en',
use_gpu: bool = False,
bbox_padding: float = 0.1,
dpi: int = 300
):
"""
Initialize field extractor.
Args:
ocr_lang: Language for OCR
use_gpu: Whether to use GPU for OCR
bbox_padding: Padding to add around bboxes (as fraction)
dpi: DPI used for rendering (for coordinate conversion)
"""
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self.bbox_padding = bbox_padding
self.dpi = dpi
self._ocr_engine = None # Lazy init
@property
def ocr_engine(self):
"""Lazy-load OCR engine only when needed."""
if self._ocr_engine is None:
from ..ocr import OCREngine
self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu)
return self._ocr_engine
def extract_from_detection_with_pdf(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int
) -> ExtractedField:
"""
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
Args:
detection: Detection object with bbox in pixel coordinates
pdf_tokens: List of Token objects from PDF text extraction
image_width: Width of rendered image in pixels
image_height: Height of rendered image in pixels
Returns:
ExtractedField object
"""
# Convert detection bbox from pixels to PDF points
scale = 72 / self.dpi # points per pixel
x0_pdf = detection.bbox[0] * scale
y0_pdf = detection.bbox[1] * scale
x1_pdf = detection.bbox[2] * scale
y1_pdf = detection.bbox[3] * scale
# Add padding in points
pad = 3 # Small padding in points
# Find tokens that overlap with detection bbox
matching_tokens = []
for token in pdf_tokens:
if token.page_no != detection.page_no:
continue
tx0, ty0, tx1, ty1 = token.bbox
# Check overlap
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
# Calculate overlap ratio to prioritize better matches
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
if overlap_x > 0 and overlap_y > 0:
token_area = (tx1 - tx0) * (ty1 - ty0)
overlap_area = overlap_x * overlap_y
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
detection_confidence=detection.confidence,
ocr_confidence=1.0, # PDF text is always accurate
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def extract_from_detection(
self,
detection: Detection,
image: np.ndarray | Image.Image
) -> ExtractedField:
"""
Extract field value from a detection region using OCR.
Args:
detection: Detection object
image: Full page image
Returns:
ExtractedField object
"""
if isinstance(image, Image.Image):
image = np.array(image)
# Get padded bbox
h, w = image.shape[:2]
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
# Crop region
x0, y0, x1, y1 = [int(v) for v in bbox]
region = image[y0:y1, x0:x1]
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
# Combined confidence
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=confidence,
detection_confidence=detection.confidence,
ocr_confidence=ocr_confidence,
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def _normalize_and_validate(
self,
field_name: str,
raw_text: str
) -> tuple[str | None, bool, str | None]:
"""
Normalize and validate extracted text for a field.
Returns:
(normalized_value, is_valid, validation_error)
"""
text = raw_text.strip()
if not text:
return None, False, "Empty text"
if field_name == 'InvoiceNumber':
return self._normalize_invoice_number(text)
elif field_name == 'OCR':
return self._normalize_ocr_number(text)
elif field_name == 'Bankgiro':
return self._normalize_bankgiro(text)
elif field_name == 'Plusgiro':
return self._normalize_plusgiro(text)
elif field_name == 'Amount':
return self._normalize_amount(text)
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
return self._normalize_date(text)
else:
return text, True, None
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize invoice number."""
# Extract digits only
digits = re.sub(r'\D', '', text)
if len(digits) < 3:
return None, False, f"Too few digits: {len(digits)}"
return digits, True, None
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize OCR number."""
digits = re.sub(r'\D', '', text)
if len(digits) < 5:
return None, False, f"Too few digits for OCR: {len(digits)}"
return digits, True, None
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize Bankgiro number."""
digits = re.sub(r'\D', '', text)
if len(digits) == 8:
# Format as XXXX-XXXX
formatted = f"{digits[:4]}-{digits[4:]}"
return formatted, True, None
elif len(digits) == 7:
# Format as XXX-XXXX
formatted = f"{digits[:3]}-{digits[3:]}"
return formatted, True, None
elif 6 <= len(digits) <= 9:
return digits, True, None
else:
return None, False, f"Invalid Bankgiro length: {len(digits)}"
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize Plusgiro number."""
digits = re.sub(r'\D', '', text)
if len(digits) >= 6:
# Format as XXXXXXX-X
formatted = f"{digits[:-1]}-{digits[-1]}"
return formatted, True, None
else:
return None, False, f"Invalid Plusgiro length: {len(digits)}"
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize monetary amount."""
# Remove currency and common suffixes
text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE)
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
if ',' in text and '.' not in text:
text = text.replace(',', '.')
# Try to parse as float
try:
amount = float(text)
return f"{amount:.2f}", True, None
except ValueError:
return None, False, f"Cannot parse amount: {text}"
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize date."""
from datetime import datetime
# Common date patterns
patterns = [
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"),
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
(r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"),
]
for pattern, formatter in patterns:
match = re.search(pattern, text)
if match:
try:
date_str = formatter(match)
# Validate date
datetime.strptime(date_str, '%Y-%m-%d')
return date_str, True, None
except ValueError:
continue
return None, False, f"Cannot parse date: {text}"
def extract_all_fields(
self,
detections: list[Detection],
image: np.ndarray | Image.Image
) -> list[ExtractedField]:
"""
Extract fields from all detections.
Args:
detections: List of detections
image: Full page image
Returns:
List of ExtractedField objects
"""
fields = []
for detection in detections:
field = self.extract_from_detection(detection, image)
fields.append(field)
return fields
@staticmethod
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
"""
Infer OCR field from InvoiceNumber if not detected.
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
Args:
fields: Dict of field_name -> normalized_value
Returns:
Updated fields dict with inferred OCR if applicable
"""
# If OCR already exists, no need to infer
if fields.get('OCR'):
return fields
# If InvoiceNumber exists and is numeric, use it as OCR
invoice_number = fields.get('InvoiceNumber')
if invoice_number:
# Check if it's mostly digits (valid OCR reference)
digits_only = re.sub(r'\D', '', invoice_number)
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
fields['OCR'] = invoice_number
return fields