This commit is contained in:
Yaojia Wang
2026-01-22 22:03:24 +01:00
parent 4ea4bc96d4
commit 8fd61ea928
19 changed files with 4069 additions and 226 deletions

View File

@@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace:
"--dpi",
type=int,
default=150,
help="DPI for PDF rendering",
help="DPI for PDF rendering (must match training DPI)",
)
parser.add_argument(

View File

@@ -63,7 +63,24 @@ def main():
)
parser.add_argument(
'--resume',
help='Resume from checkpoint'
action='store_true',
help='Resume from last checkpoint'
)
parser.add_argument(
'--workers',
type=int,
default=4,
help='Number of data loader workers (default: 4, reduce if OOM)'
)
parser.add_argument(
'--cache',
action='store_true',
help='Cache images in RAM (faster but uses more memory)'
)
parser.add_argument(
'--low-memory',
action='store_true',
help='Enable low memory mode (batch=4, workers=2, no cache)'
)
parser.add_argument(
'--train-ratio',
@@ -103,6 +120,16 @@ def main():
args = parser.parse_args()
# Apply low-memory mode if specified
if args.low_memory:
print("🔧 Low memory mode enabled")
args.batch = min(args.batch, 8) # Reduce from 16 to 8
args.workers = min(args.workers, 4) # Reduce from 8 to 4
args.cache = False
print(f" Batch size: {args.batch}")
print(f" Workers: {args.workers}")
print(f" Cache: disabled")
# Validate dataset directory
dataset_dir = Path(args.dataset_dir)
temp_dir = dataset_dir / 'temp'
@@ -181,9 +208,10 @@ def main():
from ultralytics import YOLO
# Load model
if args.resume:
print(f"Resuming from: {args.resume}")
model = YOLO(args.resume)
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
if args.resume and last_checkpoint.exists():
print(f"Resuming from: {last_checkpoint}")
model = YOLO(str(last_checkpoint))
else:
model = YOLO(args.model)
@@ -200,6 +228,9 @@ def main():
'exist_ok': True,
'pretrained': True,
'verbose': True,
'workers': args.workers,
'cache': args.cache,
'resume': args.resume and last_checkpoint.exists(),
# Document-specific augmentation settings
'degrees': 5.0,
'translate': 0.05,

View File

@@ -2,17 +2,33 @@
Field Extractor Module
Extracts and validates field values from detected regions.
This module is used during inference to extract values from OCR text.
It uses shared utilities from src.utils for text cleaning and validation.
Enhanced features:
- Multi-source fusion with confidence weighting
- Smart amount parsing with multiple strategies
- Enhanced date format unification
- OCR error correction integration
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from collections import defaultdict
import re
import numpy as np
from PIL import Image
from .yolo_detector import Detection, CLASS_TO_FIELD
# Import shared utilities for text cleaning and validation
from src.utils.text_cleaner import TextCleaner
from src.utils.validators import FieldValidators
from src.utils.fuzzy_matcher import FuzzyMatcher
from src.utils.ocr_corrections import OCRCorrections
@dataclass
class ExtractedField:
@@ -27,10 +43,14 @@ class ExtractedField:
page_no: int
is_valid: bool = True
validation_error: str | None = None
# Multi-source fusion fields
alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...]
extraction_method: str = 'single' # 'single', 'fused', 'corrected'
ocr_corrections_applied: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
result = {
'field_name': self.field_name,
'value': self.normalized_value,
'raw_text': self.raw_text,
@@ -40,6 +60,11 @@ class ExtractedField:
'is_valid': self.is_valid,
'validation_error': self.validation_error
}
if self.alternative_values:
result['alternatives'] = self.alternative_values
if self.extraction_method != 'single':
result['extraction_method'] = self.extraction_method
return result
class FieldExtractor:
@@ -324,21 +349,9 @@ class FieldExtractor:
Validate using Luhn (Mod10) algorithm.
Used for Bankgiro, Plusgiro, and OCR number validation.
The checksum is valid if the total modulo 10 equals 0.
Delegates to shared FieldValidators for consistency.
"""
if not digits.isdigit():
return False
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 1: # Double every second digit from right
digit *= 2
if digit > 9:
digit -= 9
total += digit
return total % 10 == 0
return FieldValidators.luhn_checksum(digits)
def _detect_giro_type(self, text: str) -> str | None:
"""
@@ -489,45 +502,90 @@ class FieldExtractor:
return None, False, f"No valid Plusgiro found in text"
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize monetary amount."""
# Try to extract amount using regex patterns
# Pattern 1: Number with comma as decimal (Swedish format: 1 234,56)
# Pattern 2: Number with dot as decimal (1234.56)
# Pattern 3: Number followed by currency (275,60 kr or 275.60 SEK)
"""Normalize monetary amount.
patterns = [
# Swedish format with space thousand separator: 1 234,56 or 1234,56
r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?',
# Simple decimal: 350.00 or 350,00
r'(\d+[,\.]\d{2})',
# Integer amount
r'(\d{2,})',
]
Uses shared TextCleaner for preprocessing and FieldValidators for parsing.
If multiple amounts are found, returns the last one (usually the total).
"""
# Split by newlines and process line by line to get the last valid amount
lines = text.split('\n')
for pattern in patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
if matches:
# Take the last match (usually the total amount)
amount_str = matches[-1]
# Clean up
amount_str = amount_str.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
if ',' in amount_str:
amount_str = amount_str.replace(',', '.')
# Collect all valid amounts from all lines
all_amounts = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?'
for line in lines:
line = line.strip()
if not line:
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(' ', '').replace(',', '.')
try:
amount = float(amount_str)
if amount > 0:
return f"{amount:.2f}", True, None
all_amounts.append(amount)
except ValueError:
continue
# Return the last amount found (usually the total)
if all_amounts:
return f"{all_amounts[-1]:.2f}", True, None
# Fallback: try shared validator on cleaned text
cleaned = TextCleaner.normalize_amount_text(text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and amount > 0:
return f"{amount:.2f}", True, None
# Try to find any decimal number
simple_pattern = r'(\d+[,\.]\d{2})'
matches = re.findall(simple_pattern, text)
if matches:
amount_str = matches[-1].replace(',', '.')
try:
amount = float(amount_str)
if amount > 0:
return f"{amount:.2f}", True, None
except ValueError:
pass
# Last resort: try to find integer amount (no decimals)
# Look for patterns like "Amount: 11699" or standalone numbers
int_pattern = r'(?:amount|belopp|summa|total)[:\s]*(\d+)'
match = re.search(int_pattern, text, re.IGNORECASE)
if match:
try:
amount = float(match.group(1))
if amount > 0:
return f"{amount:.2f}", True, None
except ValueError:
pass
# Very last resort: find any standalone number >= 3 digits
standalone_pattern = r'\b(\d{3,})\b'
matches = re.findall(standalone_pattern, text)
if matches:
# Take the last/largest number
try:
amount = float(matches[-1])
if amount > 0:
return f"{amount:.2f}", True, None
except ValueError:
pass
return None, False, f"Cannot parse amount: {text}"
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Normalize date from text that may contain surrounding text.
Uses shared FieldValidators for date parsing and validation.
Handles various date formats found in Swedish invoices:
- 2025-08-29 (ISO format)
- 2025.08.29 (dot separator)
@@ -535,9 +593,14 @@ class FieldExtractor:
- 29.08.2025 (European with dots)
- 20250829 (compact format)
"""
# First, try using shared validator
iso_date = FieldValidators.format_date_iso(text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return iso_date, True, None
# Fallback: try original patterns for edge cases
from datetime import datetime
# Common date patterns - order matters, most specific first
patterns = [
# ISO format: 2025-08-29
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
@@ -570,32 +633,76 @@ class FieldExtractor:
"""
Normalize payment line region text.
Extracts OCR, Amount, and Bankgiro from the payment line using MachineCodeParser.
Extracts the machine-readable payment line format from OCR text.
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Examples:
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
Returns normalized format preserving ALL components including Amount:
- Full format: "OCR:xxx Amount:xxx.xx BG:xxx" or "OCR:xxx Amount:xxx.xx PG:xxx"
- This allows downstream cross-validation to extract fields properly.
"""
from ..ocr.machine_code_parser import MachineCodeParser
# Pattern to match Swedish payment line format WITH amount
# Format: # <OCR number> # <Kronor> <Öre> <Type> > <account number>#<check digits>#
# Account number may have spaces: "78 2 1 713" -> "7821713"
# Kronor may have OCR-induced spaces: "12 0 0" -> "1200"
# The > symbol may be missing in low-DPI OCR, so make it optional
# Check digits may have spaces: "#41 #" -> "#41#"
payment_line_full_pattern = r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
# Create a simple token-like structure for the parser
# (The parser expects tokens, but for inference we have raw text)
parser = MachineCodeParser()
match = re.search(payment_line_full_pattern, text)
if match:
ocr_part = match.group(1).replace(' ', '')
kronor = match.group(2).replace(' ', '') # Remove OCR-induced spaces
ore = match.group(3)
record_type = match.group(4)
account = match.group(5).replace(' ', '') # Remove spaces from account number
check_digits = match.group(6)
# Try to parse the standard payment line format
result = parser._parse_standard_payment_line(text)
# Reconstruct the clean machine-readable format
# Format: # OCR # KRONOR ORE TYPE > ACCOUNT#CHECK#
result = f"# {ocr_part} # {kronor} {ore} {record_type} > {account}#{check_digits}#"
return result, True, None
if result:
# Format as structured output
parts = []
if result.get('ocr'):
parts.append(f"OCR:{result['ocr']}")
if result.get('amount'):
parts.append(f"Amount:{result['amount']}")
if result.get('bankgiro'):
parts.append(f"BG:{result['bankgiro']}")
# Try pattern WITHOUT amount (some payment lines don't have amount)
# Format: # <OCR number> # > <account number>#<check digits>#
# > may be missing in low-DPI OCR
# Check digits may have spaces
payment_line_no_amount_pattern = r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(payment_line_no_amount_pattern, text)
if match:
ocr_part = match.group(1).replace(' ', '')
account = match.group(2).replace(' ', '')
check_digits = match.group(3)
if parts:
return ' '.join(parts), True, None
result = f"# {ocr_part} # > {account}#{check_digits}#"
return result, True, None
# Fallback: return raw text if no structured parsing possible
return text, True, None
# Try alternative pattern: just look for the # > account# pattern (> optional)
# Check digits may have spaces
alt_pattern = r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(alt_pattern, text)
if match:
ocr_part = match.group(1).replace(' ', '')
account = match.group(2).replace(' ', '')
check_digits = match.group(3)
result = f"# {ocr_part} # > {account}#{check_digits}#"
return result, True, None
# Try to find just the account part with # markers
# Check digits may have spaces
account_pattern = r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
match = re.search(account_pattern, text)
if match:
account = match.group(1).replace(' ', '')
check_digits = match.group(2)
return f"> {account}#{check_digits}#", True, "Partial payment line (account only)"
# Fallback: return None if no payment line format found
return None, False, "No valid payment line format found"
def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""
@@ -640,24 +747,52 @@ class FieldExtractor:
Normalize customer number extracted from OCR.
Customer numbers can have various formats:
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N'
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
Note: Spaces and dashes may be removed from invoice display,
so we need to match both 'JTY 576-3' and 'JTY5763' formats.
"""
from ..normalize.normalizer import FieldNormalizer
# Clean the text using the same logic as matcher
text = FieldNormalizer.clean_text(text)
if not text:
if not text or not text.strip():
return None, False, "Empty text"
# Customer number patterns - ordered by specificity
# Match both spaced/dashed versions and compact versions
# Keep original text for pattern matching (don't uppercase yet)
original_text = text.strip()
# Customer number patterns - ordered by specificity (most specific first)
# All patterns use IGNORECASE so they work regardless of case
customer_code_patterns = [
# Pattern: 2-4 letters + space + digits + dash + single letter/digit (UMJ 436-R, EMM 256-6)
# This is the most common Swedish customer number format
r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b',
# Pattern: 2-4 letters + space + digits + letter WITHOUT dash (Dwq 211X, ABC 123X)
# Note: This is also common for customer numbers
r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b',
# Pattern: Word (capitalized) + space + digits + dash + letter (Umj 436-R, Billo 123-A)
r'\b([A-Za-z][a-z]{1,10})\s+(\d{1,4})-([A-Za-z0-9])\b',
# Pattern: Letters + digits + dash + digit/letter without space (JTY576-3)
r'\b([A-Za-z]{2,4})(\d{1,4})-([A-Za-z0-9])\b',
]
# Try specific patterns first
for pattern in customer_code_patterns:
match = re.search(pattern, original_text)
if match:
# Skip if it looks like a Swedish postal code (SE + digits)
full_match = match.group(0)
if full_match.upper().startswith('SE ') and re.match(r'^SE\s+\d{3}\s*\d{2}', full_match, re.IGNORECASE):
continue
# Reconstruct the customer number in standard format
groups = match.groups()
if len(groups) == 3:
# Format: XXX NNN-X (add dash if not present, e.g., "Dwq 211X" -> "DWQ 211-X")
result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}"
return result, True, None
# Generic patterns for other formats
generic_patterns = [
# Pattern: Letters + space/dash + digits + dash + digit (EMM 256-6, JTY 576-3)
r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]\d{1,2}[A-Z]?)\b',
# Pattern: Letters + space/dash + digits + optional letter (FFL 019N, ABC 123X)
@@ -666,27 +801,47 @@ class FieldExtractor:
r'\b([A-Z]{2,4}\d{3,6}[A-Z]?)\b',
# Pattern: Single letter + digits (A12345)
r'\b([A-Z]\d{4,6}[A-Z]?)\b',
# Pattern: Digits + dash/space + digits (123-456)
r'\b(\d{3,6}[\s\-]\d{1,4})\b',
]
all_matches = []
for pattern in customer_code_patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
all_matches.extend(matches)
for pattern in generic_patterns:
for match in re.finditer(pattern, original_text, re.IGNORECASE):
matched_text = match.group(1)
pos = match.start()
# Filter out matches that look like postal codes or ID numbers
# Postal codes are usually 3-5 digits without letters
if re.match(r'^\d+$', matched_text):
continue
# Filter out V4 2 type matches (single letter + digit + space + digit)
if re.match(r'^[A-Z]\d\s+\d$', matched_text, re.IGNORECASE):
continue
# Filter out Swedish postal codes (SE XXX XX format or SE + digits)
# SE followed by digits is typically postal code, not customer number
if re.match(r'^SE[\s\-]*\d', matched_text, re.IGNORECASE):
continue
all_matches.append((matched_text, pos))
if all_matches:
# Prefer longer matches and those appearing later in text (after names)
# Sort by position in text (later = better) and length (longer = better)
# Prefer matches that contain both letters and digits with dash
scored_matches = []
for match in all_matches:
pos = text.upper().rfind(match.upper())
# Score: position * 0.1 + length (prefer later and longer)
score = pos * 0.1 + len(match)
scored_matches.append((score, match))
for match_text, pos in all_matches:
score = 0
# Bonus for containing dash (likely customer number format)
if '-' in match_text:
score += 50
# Bonus for format like XXX NNN-X
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', match_text, re.IGNORECASE):
score += 100
# Bonus for length (prefer medium length)
if 6 <= len(match_text) <= 12:
score += 20
# Position bonus (prefer later matches, after names)
score += pos * 0.1
scored_matches.append((score, match_text))
best_match = max(scored_matches, key=lambda x: x[0])[1]
return best_match.strip().upper(), True, None
if scored_matches:
best_match = max(scored_matches, key=lambda x: x[0])[1]
return best_match.strip().upper(), True, None
# Pattern 2: Look for explicit labels
labeled_patterns = [
@@ -694,7 +849,7 @@ class FieldExtractor:
]
for pattern in labeled_patterns:
match = re.search(pattern, text, re.IGNORECASE)
match = re.search(pattern, original_text, re.IGNORECASE)
if match:
extracted = match.group(1).strip()
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
@@ -702,32 +857,18 @@ class FieldExtractor:
return extracted.upper(), True, None
# Pattern 3: If text contains comma (likely "NAME, NAME CODE"), extract after last comma
if ',' in text:
after_comma = text.split(',')[-1].strip()
if ',' in original_text:
after_comma = original_text.split(',')[-1].strip()
# Look for alphanumeric code in the part after comma
for pattern in customer_code_patterns[:3]: # Use first 3 patterns
code_match = re.search(pattern, after_comma, re.IGNORECASE)
for pattern in customer_code_patterns:
code_match = re.search(pattern, after_comma)
if code_match:
return code_match.group(1).strip().upper(), True, None
groups = code_match.groups()
if len(groups) == 3:
result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}"
return result, True, None
# Pattern 4: Short text - filter out name-like words
if len(text) <= 20:
words = text.split()
code_parts = []
for word in words:
# Keep if: contains digits, or is short uppercase (likely abbreviation)
if re.search(r'\d', word) or (len(word) <= 4 and word.isupper()):
code_parts.append(word)
if code_parts:
result = ' '.join(code_parts).upper()
if len(result) >= 3:
return result, True, None
# Fallback: return cleaned text if reasonable
if text and 3 <= len(text) <= 15:
return text.upper(), True, None
return None, False, f"Cannot extract customer number from: {text[:50]}"
return None, False, f"Cannot extract customer number from: {original_text[:50]}"
def extract_all_fields(
self,
@@ -779,3 +920,428 @@ class FieldExtractor:
fields['OCR'] = invoice_number
return fields
# =========================================================================
# Multi-Source Fusion with Confidence Weighting
# =========================================================================
def fuse_multiple_detections(
self,
extracted_fields: list[ExtractedField]
) -> list[ExtractedField]:
"""
Fuse multiple detections of the same field using confidence-weighted voting.
When YOLO detects the same field type multiple times (e.g., multiple Amount boxes),
this method selects the best value or combines them intelligently.
Strategies:
1. For numeric fields (Amount, OCR): prefer values that pass validation
2. For date fields: prefer values in expected range
3. For giro numbers: prefer values with valid Luhn checksum
4. General: weighted vote by confidence scores
Args:
extracted_fields: List of all extracted fields (may have duplicates)
Returns:
List with duplicates resolved to single best value per field
"""
# Group fields by name
fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list)
for field in extracted_fields:
fields_by_name[field.field_name].append(field)
fused_fields = []
for field_name, candidates in fields_by_name.items():
if len(candidates) == 1:
# No fusion needed
fused_fields.append(candidates[0])
else:
# Multiple candidates - fuse them
fused = self._fuse_field_candidates(field_name, candidates)
fused_fields.append(fused)
return fused_fields
def _fuse_field_candidates(
self,
field_name: str,
candidates: list[ExtractedField]
) -> ExtractedField:
"""
Fuse multiple candidates for a single field.
Returns the best candidate with alternatives recorded.
"""
# Sort by confidence (descending)
sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True)
# Collect all unique values with their max confidence
value_scores: dict[str, tuple[float, ExtractedField]] = {}
for c in sorted_candidates:
if c.normalized_value:
if c.normalized_value not in value_scores:
value_scores[c.normalized_value] = (c.confidence, c)
else:
# Keep the higher confidence one
if c.confidence > value_scores[c.normalized_value][0]:
value_scores[c.normalized_value] = (c.confidence, c)
if not value_scores:
# No valid values, return the highest confidence candidate
return sorted_candidates[0]
# Field-specific fusion strategy
best_value, best_field = self._select_best_value(field_name, value_scores)
# Record alternatives
alternatives = [
(v, score) for v, (score, _) in value_scores.items()
if v != best_value
]
# Create fused result
result = ExtractedField(
field_name=field_name,
raw_text=best_field.raw_text,
normalized_value=best_value,
confidence=value_scores[best_value][0],
detection_confidence=best_field.detection_confidence,
ocr_confidence=best_field.ocr_confidence,
bbox=best_field.bbox,
page_no=best_field.page_no,
is_valid=best_field.is_valid,
validation_error=best_field.validation_error,
alternative_values=alternatives,
extraction_method='fused' if len(value_scores) > 1 else 'single'
)
return result
def _select_best_value(
self,
field_name: str,
value_scores: dict[str, tuple[float, ExtractedField]]
) -> tuple[str, ExtractedField]:
"""
Select the best value for a field using field-specific logic.
Returns (best_value, best_field)
"""
items = list(value_scores.items())
# Field-specific selection
if field_name in ('Bankgiro', 'Plusgiro', 'OCR'):
# Prefer values with valid Luhn checksum
for value, (score, field) in items:
digits = re.sub(r'\D', '', value)
if FieldValidators.luhn_checksum(digits):
return value, field
elif field_name == 'Amount':
# Prefer larger amounts (usually the total, not subtotals)
amounts = []
for value, (score, field) in items:
try:
amt = float(value.replace(',', '.'))
amounts.append((amt, value, field))
except ValueError:
continue
if amounts:
# Return the largest amount
amounts.sort(reverse=True)
return amounts[0][1], amounts[0][2]
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Prefer dates in reasonable range
from datetime import datetime
for value, (score, field) in items:
try:
dt = datetime.strptime(value, '%Y-%m-%d')
# Prefer recent dates (within last 2 years and next 1 year)
now = datetime.now()
if now.year - 2 <= dt.year <= now.year + 1:
return value, field
except ValueError:
continue
# Default: return highest confidence value
best = max(items, key=lambda x: x[1][0])
return best[0], best[1][1]
# =========================================================================
# Enhanced Amount Parsing
# =========================================================================
def _normalize_amount_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Enhanced amount parsing with multiple strategies.
Strategies:
1. Pattern matching for Swedish formats
2. Context-aware extraction (look for keywords like "Total", "Summa")
3. OCR error correction for common digit errors
4. Multi-amount handling (prefer last/largest as total)
This method replaces the original _normalize_amount when enhanced mode is enabled.
"""
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
labeled_patterns = [
# Swedish patterns
(r'(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 1.0),
(r'(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 0.8), # Lower priority for VAT
# Generic pattern
(r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?', 0.7),
]
candidates = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(' ', '').replace(',', '.')
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
candidates.append((amount, priority, match.start()))
except ValueError:
continue
if candidates:
# Sort by priority (desc), then by position (later is usually total)
candidates.sort(key=lambda x: (-x[1], -x[2]))
best_amount = candidates[0][0]
return f"{best_amount:.2f}", True, None
# Strategy 3: Parse with shared validator
cleaned = TextCleaner.normalize_amount_text(corrected_text)
amount = FieldValidators.parse_amount(cleaned)
if amount is not None and 0 < amount < 10_000_000:
return f"{amount:.2f}", True, None
# Strategy 4: Try to extract any decimal number as fallback
decimal_pattern = r'(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})'
matches = re.findall(decimal_pattern, corrected_text)
if matches:
# Clean and parse each match
amounts = []
for m in matches:
cleaned_m = m.replace(' ', '').replace('.', '').replace(',', '.')
# Handle Swedish format: "1 234,56" -> "1234.56"
if ',' in m and '.' not in m:
cleaned_m = m.replace(' ', '').replace(',', '.')
try:
amt = float(cleaned_m)
if 0 < amt < 10_000_000:
amounts.append(amt)
except ValueError:
continue
if amounts:
# Return the last/largest amount (usually the total)
return f"{max(amounts):.2f}", True, None
return None, False, f"Cannot parse amount: {text[:50]}"
# =========================================================================
# Enhanced Date Parsing
# =========================================================================
def _normalize_date_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
"""
Enhanced date parsing with comprehensive format support.
Supports:
- ISO: 2024-12-29, 2024/12/29
- European: 29.12.2024, 29/12/2024, 29-12-2024
- Swedish text: "29 december 2024", "29 dec 2024"
- Compact: 20241229
- With OCR corrections: 2O24-12-29 -> 2024-12-29
"""
from datetime import datetime
# Apply OCR corrections
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Try shared validator first
iso_date = FieldValidators.format_date_iso(corrected_text)
if iso_date and FieldValidators.is_valid_date(iso_date):
return iso_date, True, None
# Swedish month names
swedish_months = {
'januari': 1, 'jan': 1,
'februari': 2, 'feb': 2,
'mars': 3, 'mar': 3,
'april': 4, 'apr': 4,
'maj': 5,
'juni': 6, 'jun': 6,
'juli': 7, 'jul': 7,
'augusti': 8, 'aug': 8,
'september': 9, 'sep': 9, 'sept': 9,
'oktober': 10, 'okt': 10,
'november': 11, 'nov': 11,
'december': 12, 'dec': 12,
}
# Pattern for Swedish text dates: "29 december 2024" or "29 dec 2024"
swedish_pattern = r'(\d{1,2})\s+([a-zåäö]+)\s+(\d{4})'
match = re.search(swedish_pattern, corrected_text.lower())
if match:
day = int(match.group(1))
month_name = match.group(2)
year = int(match.group(3))
if month_name in swedish_months:
month = swedish_months[month_name]
try:
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return dt.strftime('%Y-%m-%d'), True, None
except ValueError:
pass
# Extended patterns
patterns = [
# ISO format: 2025-08-29, 2025/08/29
(r'(\d{4})[-/](\d{1,2})[-/](\d{1,2})', 'ymd'),
# Dot format: 2025.08.29
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', 'ymd'),
# European slash: 29/08/2025
(r'(\d{1,2})/(\d{1,2})/(\d{4})', 'dmy'),
# European dot: 29.08.2025
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', 'dmy'),
# European dash: 29-08-2025
(r'(\d{1,2})-(\d{1,2})-(\d{4})', 'dmy'),
# Compact: 20250829
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', 'ymd_compact'),
]
for pattern, fmt in patterns:
match = re.search(pattern, corrected_text)
if match:
try:
if fmt == 'ymd':
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
elif fmt == 'dmy':
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
elif fmt == 'ymd_compact':
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
else:
continue
dt = datetime(year, month, day)
if 2000 <= dt.year <= 2100:
return dt.strftime('%Y-%m-%d'), True, None
except ValueError:
continue
return None, False, f"Cannot parse date: {text[:50]}"
# =========================================================================
# Apply OCR Corrections to Raw Text
# =========================================================================
def apply_ocr_corrections(
self,
field_name: str,
raw_text: str
) -> tuple[str, list[str]]:
"""
Apply OCR corrections to raw text based on field type.
Returns (corrected_text, list_of_corrections_applied)
"""
corrections_applied = []
if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'):
# Aggressive correction for numeric fields
result = OCRCorrections.correct_digits(raw_text, aggressive=True)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
elif field_name == 'Amount':
# Conservative correction for amounts (preserve decimal separators)
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
# Conservative correction for dates
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
if result.corrections_applied:
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
return result.corrected, corrections_applied
# No correction for other fields
return raw_text, []
# =========================================================================
# Extraction with All Enhancements
# =========================================================================
def extract_with_enhancements(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int,
use_enhanced_parsing: bool = True
) -> ExtractedField:
"""
Extract field value with all enhancements enabled.
Combines:
1. OCR error correction
2. Enhanced amount/date parsing
3. Multi-strategy extraction
Args:
detection: Detection object
pdf_tokens: PDF text tokens
image_width: Image width in pixels
image_height: Image height in pixels
use_enhanced_parsing: Whether to use enhanced parsing methods
Returns:
ExtractedField with enhancements applied
"""
# First, extract using standard method
base_result = self.extract_from_detection_with_pdf(
detection, pdf_tokens, image_width, image_height
)
if not use_enhanced_parsing:
return base_result
# Apply OCR corrections
corrected_text, corrections = self.apply_ocr_corrections(
base_result.field_name, base_result.raw_text
)
# Re-normalize with enhanced methods if corrections were applied
if corrections or base_result.normalized_value is None:
if base_result.field_name == 'Amount':
normalized, is_valid, error = self._normalize_amount_enhanced(corrected_text)
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
normalized, is_valid, error = self._normalize_date_enhanced(corrected_text)
else:
# Re-run standard normalization with corrected text
normalized, is_valid, error = self._normalize_and_validate(
base_result.field_name, corrected_text
)
# Update result if we got a better value
if normalized and (not base_result.normalized_value or is_valid):
base_result.normalized_value = normalized
base_result.is_valid = is_valid
base_result.validation_error = error
base_result.ocr_corrections_applied = corrections
if corrections:
base_result.extraction_method = 'corrected'
return base_result

View File

@@ -214,6 +214,43 @@ class InferencePipeline:
# Perform cross-validation if payment_line is detected
self._cross_validate_payment_line(result)
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
"""
Parse machine-readable Swedish payment line format.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
Returns: (ocr, amount, account) tuple
"""
# Pattern with amount
pattern_full = r'#\s*(\d+)\s*#\s*(\d+)\s+(\d{2})\s+\d\s*>\s*(\d+)#\d+#'
match = re.search(pattern_full, payment_line)
if match:
ocr = match.group(1)
kronor = match.group(2)
ore = match.group(3)
account = match.group(4)
amount = f"{kronor}.{ore}"
return ocr, amount, account
# Pattern without amount
pattern_no_amount = r'#\s*(\d+)\s*#\s*>\s*(\d+)#\d+#'
match = re.search(pattern_no_amount, payment_line)
if match:
ocr = match.group(1)
account = match.group(2)
return ocr, None, account
# Fallback: partial pattern
pattern_partial = r'>\s*(\d+)#\d+#'
match = re.search(pattern_partial, payment_line)
if match:
account = match.group(1)
return None, None, account
return None, None, None
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
"""
Cross-validate payment_line data against other detected fields.
@@ -236,23 +273,26 @@ class InferencePipeline:
cv = CrossValidationResult()
cv.details = []
# Parse payment_line format: "OCR:12345 Amount:100,00 BG:123-4567"
pl_parts = {}
for part in str(payment_line).split():
if ':' in part:
key, value = part.split(':', 1)
pl_parts[key.upper()] = value
# Parse machine-readable payment line format
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
cv.payment_line_ocr = pl_parts.get('OCR')
cv.payment_line_amount = pl_parts.get('AMOUNT')
cv.payment_line_ocr = ocr
cv.payment_line_amount = amount
# Determine account type from payment_line
if pl_parts.get('BG'):
cv.payment_line_account = pl_parts['BG']
cv.payment_line_account_type = 'bankgiro'
elif pl_parts.get('PG'):
cv.payment_line_account = pl_parts['PG']
cv.payment_line_account_type = 'plusgiro'
# Determine account type based on digit count
if account:
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
if len(account) >= 7:
cv.payment_line_account_type = 'bankgiro'
# Format: XXX-XXXX or XXXX-XXXX
if len(account) == 7:
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
else:
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
else:
cv.payment_line_account_type = 'plusgiro'
# Format: XXXXXXX-X
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
# Cross-validate and OVERRIDE with payment_line values

View File

@@ -266,6 +266,38 @@ class TestNormalizePaymentLine:
assert is_valid is True
# Bankgiro should be normalized despite spaces
def test_payment_line_with_spaces_in_check_digits(self, extractor):
"""Test payment line with spaces around check digits: #41 # instead of #41#."""
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
result, is_valid, error = extractor._normalize_payment_line(text)
assert result is not None
assert is_valid is True
assert "6026726908" in result
assert "736 00" in result
assert "5692041#41#" in result
def test_payment_line_with_ocr_spaces_in_amount(self, extractor):
"""Test payment line with OCR-induced spaces in amount: '12 0 0 00' -> '1200 00'."""
text = "# 11000770600242 # 12 0 0 00 5 3082963#41#"
result, is_valid, error = extractor._normalize_payment_line(text)
assert result is not None
assert is_valid is True
assert "11000770600242" in result
assert "1200 00" in result
assert "3082963#41#" in result
def test_payment_line_without_greater_symbol(self, extractor):
"""Test payment line with missing > symbol (low-DPI OCR issue)."""
text = "# 11000770600242 # 1200 00 5 3082963#41#"
result, is_valid, error = extractor._normalize_payment_line(text)
assert result is not None
assert is_valid is True
assert "11000770600242" in result
assert "1200 00" in result
class TestNormalizeCustomerNumber:
"""Tests for customer number normalization."""
@@ -284,6 +316,33 @@ class TestNormalizeCustomerNumber:
result, is_valid, error = extractor._normalize_customer_number("JTY5763")
assert result is not None
def test_format_without_dash(self, extractor):
"""Test customer number format without dash: Dwq 211X -> DWQ 211-X."""
text = "Dwq 211X Billo SE 106 43 Stockholm"
result, is_valid, error = extractor._normalize_customer_number(text)
assert result is not None
assert is_valid is True
assert result == "DWQ 211-X"
def test_swedish_postal_code_exclusion(self, extractor):
"""Test that Swedish postal codes are excluded: SE 106 43 should not be extracted."""
text = "SE 106 43 Stockholm"
result, is_valid, error = extractor._normalize_customer_number(text)
# Should not extract postal code
assert result is None or "SE 106" not in result
def test_customer_number_with_postal_code_in_text(self, extractor):
"""Test extracting customer number when postal code is also present."""
text = "Customer: ABC 123X, Address: SE 106 43 Stockholm"
result, is_valid, error = extractor._normalize_customer_number(text)
assert result is not None
assert "ABC" in result
# Should not extract postal code
assert "SE 106" not in result if result else True
class TestNormalizeSupplierOrgNumber:
"""Tests for supplier organization number normalization."""

View File

@@ -2,6 +2,9 @@
Field Normalization Module
Normalizes field values to generate multiple candidate forms for matching.
This module generates variants of CSV values for matching against OCR text.
It uses shared utilities from src.utils for text cleaning and OCR error variants.
"""
import re
@@ -9,6 +12,10 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Callable
# Import shared utilities
from src.utils.text_cleaner import TextCleaner
from src.utils.format_variants import FormatVariants
@dataclass
class NormalizedValue:
@@ -39,15 +46,11 @@ class FieldNormalizer:
@staticmethod
def clean_text(text: str) -> str:
"""Remove invisible characters and normalize whitespace and dashes."""
# Remove zero-width characters
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
# Normalize different dash types to standard hyphen-minus (ASCII 45)
# en-dash (, U+2013), em-dash (—, U+2014), minus sign (, U+2212), middle dot (·, U+00B7)
text = re.sub(r'[\u2013\u2014\u2212\u00b7]', '-', text)
# Normalize whitespace
text = ' '.join(text.split())
return text.strip()
"""Remove invisible characters and normalize whitespace and dashes.
Delegates to shared TextCleaner for consistency.
"""
return TextCleaner.clean_text(text)
@staticmethod
def normalize_invoice_number(value: str) -> list[str]:
@@ -81,57 +84,44 @@ class FieldNormalizer:
"""
Normalize Bankgiro number.
Uses shared FormatVariants plus OCR error variants.
Examples:
'5393-9484' -> ['5393-9484', '53939484']
'53939484' -> ['53939484', '5393-9484']
"""
value = FieldNormalizer.clean_text(value)
digits_only = re.sub(r'\D', '', value)
# Use shared module for base variants
variants = set(FormatVariants.bankgiro_variants(value))
variants = [value]
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
if digits_only:
# Add without dash
variants.append(digits_only)
# Add with dash (format: XXXX-XXXX for 8 digits)
if len(digits_only) == 8:
with_dash = f"{digits_only[:4]}-{digits_only[4:]}"
variants.append(with_dash)
elif len(digits_only) == 7:
# Some bankgiro numbers are 7 digits: XXX-XXXX
with_dash = f"{digits_only[:3]}-{digits_only[3:]}"
variants.append(with_dash)
return list(set(v for v in variants if v))
return list(v for v in variants if v)
@staticmethod
def normalize_plusgiro(value: str) -> list[str]:
"""
Normalize Plusgiro number.
Uses shared FormatVariants plus OCR error variants.
Examples:
'1234567-8' -> ['1234567-8', '12345678']
'12345678' -> ['12345678', '1234567-8']
"""
value = FieldNormalizer.clean_text(value)
digits_only = re.sub(r'\D', '', value)
# Use shared module for base variants
variants = set(FormatVariants.plusgiro_variants(value))
variants = [value]
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also handle 6+1 format
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
return list(set(v for v in variants if v))
return list(v for v in variants if v)
@staticmethod
def normalize_organisation_number(value: str) -> list[str]:
@@ -141,60 +131,27 @@ class FieldNormalizer:
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
Swedish VAT format: SE + org_number (10 digits) + 01
Uses shared FormatVariants for comprehensive variant generation,
plus OCR error variants.
Examples:
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
"""
value = FieldNormalizer.clean_text(value)
# Use shared module for base variants
variants = set(FormatVariants.organisation_number_variants(value))
# Check if input is a VAT number (starts with SE, ends with 01)
org_digits = None
if value.upper().startswith('SE') and len(value) >= 12:
# Extract org number from VAT: SE + 10 digits + 01
potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits
if len(potential_org) == 12 and potential_org.endswith('01'):
org_digits = potential_org[:-2] # Remove trailing 01
elif len(potential_org) == 10:
org_digits = potential_org
# Add OCR error variants for digit sequences
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits and len(digits) >= 10:
# Generate variants where OCR might have misread characters
for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]):
variants.add(ocr_var)
if len(ocr_var) == 10:
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
if org_digits is None:
org_digits = re.sub(r'\D', '', value)
variants = [value]
if org_digits:
variants.append(org_digits)
# Standard format: NNNNNN-NNNN (10 digits total)
if len(org_digits) == 10:
with_dash = f"{org_digits[:6]}-{org_digits[6:]}"
variants.append(with_dash)
# Swedish VAT format: SE + org_number + 01
vat_number = f"SE{org_digits}01"
variants.append(vat_number)
variants.append(vat_number.lower()) # se556123456701
# With spaces: SE 5561234567 01
variants.append(f"SE {org_digits} 01")
variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01")
# Without 01 suffix (some invoices show just SE + org)
variants.append(f"SE{org_digits}")
variants.append(f"SE {org_digits}")
# Some may have 12 digits (century prefix): NNNNNNNN-NNNN
elif len(org_digits) == 12:
with_dash = f"{org_digits[:8]}-{org_digits[8:]}"
variants.append(with_dash)
# Also try without century prefix
short_version = org_digits[2:]
variants.append(short_version)
variants.append(f"{short_version[:6]}-{short_version[6:]}")
# VAT with short version
vat_number = f"SE{short_version}01"
variants.append(vat_number)
return list(set(v for v in variants if v))
return list(v for v in variants if v)
@staticmethod
def normalize_supplier_accounts(value: str) -> list[str]:

View File

@@ -42,6 +42,7 @@ from dataclasses import dataclass, field
from typing import Optional
from src.pdf.extractor import Token as TextToken
from src.utils.validators import FieldValidators
@dataclass
@@ -484,21 +485,42 @@ class MachineCodeParser:
def format_account(account_digits: str) -> tuple[str, str]:
"""Format account and determine type (bankgiro or plusgiro).
Uses context keywords first, then falls back to Luhn validation
to determine the most likely account type.
Returns: (formatted_account, account_type)
"""
if is_plusgiro_context:
# Plusgiro format: XXXXXXX-X
# Context explicitly indicates Plusgiro
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# No explicit context - use Luhn validation to determine type
# Try both formats and see which passes Luhn check
# Format as Plusgiro: XXXXXXX-X (all digits, check digit at end)
pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
# Format as Bankgiro: XXX-XXXX or XXXX-XXXX
if len(account_digits) == 7:
bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}"
elif len(account_digits) == 8:
bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}"
else:
# Bankgiro format: XXX-XXXX or XXXX-XXXX
if len(account_digits) == 7:
formatted = f"{account_digits[:3]}-{account_digits[3:]}"
elif len(account_digits) == 8:
formatted = f"{account_digits[:4]}-{account_digits[4:]}"
else:
formatted = account_digits
return formatted, 'bankgiro'
bg_formatted = account_digits
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# Decision logic:
# 1. If only one format passes Luhn, use that
# 2. If both pass or both fail, default to Bankgiro (more common in payment lines)
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
# Both valid or both invalid - default to bankgiro
return bg_formatted, 'bankgiro'
# Try primary pattern
match = self.PAYMENT_LINE_PATTERN.search(raw_line)

34
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,34 @@
"""
Shared utilities for invoice field extraction and matching.
This module provides common functionality used by both:
- Inference stage (field_extractor.py) - extracting values from OCR text
- Matching stage (normalizer.py) - generating variants for CSV matching
Modules:
- TextCleaner: Unicode normalization and OCR error correction
- FormatVariants: Generate format variants for matching
- FieldValidators: Validate field values (Luhn, dates, amounts)
- FuzzyMatcher: Fuzzy string matching with OCR awareness
- OCRCorrections: Comprehensive OCR error correction
- ContextExtractor: Context-aware field extraction
"""
from .text_cleaner import TextCleaner
from .format_variants import FormatVariants
from .validators import FieldValidators
from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
from .ocr_corrections import OCRCorrections, CorrectionResult
from .context_extractor import ContextExtractor, ExtractionCandidate
__all__ = [
'TextCleaner',
'FormatVariants',
'FieldValidators',
'FuzzyMatcher',
'FuzzyMatchResult',
'OCRCorrections',
'CorrectionResult',
'ContextExtractor',
'ExtractionCandidate',
]

View File

@@ -0,0 +1,433 @@
"""
Context-Aware Extraction Module
Extracts field values using contextual cues and label detection.
Improves extraction accuracy by understanding the semantic context.
"""
import re
from typing import Optional, NamedTuple
from dataclasses import dataclass
from .text_cleaner import TextCleaner
from .validators import FieldValidators
@dataclass
class ExtractionCandidate:
"""A candidate extracted value with metadata."""
value: str
raw_text: str
context_label: str
confidence: float
position: int # Character position in source text
extraction_method: str # 'label', 'pattern', 'proximity'
class ContextExtractor:
"""
Context-aware field extraction.
Uses multiple strategies:
1. Label detection - finds values after field labels
2. Pattern matching - uses field-specific regex patterns
3. Proximity analysis - finds values near related terms
4. Validation filtering - removes invalid candidates
"""
# =========================================================================
# Swedish Label Patterns (what appears before the value)
# =========================================================================
LABEL_PATTERNS = {
'InvoiceNumber': [
# Swedish
r'(?:faktura|fakt)\.?\s*(?:nr|nummer|#)?[:\s]*',
r'(?:fakturanummer|fakturanr)[:\s]*',
r'(?:vår\s+referens)[:\s]*',
# English
r'(?:invoice)\s*(?:no|number|#)?[:\s]*',
r'inv[.:\s]*#?',
],
'Amount': [
# Swedish
r'(?:att\s+)?betala[:\s]*',
r'(?:total|totalt|summa)[:\s]*',
r'(?:belopp)[:\s]*',
r'(?:slutsumma)[:\s]*',
r'(?:att\s+erlägga)[:\s]*',
# English
r'(?:total|amount|sum)[:\s]*',
r'(?:amount\s+due)[:\s]*',
],
'InvoiceDate': [
# Swedish
r'(?:faktura)?datum[:\s]*',
r'(?:fakt\.?\s*datum)[:\s]*',
# English
r'(?:invoice\s+)?date[:\s]*',
],
'InvoiceDueDate': [
# Swedish
r'(?:förfall(?:o)?datum)[:\s]*',
r'(?:betalas\s+senast)[:\s]*',
r'(?:sista\s+betalningsdag)[:\s]*',
r'(?:förfaller)[:\s]*',
# English
r'(?:due\s+date)[:\s]*',
r'(?:payment\s+due)[:\s]*',
],
'OCR': [
r'(?:ocr)[:\s]*',
r'(?:ocr\s*-?\s*nummer)[:\s]*',
r'(?:referens(?:nummer)?)[:\s]*',
r'(?:betalningsreferens)[:\s]*',
],
'Bankgiro': [
r'(?:bankgiro|bg)[:\s]*',
r'(?:bank\s*giro)[:\s]*',
],
'Plusgiro': [
r'(?:plusgiro|pg)[:\s]*',
r'(?:plus\s*giro)[:\s]*',
r'(?:postgiro)[:\s]*',
],
'supplier_organisation_number': [
r'(?:org\.?\s*(?:nr|nummer)?)[:\s]*',
r'(?:organisationsnummer)[:\s]*',
r'(?:org\.?\s*id)[:\s]*',
r'(?:vat\s*(?:no|number|nr)?)[:\s]*',
r'(?:moms(?:reg)?\.?\s*(?:nr|nummer)?)[:\s]*',
r'(?:se)[:\s]*', # VAT prefix
],
'customer_number': [
r'(?:kund(?:nr|nummer)?)[:\s]*',
r'(?:kundnummer)[:\s]*',
r'(?:customer\s*(?:no|number|id)?)[:\s]*',
r'(?:er\s+referens)[:\s]*',
],
}
# =========================================================================
# Value Patterns (what the value looks like)
# =========================================================================
VALUE_PATTERNS = {
'InvoiceNumber': [
r'[A-Z]{0,3}\d{3,15}', # Alphanumeric: INV12345
r'\d{3,15}', # Pure digits
r'20\d{2}[-/]\d{3,8}', # Year prefix: 2024-001
],
'Amount': [
r'\d{1,3}(?:[\s.]\d{3})*[,]\d{2}', # Swedish: 1 234,56
r'\d{1,3}(?:[,]\d{3})*[.]\d{2}', # US: 1,234.56
r'\d+[,.]\d{2}', # Simple: 123,45
r'\d+', # Integer
],
'InvoiceDate': [
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}', # ISO-like
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}', # European
r'\d{8}', # Compact YYYYMMDD
],
'InvoiceDueDate': [
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}',
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}',
r'\d{8}',
],
'OCR': [
r'\d{10,25}', # Long digit sequence
],
'Bankgiro': [
r'\d{3,4}[-\s]?\d{4}', # XXX-XXXX or XXXX-XXXX
r'\d{7,8}', # Without separator
],
'Plusgiro': [
r'\d{1,7}[-\s]?\d', # XXXXXXX-X
r'\d{2,8}', # Without separator
],
'supplier_organisation_number': [
r'\d{6}[-\s]?\d{4}', # NNNNNN-NNNN
r'\d{10}', # Without separator
r'SE\s?\d{10,12}(?:\s?01)?', # VAT format
],
'customer_number': [
r'[A-Z]{0,5}\s?[-]?\s?\d{1,10}', # EMM 256-6
r'\d{3,15}', # Pure digits
],
}
# =========================================================================
# Extraction Methods
# =========================================================================
@classmethod
def extract_with_label(
cls,
text: str,
field_name: str,
validate: bool = True
) -> list[ExtractionCandidate]:
"""
Extract field values by finding labels and taking following values.
Example: "Fakturanummer: 12345" -> extracts "12345"
"""
candidates = []
label_patterns = cls.LABEL_PATTERNS.get(field_name, [])
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
for label_pattern in label_patterns:
for value_pattern in value_patterns:
# Combine label + value patterns
full_pattern = f'({label_pattern})({value_pattern})'
matches = re.finditer(full_pattern, text, re.IGNORECASE)
for match in matches:
label = match.group(1).strip()
value = match.group(2).strip()
# Validate if requested
if validate and not cls._validate_value(field_name, value):
continue
# Calculate confidence based on label specificity
confidence = cls._calculate_label_confidence(label, field_name)
candidates.append(ExtractionCandidate(
value=value,
raw_text=match.group(0),
context_label=label,
confidence=confidence,
position=match.start(),
extraction_method='label'
))
return candidates
@classmethod
def extract_with_pattern(
cls,
text: str,
field_name: str,
validate: bool = True
) -> list[ExtractionCandidate]:
"""
Extract field values using only value patterns (no label required).
This is a fallback when no labels are found.
"""
candidates = []
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
for pattern in value_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
value = match.group(0).strip()
# Validate if requested
if validate and not cls._validate_value(field_name, value):
continue
# Lower confidence for pattern-only extraction
confidence = 0.6
candidates.append(ExtractionCandidate(
value=value,
raw_text=value,
context_label='',
confidence=confidence,
position=match.start(),
extraction_method='pattern'
))
return candidates
@classmethod
def extract_field(
cls,
text: str,
field_name: str,
validate: bool = True
) -> list[ExtractionCandidate]:
"""
Extract all candidate values for a field using multiple strategies.
Returns candidates sorted by confidence (highest first).
"""
candidates = []
# Strategy 1: Label-based extraction (highest confidence)
label_candidates = cls.extract_with_label(text, field_name, validate)
candidates.extend(label_candidates)
# Strategy 2: Pattern-based extraction (fallback)
if not label_candidates:
pattern_candidates = cls.extract_with_pattern(text, field_name, validate)
candidates.extend(pattern_candidates)
# Remove duplicates (same value, keep highest confidence)
seen_values = {}
for candidate in candidates:
normalized = TextCleaner.normalize_for_comparison(candidate.value)
if normalized not in seen_values or candidate.confidence > seen_values[normalized].confidence:
seen_values[normalized] = candidate
# Sort by confidence
result = sorted(seen_values.values(), key=lambda x: x.confidence, reverse=True)
return result
@classmethod
def extract_best(
cls,
text: str,
field_name: str,
validate: bool = True
) -> Optional[ExtractionCandidate]:
"""
Extract the best (highest confidence) candidate for a field.
"""
candidates = cls.extract_field(text, field_name, validate)
return candidates[0] if candidates else None
@classmethod
def extract_all_fields(cls, text: str) -> dict[str, list[ExtractionCandidate]]:
"""
Extract all known fields from text.
Returns a dictionary mapping field names to their candidates.
"""
results = {}
for field_name in cls.LABEL_PATTERNS.keys():
candidates = cls.extract_field(text, field_name)
if candidates:
results[field_name] = candidates
return results
# =========================================================================
# Helper Methods
# =========================================================================
@classmethod
def _validate_value(cls, field_name: str, value: str) -> bool:
"""Validate a value based on field type."""
field_lower = field_name.lower()
if 'date' in field_lower:
return FieldValidators.is_valid_date(value)
elif 'amount' in field_lower:
return FieldValidators.is_valid_amount(value)
elif 'bankgiro' in field_lower:
# Basic format check, not Luhn
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return 7 <= len(digits) <= 8
elif 'plusgiro' in field_lower:
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return 2 <= len(digits) <= 8
elif 'ocr' in field_lower:
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return 10 <= len(digits) <= 25
elif 'org' in field_lower:
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return len(digits) >= 10
else:
# For other fields, just check it's not empty
return bool(value.strip())
@classmethod
def _calculate_label_confidence(cls, label: str, field_name: str) -> float:
"""
Calculate confidence based on how specific the label is.
More specific labels = higher confidence.
"""
label_lower = label.lower()
# Very specific labels
very_specific = {
'InvoiceNumber': ['fakturanummer', 'invoice number', 'fakturanr'],
'Amount': ['att betala', 'slutsumma', 'amount due'],
'InvoiceDate': ['fakturadatum', 'invoice date'],
'InvoiceDueDate': ['förfallodatum', 'förfallodag', 'due date'],
'OCR': ['ocr', 'betalningsreferens'],
'Bankgiro': ['bankgiro'],
'Plusgiro': ['plusgiro', 'postgiro'],
'supplier_organisation_number': ['organisationsnummer', 'org nummer'],
'customer_number': ['kundnummer', 'customer number'],
}
# Check for very specific match
if field_name in very_specific:
for specific in very_specific[field_name]:
if specific in label_lower:
return 0.95
# Moderately specific
moderate = {
'InvoiceNumber': ['faktura', 'invoice', 'nr'],
'Amount': ['total', 'summa', 'belopp'],
'InvoiceDate': ['datum', 'date'],
'InvoiceDueDate': ['förfall', 'due'],
}
if field_name in moderate:
for mod in moderate[field_name]:
if mod in label_lower:
return 0.85
# Generic match
return 0.75
@classmethod
def find_field_context(cls, text: str, position: int, window: int = 50) -> str:
"""
Get the surrounding context for a position in text.
Useful for understanding what field a value belongs to.
"""
start = max(0, position - window)
end = min(len(text), position + window)
return text[start:end]
@classmethod
def identify_field_type(cls, text: str, value: str) -> Optional[str]:
"""
Try to identify what field type a value belongs to based on context.
Looks at text surrounding the value to find labels.
"""
# Find the value in text
pos = text.find(value)
if pos == -1:
return None
# Get context before the value
context_before = text[max(0, pos - 50):pos].lower()
# Check each field's labels
for field_name, patterns in cls.LABEL_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, context_before, re.IGNORECASE):
return field_name
return None
# =========================================================================
# Convenience functions
# =========================================================================
def extract_field_with_context(text: str, field_name: str) -> Optional[str]:
"""Convenience function to extract a field value."""
candidate = ContextExtractor.extract_best(text, field_name)
return candidate.value if candidate else None
def extract_all_with_context(text: str) -> dict[str, str]:
"""Convenience function to extract all fields."""
all_candidates = ContextExtractor.extract_all_fields(text)
return {
field: candidates[0].value
for field, candidates in all_candidates.items()
if candidates
}

View File

@@ -0,0 +1,610 @@
"""
Format Variants Generator
Generates multiple format variants for invoice field values.
Used by both inference (to try different extractions) and matching (to match CSV values).
"""
import re
from datetime import datetime
from typing import Optional
from .text_cleaner import TextCleaner
class FormatVariants:
"""
Generates format variants for different field types.
The same logic is used for:
- Inference: trying different formats to extract a value
- Matching: generating variants of CSV values to match against OCR text
"""
# Swedish month names for date parsing
SWEDISH_MONTHS = {
'januari': '01', 'jan': '01',
'februari': '02', 'feb': '02',
'mars': '03', 'mar': '03',
'april': '04', 'apr': '04',
'maj': '05',
'juni': '06', 'jun': '06',
'juli': '07', 'jul': '07',
'augusti': '08', 'aug': '08',
'september': '09', 'sep': '09', 'sept': '09',
'oktober': '10', 'okt': '10',
'november': '11', 'nov': '11',
'december': '12', 'dec': '12',
}
# =========================================================================
# Organization Number Variants
# =========================================================================
@classmethod
def organisation_number_variants(cls, value: str) -> list[str]:
"""
Generate all format variants for Swedish organization number.
Input formats handled:
- "556123-4567" (standard with hyphen)
- "5561234567" (no hyphen)
- "SE556123456701" (VAT format)
- "SE 556123-4567 01" (VAT with spaces)
Returns all possible variants for matching.
"""
value = TextCleaner.clean_text(value)
value_upper = value.upper()
variants = set()
# 提取纯数字
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
# 如果是 VAT 格式,提取中间的 org number
# SE + 10 digits + 01 = "SE556123456701"
if value_upper.startswith('SE') and len(digits) == 12 and digits.endswith('01'):
# VAT format: SE + org_number + 01
digits = digits[:10]
elif digits.startswith('46') and len(digits) == 14:
# SE prefix in numeric (46 is SE in phone code): 46 + 10 digits + 01
digits = digits[2:12]
if len(digits) == 12:
# 12 位数字可能是带世纪前缀的: NNNNNNNN-NNNN (19556123-4567)
variants.add(value)
variants.add(digits) # 195561234567
# 带横线格式
variants.add(f"{digits[:8]}-{digits[8:]}") # 19556123-4567
# 提取后 10 位作为标准 org number
short_digits = digits[2:] # 5561234567
variants.add(short_digits)
variants.add(f"{short_digits[:6]}-{short_digits[6:]}") # 556123-4567
# VAT 格式
variants.add(f"SE{short_digits}01") # SE556123456701
return list(v for v in variants if v)
if len(digits) != 10:
# 如果不是标准 10 位,返回原始值和清洗后的变体
variants.add(value)
if digits:
variants.add(digits)
return list(variants)
# 生成所有变体
# 1. 纯数字
variants.add(digits) # 5561234567
# 2. 标准格式 (NNNNNN-NNNN)
with_hyphen = f"{digits[:6]}-{digits[6:]}"
variants.add(with_hyphen) # 556123-4567
# 3. VAT 格式
vat_compact = f"SE{digits}01"
variants.add(vat_compact) # SE556123456701
variants.add(vat_compact.lower()) # se556123456701
vat_spaced = f"SE {digits[:6]}-{digits[6:]} 01"
variants.add(vat_spaced) # SE 556123-4567 01
vat_spaced_no_hyphen = f"SE {digits} 01"
variants.add(vat_spaced_no_hyphen) # SE 5561234567 01
# 4. 有时带国家代码但无 01 后缀
variants.add(f"SE{digits}") # SE5561234567
variants.add(f"SE {digits}") # SE 5561234567
variants.add(f"SE{digits[:6]}-{digits[6:]}") # SE556123-4567
# 5. OCR 可能的错误变体
ocr_variants = TextCleaner.generate_ocr_variants(digits)
for ocr_var in ocr_variants:
if len(ocr_var) == 10:
variants.add(ocr_var)
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
return list(v for v in variants if v)
# =========================================================================
# Bankgiro Variants
# =========================================================================
@classmethod
def bankgiro_variants(cls, value: str) -> list[str]:
"""
Generate variants for Bankgiro number.
Formats:
- 7 digits: XXX-XXXX (e.g., 123-4567)
- 8 digits: XXXX-XXXX (e.g., 1234-5678)
"""
value = TextCleaner.clean_text(value)
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
variants = set()
variants.add(value)
if not digits or len(digits) < 7 or len(digits) > 8:
return list(v for v in variants if v)
# 纯数字
variants.add(digits)
# 带横线格式
if len(digits) == 7:
variants.add(f"{digits[:3]}-{digits[3:]}") # XXX-XXXX
elif len(digits) == 8:
variants.add(f"{digits[:4]}-{digits[4:]}") # XXXX-XXXX
# 有些 8 位也用 XXX-XXXXX 格式
variants.add(f"{digits[:3]}-{digits[3:]}")
# 带空格格式 (有时 OCR 会这样识别)
if len(digits) == 7:
variants.add(f"{digits[:3]} {digits[3:]}")
elif len(digits) == 8:
variants.add(f"{digits[:4]} {digits[4:]}")
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Plusgiro Variants
# =========================================================================
@classmethod
def plusgiro_variants(cls, value: str) -> list[str]:
"""
Generate variants for Plusgiro number.
Format: XXXXXXX-X (7 digits + check digit) or shorter
Examples: 1234567-8, 12345-6, 1-8
"""
value = TextCleaner.clean_text(value)
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
variants = set()
variants.add(value)
if not digits or len(digits) < 2 or len(digits) > 8:
return list(v for v in variants if v)
# 纯数字
variants.add(digits)
# Plusgiro 格式: 最后一位是校验位,用横线分隔
main_part = digits[:-1]
check_digit = digits[-1]
variants.add(f"{main_part}-{check_digit}")
# 有时带空格
variants.add(f"{main_part} {check_digit}")
# 分组格式 (常见于长号码): XX XX XX-X
if len(digits) >= 6:
# 尝试 XX XX XX-X 格式
spaced = ' '.join([digits[i:i + 2] for i in range(0, len(digits) - 1, 2)])
if len(digits) % 2 == 0:
spaced = spaced[:-1] + '-' + digits[-1]
else:
spaced = spaced + '-' + digits[-1]
variants.add(spaced.replace('- ', '-'))
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Amount Variants
# =========================================================================
@classmethod
def amount_variants(cls, value: str) -> list[str]:
"""
Generate variants for monetary amounts.
Handles:
- Swedish: 1 234,56 (space thousand, comma decimal)
- German: 1.234,56 (dot thousand, comma decimal)
- US/UK: 1,234.56 (comma thousand, dot decimal)
- Integer: 1234 -> 1234.00
Returns variants with different separators and with/without decimals.
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 尝试解析为数值
amount = cls._parse_amount(value)
if amount is None:
return list(v for v in variants if v)
# 生成不同格式的变体
int_part = int(amount)
dec_part = round((amount - int_part) * 100)
# 1. 基础格式
if dec_part == 0:
variants.add(str(int_part)) # 1234
variants.add(f"{int_part}.00") # 1234.00
variants.add(f"{int_part},00") # 1234,00
else:
variants.add(f"{int_part}.{dec_part:02d}") # 1234.56
variants.add(f"{int_part},{dec_part:02d}") # 1234,56
# 2. 带千位分隔符
int_str = str(int_part)
if len(int_str) > 3:
# 从右往左每3位加分隔符
parts = []
while int_str:
parts.append(int_str[-3:])
int_str = int_str[:-3]
parts.reverse()
# 空格分隔 (Swedish)
space_sep = ' '.join(parts)
if dec_part == 0:
variants.add(space_sep)
else:
variants.add(f"{space_sep},{dec_part:02d}")
variants.add(f"{space_sep}.{dec_part:02d}")
# 点分隔 (German)
dot_sep = '.'.join(parts)
if dec_part == 0:
variants.add(dot_sep)
else:
variants.add(f"{dot_sep},{dec_part:02d}")
# 逗号分隔 (US)
comma_sep = ','.join(parts)
if dec_part == 0:
variants.add(comma_sep)
else:
variants.add(f"{comma_sep}.{dec_part:02d}")
# 3. 带货币符号
base_amounts = [f"{int_part}.{dec_part:02d}", f"{int_part},{dec_part:02d}"]
if dec_part == 0:
base_amounts.append(str(int_part))
for base in base_amounts:
variants.add(f"{base} kr")
variants.add(f"{base} SEK")
variants.add(f"{base}kr")
variants.add(f"SEK {base}")
return list(v for v in variants if v)
@classmethod
def _parse_amount(cls, text: str) -> Optional[float]:
"""Parse amount from various formats."""
text = TextCleaner.normalize_amount_text(text)
# 移除所有非数字和分隔符
clean = re.sub(r'[^\d,.\s]', '', text)
if not clean:
return None
# 检测格式
# 瑞典格式: 1 234,56 或 1234,56
if re.match(r'^[\d\s]+,\d{2}$', clean):
clean = clean.replace(' ', '').replace(',', '.')
try:
return float(clean)
except ValueError:
pass
# 德国格式: 1.234,56
if re.match(r'^[\d.]+,\d{2}$', clean):
clean = clean.replace('.', '').replace(',', '.')
try:
return float(clean)
except ValueError:
pass
# 美国格式: 1,234.56
if re.match(r'^[\d,]+\.\d{2}$', clean):
clean = clean.replace(',', '')
try:
return float(clean)
except ValueError:
pass
# 简单格式
clean = clean.replace(' ', '').replace(',', '.')
# 如果有多个点,只保留最后一个
if clean.count('.') > 1:
parts = clean.rsplit('.', 1)
clean = parts[0].replace('.', '') + '.' + parts[1]
try:
return float(clean)
except ValueError:
return None
# =========================================================================
# Date Variants
# =========================================================================
@classmethod
def date_variants(cls, value: str) -> list[str]:
"""
Generate variants for dates.
Input can be:
- ISO: 2024-12-29
- European: 29/12/2024, 29.12.2024
- Swedish text: "29 december 2024"
- Compact: 20241229
Returns all format variants.
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 尝试解析日期
parsed = cls._parse_date(value)
if parsed is None:
return list(v for v in variants if v)
year, month, day = parsed
# 生成所有格式变体
# ISO
variants.add(f"{year}-{month:02d}-{day:02d}")
variants.add(f"{year}-{month}-{day}") # 不补零
# 点分隔 (Swedish common)
variants.add(f"{year}.{month:02d}.{day:02d}")
variants.add(f"{day:02d}.{month:02d}.{year}")
# 斜杠分隔
variants.add(f"{day:02d}/{month:02d}/{year}")
variants.add(f"{month:02d}/{day:02d}/{year}") # US format
variants.add(f"{year}/{month:02d}/{day:02d}")
# 紧凑格式
variants.add(f"{year}{month:02d}{day:02d}")
# 带月份名 (Swedish)
for month_name, month_num in cls.SWEDISH_MONTHS.items():
if month_num == f"{month:02d}":
variants.add(f"{day} {month_name} {year}")
variants.add(f"{day:02d} {month_name} {year}")
# 首字母大写
variants.add(f"{day} {month_name.capitalize()} {year}")
# 短年份
short_year = str(year)[2:]
variants.add(f"{day:02d}.{month:02d}.{short_year}")
variants.add(f"{day:02d}/{month:02d}/{short_year}")
variants.add(f"{short_year}-{month:02d}-{day:02d}")
return list(v for v in variants if v)
@classmethod
def _parse_date(cls, text: str) -> Optional[tuple[int, int, int]]:
"""
Parse date from text, returns (year, month, day) or None.
"""
text = TextCleaner.clean_text(text).lower()
# ISO: 2024-12-29
match = re.search(r'(\d{4})-(\d{1,2})-(\d{1,2})', text)
if match:
return int(match.group(1)), int(match.group(2)), int(match.group(3))
# Dot format: 2024.12.29
match = re.search(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', text)
if match:
return int(match.group(1)), int(match.group(2)), int(match.group(3))
# European: 29/12/2024 or 29.12.2024
match = re.search(r'(\d{1,2})[/.](\d{1,2})[/.](\d{4})', text)
if match:
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
# 验证日期合理性
if 1 <= day <= 31 and 1 <= month <= 12:
return year, month, day
# Compact: 20241229
match = re.search(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', text)
if match:
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
if 2000 <= year <= 2100 and 1 <= month <= 12 and 1 <= day <= 31:
return year, month, day
# Swedish month name: "29 december 2024"
for month_name, month_num in cls.SWEDISH_MONTHS.items():
pattern = rf'(\d{{1,2}})\s*{month_name}\s*(\d{{4}})'
match = re.search(pattern, text)
if match:
day, year = int(match.group(1)), int(match.group(2))
return year, int(month_num), day
return None
# =========================================================================
# Invoice Number Variants
# =========================================================================
@classmethod
def invoice_number_variants(cls, value: str) -> list[str]:
"""
Generate variants for invoice numbers.
Invoice numbers are highly variable:
- Pure digits: 12345678
- Alphanumeric: A3861, INV-2024-001
- With separators: 2024/001
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 提取数字部分
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
variants.add(digits)
# 大小写变体
variants.add(value.upper())
variants.add(value.lower())
# 移除分隔符
no_sep = re.sub(r'[-/\s]', '', value)
variants.add(no_sep)
variants.add(no_sep.upper())
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(value):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# OCR Number Variants
# =========================================================================
@classmethod
def ocr_number_variants(cls, value: str) -> list[str]:
"""
Generate variants for OCR reference numbers.
OCR numbers are typically 10-25 digits.
"""
value = TextCleaner.clean_text(value)
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
variants = set()
variants.add(value)
if digits:
variants.add(digits)
# 有些 OCR 号码带空格分组
if len(digits) > 4:
# 每 4 位分组
spaced = ' '.join([digits[i:i + 4] for i in range(0, len(digits), 4)])
variants.add(spaced)
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Customer Number Variants
# =========================================================================
@classmethod
def customer_number_variants(cls, value: str) -> list[str]:
"""
Generate variants for customer numbers.
Customer numbers can be very diverse:
- Pure digits: 12345
- Alphanumeric: ABC123, EMM 256-6
- With separators: 123-456
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 大小写
variants.add(value.upper())
variants.add(value.lower())
# 移除所有分隔符和空格
compact = re.sub(r'[-/\s]', '', value)
variants.add(compact)
variants.add(compact.upper())
variants.add(compact.lower())
# 纯数字
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
variants.add(digits)
# 纯字母 + 数字 (分离)
letters = re.sub(r'[^a-zA-Z]', '', value)
if letters and digits:
variants.add(f"{letters}{digits}")
variants.add(f"{letters.upper()}{digits}")
variants.add(f"{letters} {digits}")
variants.add(f"{letters.upper()} {digits}")
variants.add(f"{letters}-{digits}")
variants.add(f"{letters.upper()}-{digits}")
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(value):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Generic Field Variants
# =========================================================================
@classmethod
def get_variants(cls, field_name: str, value: str) -> list[str]:
"""
Get variants for a field by name.
This is the main entry point - dispatches to specific variant generators.
"""
if not value:
return []
field_lower = field_name.lower()
# 映射字段名到变体生成器
if 'organisation' in field_lower or 'org' in field_lower:
return cls.organisation_number_variants(value)
elif 'bankgiro' in field_lower or field_lower == 'bg':
return cls.bankgiro_variants(value)
elif 'plusgiro' in field_lower or field_lower == 'pg':
return cls.plusgiro_variants(value)
elif 'amount' in field_lower or 'belopp' in field_lower:
return cls.amount_variants(value)
elif 'date' in field_lower or 'datum' in field_lower:
return cls.date_variants(value)
elif 'invoice' in field_lower and 'number' in field_lower:
return cls.invoice_number_variants(value)
elif field_lower == 'invoicenumber':
return cls.invoice_number_variants(value)
elif 'ocr' in field_lower:
return cls.ocr_number_variants(value)
elif 'customer' in field_lower:
return cls.customer_number_variants(value)
else:
# 默认: 返回原值和基本清洗
return [value, TextCleaner.clean_text(value)]

417
src/utils/fuzzy_matcher.py Normal file
View File

@@ -0,0 +1,417 @@
"""
Fuzzy Matching Module
Provides fuzzy string matching with OCR-aware similarity scoring.
Handles common OCR errors and format variations in invoice fields.
"""
import re
from typing import Optional
from dataclasses import dataclass
from .text_cleaner import TextCleaner
@dataclass
class FuzzyMatchResult:
"""Result of a fuzzy match operation."""
matched: bool
score: float # 0.0 to 1.0
ocr_value: str
expected_value: str
normalized_ocr: str
normalized_expected: str
match_type: str # 'exact', 'normalized', 'fuzzy', 'ocr_corrected'
class FuzzyMatcher:
"""
Fuzzy string matcher optimized for OCR text matching.
Provides multiple matching strategies:
1. Exact match
2. Normalized match (case-insensitive, whitespace-normalized)
3. OCR-corrected match (applying common OCR error corrections)
4. Edit distance based fuzzy match
5. Digit-sequence match (for numeric fields)
"""
# Minimum similarity threshold for fuzzy matches
DEFAULT_THRESHOLD = 0.85
# Field-specific thresholds (some fields need stricter matching)
FIELD_THRESHOLDS = {
'InvoiceNumber': 0.90,
'OCR': 0.95, # OCR numbers need high precision
'Amount': 0.95,
'Bankgiro': 0.90,
'Plusgiro': 0.90,
'InvoiceDate': 0.90,
'InvoiceDueDate': 0.90,
'supplier_organisation_number': 0.85,
'customer_number': 0.80, # More lenient for customer numbers
}
@classmethod
def get_threshold(cls, field_name: str) -> float:
"""Get the matching threshold for a specific field."""
return cls.FIELD_THRESHOLDS.get(field_name, cls.DEFAULT_THRESHOLD)
@classmethod
def levenshtein_distance(cls, s1: str, s2: str) -> int:
"""
Calculate Levenshtein (edit) distance between two strings.
This is the minimum number of single-character edits
(insertions, deletions, substitutions) needed to change s1 into s2.
"""
if len(s1) < len(s2):
return cls.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
# Cost is 0 if characters match, 1 otherwise
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
@classmethod
def similarity_ratio(cls, s1: str, s2: str) -> float:
"""
Calculate similarity ratio between two strings.
Returns a value between 0.0 (completely different) and 1.0 (identical).
Based on Levenshtein distance normalized by the length of the longer string.
"""
if not s1 and not s2:
return 1.0
if not s1 or not s2:
return 0.0
max_len = max(len(s1), len(s2))
distance = cls.levenshtein_distance(s1, s2)
return 1.0 - (distance / max_len)
@classmethod
def ocr_aware_similarity(cls, ocr_text: str, expected: str) -> float:
"""
Calculate similarity with OCR error awareness.
This method considers common OCR errors when calculating similarity,
giving higher scores when differences are likely OCR mistakes.
"""
if not ocr_text or not expected:
return 0.0 if ocr_text != expected else 1.0
# First try exact match
if ocr_text == expected:
return 1.0
# Try with OCR corrections applied to ocr_text
corrected = TextCleaner.apply_ocr_digit_corrections(ocr_text)
if corrected == expected:
return 0.98 # Slightly less than exact match
# Try normalized comparison
norm_ocr = TextCleaner.normalize_for_comparison(ocr_text)
norm_expected = TextCleaner.normalize_for_comparison(expected)
if norm_ocr == norm_expected:
return 0.95
# Calculate base similarity
base_sim = cls.similarity_ratio(norm_ocr, norm_expected)
# Boost score if differences are common OCR errors
boost = cls._calculate_ocr_error_boost(ocr_text, expected)
return min(1.0, base_sim + boost)
@classmethod
def _calculate_ocr_error_boost(cls, ocr_text: str, expected: str) -> float:
"""
Calculate a score boost based on whether differences are likely OCR errors.
Returns a value between 0.0 and 0.1.
"""
if len(ocr_text) != len(expected):
return 0.0
ocr_errors = 0
total_diffs = 0
for oc, ec in zip(ocr_text, expected):
if oc != ec:
total_diffs += 1
# Check if this is a known OCR confusion pair
if cls._is_ocr_confusion_pair(oc, ec):
ocr_errors += 1
if total_diffs == 0:
return 0.0
# Boost proportional to how many differences are OCR errors
ocr_error_ratio = ocr_errors / total_diffs
return ocr_error_ratio * 0.1
@classmethod
def _is_ocr_confusion_pair(cls, char1: str, char2: str) -> bool:
"""Check if two characters are commonly confused in OCR."""
confusion_pairs = {
('0', 'O'), ('0', 'o'), ('0', 'D'), ('0', 'Q'),
('1', 'l'), ('1', 'I'), ('1', 'i'), ('1', '|'),
('2', 'Z'), ('2', 'z'),
('5', 'S'), ('5', 's'),
('6', 'G'), ('6', 'b'),
('8', 'B'),
('9', 'g'), ('9', 'q'),
}
pair = (char1, char2)
return pair in confusion_pairs or (char2, char1) in confusion_pairs
@classmethod
def match_digits(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
"""
Match digit sequences with OCR error tolerance.
Optimized for numeric fields like OCR numbers, amounts, etc.
"""
# Extract digits
ocr_digits = TextCleaner.extract_digits(ocr_text, apply_ocr_correction=True)
expected_digits = TextCleaner.extract_digits(expected, apply_ocr_correction=False)
# Exact match after extraction
if ocr_digits == expected_digits:
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_digits,
normalized_expected=expected_digits,
match_type='exact'
)
# Calculate similarity
score = cls.ocr_aware_similarity(ocr_digits, expected_digits)
return FuzzyMatchResult(
matched=score >= threshold,
score=score,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_digits,
normalized_expected=expected_digits,
match_type='fuzzy' if score >= threshold else 'no_match'
)
@classmethod
def match_amount(cls, ocr_text: str, expected: str, threshold: float = 0.95) -> FuzzyMatchResult:
"""
Match monetary amounts with format tolerance.
Handles different decimal separators (. vs ,) and thousand separators.
"""
from .validators import FieldValidators
# Parse both amounts
ocr_amount = FieldValidators.parse_amount(ocr_text)
expected_amount = FieldValidators.parse_amount(expected)
if ocr_amount is None or expected_amount is None:
# Can't parse, fall back to string matching
return cls.match_string(ocr_text, expected, threshold)
# Compare numeric values
if abs(ocr_amount - expected_amount) < 0.01: # Within 1 cent
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=f"{ocr_amount:.2f}",
normalized_expected=f"{expected_amount:.2f}",
match_type='exact'
)
# Calculate relative difference
max_val = max(abs(ocr_amount), abs(expected_amount))
if max_val > 0:
diff_ratio = abs(ocr_amount - expected_amount) / max_val
score = max(0.0, 1.0 - diff_ratio)
else:
score = 1.0 if ocr_amount == expected_amount else 0.0
return FuzzyMatchResult(
matched=score >= threshold,
score=score,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=f"{ocr_amount:.2f}" if ocr_amount else ocr_text,
normalized_expected=f"{expected_amount:.2f}" if expected_amount else expected,
match_type='fuzzy' if score >= threshold else 'no_match'
)
@classmethod
def match_date(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
"""
Match dates with format tolerance.
Handles different date formats (ISO, European, compact, etc.)
"""
from .validators import FieldValidators
# Parse both dates to ISO format
ocr_iso = FieldValidators.format_date_iso(ocr_text)
expected_iso = FieldValidators.format_date_iso(expected)
if ocr_iso and expected_iso:
if ocr_iso == expected_iso:
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_iso,
normalized_expected=expected_iso,
match_type='exact'
)
# Fall back to string matching on digits
return cls.match_digits(ocr_text, expected, threshold)
@classmethod
def match_string(cls, ocr_text: str, expected: str, threshold: float = 0.85) -> FuzzyMatchResult:
"""
General string matching with multiple strategies.
Tries exact, normalized, and fuzzy matching in order.
"""
# Clean both strings
ocr_clean = TextCleaner.clean_text(ocr_text)
expected_clean = TextCleaner.clean_text(expected)
# Strategy 1: Exact match
if ocr_clean == expected_clean:
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_clean,
normalized_expected=expected_clean,
match_type='exact'
)
# Strategy 2: Case-insensitive match
if ocr_clean.lower() == expected_clean.lower():
return FuzzyMatchResult(
matched=True,
score=0.98,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_clean,
normalized_expected=expected_clean,
match_type='normalized'
)
# Strategy 3: OCR-corrected match
ocr_corrected = TextCleaner.apply_ocr_digit_corrections(ocr_clean)
if ocr_corrected == expected_clean:
return FuzzyMatchResult(
matched=True,
score=0.95,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_corrected,
normalized_expected=expected_clean,
match_type='ocr_corrected'
)
# Strategy 4: Fuzzy match
score = cls.ocr_aware_similarity(ocr_clean, expected_clean)
return FuzzyMatchResult(
matched=score >= threshold,
score=score,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_clean,
normalized_expected=expected_clean,
match_type='fuzzy' if score >= threshold else 'no_match'
)
@classmethod
def match_field(
cls,
field_name: str,
ocr_value: str,
expected_value: str,
threshold: Optional[float] = None
) -> FuzzyMatchResult:
"""
Match a field value using field-appropriate strategy.
Automatically selects the best matching strategy based on field type.
"""
if threshold is None:
threshold = cls.get_threshold(field_name)
field_lower = field_name.lower()
# Route to appropriate matcher
if 'amount' in field_lower or 'belopp' in field_lower:
return cls.match_amount(ocr_value, expected_value, threshold)
if 'date' in field_lower or 'datum' in field_lower:
return cls.match_date(ocr_value, expected_value, threshold)
if any(x in field_lower for x in ['ocr', 'bankgiro', 'plusgiro', 'org']):
# Numeric fields with OCR errors
return cls.match_digits(ocr_value, expected_value, threshold)
if 'invoice' in field_lower and 'number' in field_lower:
# Invoice numbers can be alphanumeric
return cls.match_string(ocr_value, expected_value, threshold)
# Default to string matching
return cls.match_string(ocr_value, expected_value, threshold)
@classmethod
def find_best_match(
cls,
ocr_value: str,
candidates: list[str],
field_name: str = '',
threshold: Optional[float] = None
) -> Optional[tuple[str, FuzzyMatchResult]]:
"""
Find the best matching candidate from a list.
Returns (matched_value, match_result) or None if no match above threshold.
"""
if threshold is None:
threshold = cls.get_threshold(field_name) if field_name else cls.DEFAULT_THRESHOLD
best_match = None
best_result = None
for candidate in candidates:
result = cls.match_field(field_name, ocr_value, candidate, threshold=0.0)
if result.score >= threshold:
if best_result is None or result.score > best_result.score:
best_match = candidate
best_result = result
if best_match:
return (best_match, best_result)
return None

View File

@@ -0,0 +1,384 @@
"""
OCR Error Corrections Module
Provides comprehensive OCR error correction tables and correction functions.
Based on common OCR recognition errors in Swedish invoice documents.
"""
import re
from typing import Optional
from dataclasses import dataclass
@dataclass
class CorrectionResult:
"""Result of an OCR correction operation."""
original: str
corrected: str
corrections_applied: list[tuple[int, str, str]] # (position, from_char, to_char)
confidence: float # How confident we are in the correction
class OCRCorrections:
"""
Comprehensive OCR error correction utilities.
Provides:
- Character-level corrections for digits
- Word-level corrections for common Swedish terms
- Context-aware corrections
- Multiple correction strategies
"""
# =========================================================================
# Character-level OCR errors (digit fields)
# =========================================================================
# Characters commonly misread as digits
CHAR_TO_DIGIT = {
# Letters that look like digits
'O': '0', 'o': '0', # O -> 0
'Q': '0', # Q -> 0 (less common)
'D': '0', # D -> 0 (in some fonts)
'l': '1', 'I': '1', # l/I -> 1
'i': '1', # i without dot -> 1
'|': '1', # pipe -> 1
'!': '1', # exclamation -> 1
'Z': '2', 'z': '2', # Z -> 2
'E': '3', # E -> 3 (rare)
'A': '4', 'h': '4', # A/h -> 4 (in some fonts)
'S': '5', 's': '5', # S -> 5
'G': '6', 'b': '6', # G/b -> 6
'T': '7', 't': '7', # T -> 7 (rare)
'B': '8', # B -> 8
'g': '9', 'q': '9', # g/q -> 9
}
# Digits commonly misread as other characters
DIGIT_TO_CHAR = {
'0': ['O', 'o', 'D', 'Q'],
'1': ['l', 'I', 'i', '|', '!'],
'2': ['Z', 'z'],
'3': ['E'],
'4': ['A', 'h'],
'5': ['S', 's'],
'6': ['G', 'b'],
'7': ['T', 't'],
'8': ['B'],
'9': ['g', 'q'],
}
# Bidirectional confusion pairs (either direction is possible)
CONFUSION_PAIRS = [
('0', 'O'), ('0', 'o'), ('0', 'D'),
('1', 'l'), ('1', 'I'), ('1', '|'),
('2', 'Z'), ('2', 'z'),
('5', 'S'), ('5', 's'),
('6', 'G'), ('6', 'b'),
('8', 'B'),
('9', 'g'), ('9', 'q'),
]
# =========================================================================
# Word-level OCR errors (Swedish invoice terms)
# =========================================================================
# Common Swedish invoice terms and their OCR misreadings
SWEDISH_TERM_CORRECTIONS = {
# Faktura (Invoice)
'faktura': ['Faktura', 'FAKTURA', 'faktúra', 'faKtura'],
'fakturanummer': ['Fakturanummer', 'FAKTURANUMMER', 'fakturanr', 'fakt.nr'],
'fakturadatum': ['Fakturadatum', 'FAKTURADATUM', 'fakt.datum'],
# Belopp (Amount)
'belopp': ['Belopp', 'BELOPP', 'be1opp', 'bel0pp'],
'summa': ['Summa', 'SUMMA', '5umma'],
'total': ['Total', 'TOTAL', 'tota1', 't0tal'],
'moms': ['Moms', 'MOMS', 'm0ms'],
# Dates
'förfallodatum': ['Förfallodatum', 'FÖRFALLODATUM', 'förfa11odatum'],
'datum': ['Datum', 'DATUM', 'dátum'],
# Payment
'bankgiro': ['Bankgiro', 'BANKGIRO', 'BG', 'bg', 'bank giro'],
'plusgiro': ['Plusgiro', 'PLUSGIRO', 'PG', 'pg', 'plus giro'],
'postgiro': ['Postgiro', 'POSTGIRO'],
'ocr': ['OCR', 'ocr', '0CR', 'OcR'],
# Organization
'organisationsnummer': ['Organisationsnummer', 'ORGANISATIONSNUMMER', 'org.nr', 'orgnr'],
'kundnummer': ['Kundnummer', 'KUNDNUMMER', 'kund nr', 'kundnr'],
# Currency
'kronor': ['Kronor', 'KRONOR', 'kr', 'KR', 'SEK', 'sek'],
'öre': ['Öre', 'ÖRE', 'ore', 'ORE'],
}
# =========================================================================
# Context patterns
# =========================================================================
# Patterns that indicate the following/preceding text is a specific field
CONTEXT_INDICATORS = {
'invoice_number': [
r'faktura\s*(?:nr|nummer)?[:\s]*',
r'invoice\s*(?:no|number)?[:\s]*',
r'fakt\.?\s*nr[:\s]*',
r'inv[:\s]*#?',
],
'amount': [
r'(?:att\s+)?betala[:\s]*',
r'total[t]?[:\s]*',
r'summa[:\s]*',
r'belopp[:\s]*',
r'amount[:\s]*',
],
'date': [
r'datum[:\s]*',
r'date[:\s]*',
r'förfall(?:o)?datum[:\s]*',
r'fakturadatum[:\s]*',
],
'ocr': [
r'ocr[:\s]*',
r'referens[:\s]*',
r'betalningsreferens[:\s]*',
],
'bankgiro': [
r'bankgiro[:\s]*',
r'bg[:\s]*',
r'bank\s*giro[:\s]*',
],
'plusgiro': [
r'plusgiro[:\s]*',
r'pg[:\s]*',
r'plus\s*giro[:\s]*',
r'postgiro[:\s]*',
],
'org_number': [
r'org\.?\s*(?:nr|nummer)?[:\s]*',
r'organisationsnummer[:\s]*',
r'vat[:\s]*',
r'moms(?:reg)?\.?\s*(?:nr|nummer)?[:\s]*',
],
}
# =========================================================================
# Correction Methods
# =========================================================================
@classmethod
def correct_digits(cls, text: str, aggressive: bool = False) -> CorrectionResult:
"""
Apply digit corrections to text.
Args:
text: Input text
aggressive: If True, correct all potential digit-like characters.
If False, only correct characters adjacent to digits.
Returns:
CorrectionResult with original, corrected text, and details.
"""
corrections = []
result = []
for i, char in enumerate(text):
if char.isdigit():
result.append(char)
elif char in cls.CHAR_TO_DIGIT:
if aggressive:
# Always correct
corrected_char = cls.CHAR_TO_DIGIT[char]
corrections.append((i, char, corrected_char))
result.append(corrected_char)
else:
# Only correct if adjacent to digit
prev_is_digit = i > 0 and (text[i-1].isdigit() or text[i-1] in cls.CHAR_TO_DIGIT)
next_is_digit = i < len(text) - 1 and (text[i+1].isdigit() or text[i+1] in cls.CHAR_TO_DIGIT)
if prev_is_digit or next_is_digit:
corrected_char = cls.CHAR_TO_DIGIT[char]
corrections.append((i, char, corrected_char))
result.append(corrected_char)
else:
result.append(char)
else:
result.append(char)
corrected = ''.join(result)
confidence = 1.0 - (len(corrections) * 0.05) # Decrease confidence per correction
return CorrectionResult(
original=text,
corrected=corrected,
corrections_applied=corrections,
confidence=max(0.5, confidence)
)
@classmethod
def generate_digit_variants(cls, text: str) -> list[str]:
"""
Generate all possible digit interpretations of a text.
Useful for matching when we don't know which direction the OCR error went.
"""
if not text:
return [text]
variants = {text}
# For each character that could be confused
for i, char in enumerate(text):
new_variants = set()
for existing in variants:
# If it's a digit, add letter variants
if char.isdigit() and char in cls.DIGIT_TO_CHAR:
for replacement in cls.DIGIT_TO_CHAR[char]:
new_variants.add(existing[:i] + replacement + existing[i+1:])
# If it's a letter that looks like a digit, add digit variant
if char in cls.CHAR_TO_DIGIT:
new_variants.add(existing[:i] + cls.CHAR_TO_DIGIT[char] + existing[i+1:])
variants.update(new_variants)
# Limit explosion - only keep reasonable number
if len(variants) > 100:
break
return list(variants)
@classmethod
def correct_swedish_term(cls, text: str) -> str:
"""
Correct common Swedish invoice terms that may have OCR errors.
"""
text_lower = text.lower()
for canonical, variants in cls.SWEDISH_TERM_CORRECTIONS.items():
for variant in variants:
if variant.lower() in text_lower:
# Replace with canonical form (preserving case of first letter)
pattern = re.compile(re.escape(variant), re.IGNORECASE)
if text[0].isupper():
replacement = canonical.capitalize()
else:
replacement = canonical
text = pattern.sub(replacement, text)
return text
@classmethod
def extract_with_context(cls, text: str, field_type: str) -> Optional[str]:
"""
Extract a field value using context indicators.
Looks for patterns like "Fakturanr: 12345" and extracts "12345".
"""
patterns = cls.CONTEXT_INDICATORS.get(field_type, [])
for pattern in patterns:
# Look for pattern followed by value
full_pattern = pattern + r'([^\s,;]+)'
match = re.search(full_pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None
@classmethod
def is_likely_ocr_error(cls, char1: str, char2: str) -> bool:
"""
Check if two characters are commonly confused in OCR.
"""
pair = (char1, char2)
reverse_pair = (char2, char1)
for p in cls.CONFUSION_PAIRS:
if pair == p or reverse_pair == p:
return True
return False
@classmethod
def count_potential_ocr_errors(cls, s1: str, s2: str) -> tuple[int, int]:
"""
Count how many character differences between two strings
are likely OCR errors vs other differences.
Returns: (ocr_errors, other_errors)
"""
if len(s1) != len(s2):
return (0, abs(len(s1) - len(s2)))
ocr_errors = 0
other_errors = 0
for c1, c2 in zip(s1, s2):
if c1 != c2:
if cls.is_likely_ocr_error(c1, c2):
ocr_errors += 1
else:
other_errors += 1
return (ocr_errors, other_errors)
@classmethod
def suggest_corrections(cls, text: str, expected_type: str = 'digit') -> list[tuple[str, float]]:
"""
Suggest possible corrections for a text with confidence scores.
Returns list of (corrected_text, confidence) tuples, sorted by confidence.
"""
suggestions = []
if expected_type == 'digit':
# Apply digit corrections with different levels of aggressiveness
mild = cls.correct_digits(text, aggressive=False)
if mild.corrected != text:
suggestions.append((mild.corrected, mild.confidence))
aggressive = cls.correct_digits(text, aggressive=True)
if aggressive.corrected != text and aggressive.corrected != mild.corrected:
suggestions.append((aggressive.corrected, aggressive.confidence * 0.9))
# Generate variants
variants = cls.generate_digit_variants(text)
for variant in variants[:10]: # Limit to top 10
if variant != text and variant not in [s[0] for s in suggestions]:
# Lower confidence for variants
suggestions.append((variant, 0.7))
# Sort by confidence
suggestions.sort(key=lambda x: x[1], reverse=True)
return suggestions
# =========================================================================
# Convenience functions
# =========================================================================
def correct_ocr_digits(text: str, aggressive: bool = False) -> str:
"""Convenience function to correct OCR digit errors."""
return OCRCorrections.correct_digits(text, aggressive).corrected
def generate_ocr_variants(text: str) -> list[str]:
"""Convenience function to generate OCR variants."""
return OCRCorrections.generate_digit_variants(text)
def is_ocr_confusion(char1: str, char2: str) -> bool:
"""Convenience function to check if characters are OCR confusable."""
return OCRCorrections.is_likely_ocr_error(char1, char2)

View File

@@ -0,0 +1,399 @@
"""
Tests for advanced utility modules:
- FuzzyMatcher
- OCRCorrections
- ContextExtractor
"""
import pytest
from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
from .ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants
from .context_extractor import ContextExtractor, extract_field_with_context
class TestFuzzyMatcher:
"""Tests for FuzzyMatcher class."""
def test_levenshtein_distance_identical(self):
"""Test distance for identical strings."""
assert FuzzyMatcher.levenshtein_distance("hello", "hello") == 0
def test_levenshtein_distance_one_char(self):
"""Test distance for one character difference."""
assert FuzzyMatcher.levenshtein_distance("hello", "hallo") == 1
assert FuzzyMatcher.levenshtein_distance("hello", "hell") == 1
assert FuzzyMatcher.levenshtein_distance("hello", "helloo") == 1
def test_levenshtein_distance_multiple(self):
"""Test distance for multiple differences."""
assert FuzzyMatcher.levenshtein_distance("hello", "world") == 4
assert FuzzyMatcher.levenshtein_distance("", "hello") == 5
def test_similarity_ratio_identical(self):
"""Test similarity for identical strings."""
assert FuzzyMatcher.similarity_ratio("hello", "hello") == 1.0
def test_similarity_ratio_similar(self):
"""Test similarity for similar strings."""
ratio = FuzzyMatcher.similarity_ratio("hello", "hallo")
assert 0.8 <= ratio <= 0.9 # One char different in 5-char string
def test_similarity_ratio_different(self):
"""Test similarity for different strings."""
ratio = FuzzyMatcher.similarity_ratio("hello", "world")
assert ratio < 0.5
def test_ocr_aware_similarity_exact(self):
"""Test OCR-aware similarity for exact match."""
assert FuzzyMatcher.ocr_aware_similarity("12345", "12345") == 1.0
def test_ocr_aware_similarity_ocr_error(self):
"""Test OCR-aware similarity with OCR error."""
# O instead of 0
score = FuzzyMatcher.ocr_aware_similarity("1234O", "12340")
assert score >= 0.9 # Should be high due to OCR correction
def test_ocr_aware_similarity_multiple_errors(self):
"""Test OCR-aware similarity with multiple OCR errors."""
# l instead of 1, O instead of 0
score = FuzzyMatcher.ocr_aware_similarity("l234O", "12340")
assert score >= 0.85
def test_match_digits_exact(self):
"""Test digit matching for exact match."""
result = FuzzyMatcher.match_digits("12345", "12345")
assert result.matched is True
assert result.score == 1.0
assert result.match_type == 'exact'
def test_match_digits_with_separators(self):
"""Test digit matching ignoring separators."""
result = FuzzyMatcher.match_digits("123-4567", "1234567")
assert result.matched is True
assert result.normalized_ocr == "1234567"
def test_match_digits_ocr_error(self):
"""Test digit matching with OCR error."""
result = FuzzyMatcher.match_digits("556O234567", "5560234567")
assert result.matched is True
assert result.score >= 0.9
def test_match_amount_exact(self):
"""Test amount matching for exact values."""
result = FuzzyMatcher.match_amount("1234.56", "1234.56")
assert result.matched is True
assert result.score == 1.0
def test_match_amount_different_formats(self):
"""Test amount matching with different formats."""
# Swedish vs US format
result = FuzzyMatcher.match_amount("1234,56", "1234.56")
assert result.matched is True
assert result.score >= 0.99
def test_match_amount_with_spaces(self):
"""Test amount matching with thousand separators."""
result = FuzzyMatcher.match_amount("1 234,56", "1234.56")
assert result.matched is True
def test_match_date_same_date_different_format(self):
"""Test date matching with different formats."""
result = FuzzyMatcher.match_date("2024-12-29", "29.12.2024")
assert result.matched is True
assert result.score >= 0.9
def test_match_date_different_dates(self):
"""Test date matching with different dates."""
result = FuzzyMatcher.match_date("2024-12-29", "2024-12-30")
assert result.matched is False
def test_match_string_exact(self):
"""Test string matching for exact match."""
result = FuzzyMatcher.match_string("Hello World", "Hello World")
assert result.matched is True
assert result.match_type == 'exact'
def test_match_string_case_insensitive(self):
"""Test string matching case insensitivity."""
result = FuzzyMatcher.match_string("HELLO", "hello")
assert result.matched is True
assert result.match_type == 'normalized'
def test_match_string_ocr_corrected(self):
"""Test string matching with OCR corrections."""
result = FuzzyMatcher.match_string("5561234567", "556l234567")
assert result.matched is True
def test_match_field_routes_correctly(self):
"""Test that match_field routes to correct matcher."""
# Amount field
result = FuzzyMatcher.match_field("Amount", "1234.56", "1234,56")
assert result.matched is True
# Date field
result = FuzzyMatcher.match_field("InvoiceDate", "2024-12-29", "29.12.2024")
assert result.matched is True
def test_find_best_match(self):
"""Test finding best match from candidates."""
candidates = ["12345", "12346", "99999"]
result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber")
assert result is not None
assert result[0] == "12345"
assert result[1].score == 1.0
def test_find_best_match_no_match(self):
"""Test finding best match when none above threshold."""
candidates = ["99999", "88888", "77777"]
result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber")
assert result is None
class TestOCRCorrections:
"""Tests for OCRCorrections class."""
def test_correct_digits_simple(self):
"""Test simple digit correction."""
result = OCRCorrections.correct_digits("556O23", aggressive=False)
assert result.corrected == "556023"
assert len(result.corrections_applied) == 1
def test_correct_digits_multiple(self):
"""Test multiple digit corrections."""
result = OCRCorrections.correct_digits("5S6l23", aggressive=False)
assert result.corrected == "556123"
assert len(result.corrections_applied) == 2
def test_correct_digits_aggressive(self):
"""Test aggressive mode corrects all potential errors."""
result = OCRCorrections.correct_digits("AB123", aggressive=True)
# A -> 4, B -> 8
assert result.corrected == "48123"
def test_correct_digits_non_aggressive(self):
"""Test non-aggressive mode only corrects adjacent."""
result = OCRCorrections.correct_digits("AB 123", aggressive=False)
# A and B are adjacent to each other and both in CHAR_TO_DIGIT,
# so they may be corrected. The key is digits are not affected.
assert "123" in result.corrected
def test_generate_digit_variants(self):
"""Test generating OCR variants."""
variants = OCRCorrections.generate_digit_variants("10")
# Should include original and variants like "1O", "I0", "IO", "l0", etc.
assert "10" in variants
assert "1O" in variants or "l0" in variants
def test_generate_digit_variants_limits(self):
"""Test that variant generation is limited."""
variants = OCRCorrections.generate_digit_variants("1234567890")
# Should be limited to prevent explosion (limit is ~100, but may slightly exceed)
assert len(variants) <= 150
def test_is_likely_ocr_error(self):
"""Test OCR error detection."""
assert OCRCorrections.is_likely_ocr_error('0', 'O') is True
assert OCRCorrections.is_likely_ocr_error('O', '0') is True
assert OCRCorrections.is_likely_ocr_error('1', 'l') is True
assert OCRCorrections.is_likely_ocr_error('5', 'S') is True
assert OCRCorrections.is_likely_ocr_error('A', 'Z') is False
def test_count_potential_ocr_errors(self):
"""Test counting OCR errors vs other errors."""
ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1O3", "103")
assert ocr_errors == 1 # O vs 0
assert other_errors == 0
ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1X3", "103")
assert ocr_errors == 0
assert other_errors == 1 # X vs 0, not a known pair
def test_suggest_corrections(self):
"""Test correction suggestions."""
suggestions = OCRCorrections.suggest_corrections("556O23", expected_type='digit')
assert len(suggestions) > 0
# First suggestion should be the corrected version
assert suggestions[0][0] == "556023"
def test_convenience_function_correct(self):
"""Test convenience function."""
assert correct_ocr_digits("556O23") == "556023"
def test_convenience_function_variants(self):
"""Test convenience function for variants."""
variants = generate_ocr_variants("10")
assert "10" in variants
class TestContextExtractor:
"""Tests for ContextExtractor class."""
def test_extract_invoice_number_with_label(self):
"""Test extracting invoice number after label."""
text = "Fakturanummer: 12345678"
candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber")
assert len(candidates) > 0
assert candidates[0].value == "12345678"
assert candidates[0].extraction_method == 'label'
def test_extract_invoice_number_swedish(self):
"""Test extracting with Swedish label."""
text = "Faktura nr: A12345"
candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber")
assert len(candidates) > 0
# Should extract A12345 or 12345
def test_extract_amount_with_label(self):
"""Test extracting amount after label."""
text = "Att betala: 1 234,56"
candidates = ContextExtractor.extract_with_label(text, "Amount")
assert len(candidates) > 0
def test_extract_amount_total(self):
"""Test extracting with total label."""
text = "Total: 5678,90 kr"
candidates = ContextExtractor.extract_with_label(text, "Amount")
assert len(candidates) > 0
def test_extract_date_with_label(self):
"""Test extracting date after label."""
text = "Fakturadatum: 2024-12-29"
candidates = ContextExtractor.extract_with_label(text, "InvoiceDate")
assert len(candidates) > 0
assert "2024-12-29" in candidates[0].value
def test_extract_due_date(self):
"""Test extracting due date."""
text = "Förfallodatum: 2025-01-15"
candidates = ContextExtractor.extract_with_label(text, "InvoiceDueDate")
assert len(candidates) > 0
def test_extract_bankgiro(self):
"""Test extracting Bankgiro."""
text = "Bankgiro: 1234-5678"
candidates = ContextExtractor.extract_with_label(text, "Bankgiro")
assert len(candidates) > 0
assert "1234-5678" in candidates[0].value or "12345678" in candidates[0].value
def test_extract_plusgiro(self):
"""Test extracting Plusgiro."""
text = "Plusgiro: 1234567-8"
candidates = ContextExtractor.extract_with_label(text, "Plusgiro")
assert len(candidates) > 0
def test_extract_ocr(self):
"""Test extracting OCR number."""
text = "OCR: 12345678901234"
candidates = ContextExtractor.extract_with_label(text, "OCR")
assert len(candidates) > 0
assert candidates[0].value == "12345678901234"
def test_extract_org_number(self):
"""Test extracting organization number."""
text = "Org.nr: 556123-4567"
candidates = ContextExtractor.extract_with_label(text, "supplier_organisation_number")
assert len(candidates) > 0
def test_extract_customer_number(self):
"""Test extracting customer number."""
text = "Kundnummer: EMM 256-6"
candidates = ContextExtractor.extract_with_label(text, "customer_number")
assert len(candidates) > 0
def test_extract_field_returns_sorted(self):
"""Test that extract_field returns sorted by confidence."""
text = "Fakturanummer: 12345 Invoice number: 67890"
candidates = ContextExtractor.extract_field(text, "InvoiceNumber")
if len(candidates) > 1:
# Should be sorted by confidence (descending)
assert candidates[0].confidence >= candidates[1].confidence
def test_extract_best(self):
"""Test extract_best returns single best candidate."""
text = "Fakturanummer: 12345678"
best = ContextExtractor.extract_best(text, "InvoiceNumber")
assert best is not None
assert best.value == "12345678"
def test_extract_best_no_match(self):
"""Test extract_best returns None when no match."""
text = "No invoice information here"
best = ContextExtractor.extract_best(text, "InvoiceNumber", validate=True)
# May or may not find something depending on validation
def test_extract_all_fields(self):
"""Test extracting all fields from text."""
text = """
Fakturanummer: 12345
Datum: 2024-12-29
Belopp: 1234,56
Bankgiro: 1234-5678
"""
results = ContextExtractor.extract_all_fields(text)
# Should find at least some fields
assert len(results) > 0
def test_identify_field_type(self):
"""Test identifying field type from context."""
text = "Fakturanummer: 12345"
field_type = ContextExtractor.identify_field_type(text, "12345")
assert field_type == "InvoiceNumber"
def test_convenience_function_extract(self):
"""Test convenience function."""
text = "Fakturanummer: 12345678"
value = extract_field_with_context(text, "InvoiceNumber")
assert value == "12345678"
class TestIntegration:
"""Integration tests combining multiple modules."""
def test_fuzzy_match_with_ocr_correction(self):
"""Test fuzzy matching with OCR correction."""
# Simulate OCR error: 0 -> O
ocr_text = "556O234567"
expected = "5560234567"
# First correct
corrected = correct_ocr_digits(ocr_text)
assert corrected == expected
# Then match
result = FuzzyMatcher.match_digits(ocr_text, expected)
assert result.matched is True
def test_context_extraction_with_fuzzy_match(self):
"""Test extracting value and fuzzy matching."""
text = "Fakturanummer: 1234S678" # S is OCR error for 5
# Extract
candidate = ContextExtractor.extract_best(text, "InvoiceNumber", validate=False)
assert candidate is not None
# Fuzzy match against expected
result = FuzzyMatcher.match_string(candidate.value, "12345678")
# Might match depending on threshold
if __name__ == "__main__":
pytest.main([__file__, "-v"])

235
src/utils/test_utils.py Normal file
View File

@@ -0,0 +1,235 @@
"""
Tests for shared utility modules.
"""
import pytest
from .text_cleaner import TextCleaner
from .format_variants import FormatVariants
from .validators import FieldValidators
class TestTextCleaner:
"""Tests for TextCleaner class."""
def test_clean_unicode_dashes(self):
"""Test normalization of various dash types."""
# en-dash
assert TextCleaner.clean_unicode("5561234567") == "556123-4567"
# em-dash
assert TextCleaner.clean_unicode("556123—4567") == "556123-4567"
# minus sign
assert TextCleaner.clean_unicode("5561234567") == "556123-4567"
def test_clean_unicode_spaces(self):
"""Test normalization of various space types."""
# non-breaking space
assert TextCleaner.clean_unicode("1\xa0234") == "1 234"
# zero-width space removed
assert TextCleaner.clean_unicode("123\u200b456") == "123456"
def test_ocr_digit_corrections(self):
"""Test OCR error corrections for digit fields."""
# O -> 0
assert TextCleaner.apply_ocr_digit_corrections("556O23") == "556023"
# l -> 1
assert TextCleaner.apply_ocr_digit_corrections("556l23") == "556123"
# S -> 5
assert TextCleaner.apply_ocr_digit_corrections("5S6123") == "556123"
# Mixed
assert TextCleaner.apply_ocr_digit_corrections("S56l23-4S67") == "556123-4567"
def test_extract_digits(self):
"""Test digit extraction with OCR correction."""
assert TextCleaner.extract_digits("556123-4567") == "5561234567"
assert TextCleaner.extract_digits("556O23-4567", apply_ocr_correction=True) == "5560234567"
# Without OCR correction, only extracts actual digits
assert TextCleaner.extract_digits("ABC 123 DEF", apply_ocr_correction=False) == "123"
# With OCR correction, standalone letters are not converted
# (they need to be adjacent to digits to be corrected)
assert TextCleaner.extract_digits("A 123 B", apply_ocr_correction=True) == "123"
def test_normalize_amount_text(self):
"""Test amount text normalization."""
assert TextCleaner.normalize_amount_text("1 234,56 kr") == "1234,56"
assert TextCleaner.normalize_amount_text("SEK 1234.56") == "1234.56"
assert TextCleaner.normalize_amount_text("1 234 567,89 kronor") == "1234567,89"
class TestFormatVariants:
"""Tests for FormatVariants class."""
def test_organisation_number_variants(self):
"""Test organisation number variant generation."""
variants = FormatVariants.organisation_number_variants("5561234567")
assert "5561234567" in variants # 纯数字
assert "556123-4567" in variants # 带横线
assert "SE556123456701" in variants # VAT格式
def test_organisation_number_from_vat(self):
"""Test extracting org number from VAT format."""
variants = FormatVariants.organisation_number_variants("SE556123456701")
assert "5561234567" in variants
assert "556123-4567" in variants
def test_bankgiro_variants(self):
"""Test Bankgiro variant generation."""
# 8 digits
variants = FormatVariants.bankgiro_variants("53939484")
assert "53939484" in variants
assert "5393-9484" in variants
# 7 digits
variants = FormatVariants.bankgiro_variants("1234567")
assert "1234567" in variants
assert "123-4567" in variants
def test_plusgiro_variants(self):
"""Test Plusgiro variant generation."""
variants = FormatVariants.plusgiro_variants("12345678")
assert "12345678" in variants
assert "1234567-8" in variants
def test_amount_variants(self):
"""Test amount variant generation."""
variants = FormatVariants.amount_variants("1234.56")
assert "1234.56" in variants
assert "1234,56" in variants
assert "1 234,56" in variants or "1234,56" in variants # Swedish format
def test_date_variants(self):
"""Test date variant generation."""
variants = FormatVariants.date_variants("2024-12-29")
assert "2024-12-29" in variants # ISO
assert "29.12.2024" in variants # European
assert "29/12/2024" in variants # European slash
assert "20241229" in variants # Compact
assert "29 december 2024" in variants # Swedish text
def test_invoice_number_variants(self):
"""Test invoice number variant generation."""
variants = FormatVariants.invoice_number_variants("INV-2024-001")
assert "INV-2024-001" in variants
assert "INV2024001" in variants # No separators
assert "inv-2024-001" in variants # Lowercase
def test_get_variants_dispatch(self):
"""Test get_variants dispatches to correct method."""
# Organisation number
org_variants = FormatVariants.get_variants("supplier_organisation_number", "5561234567")
assert "556123-4567" in org_variants
# Bankgiro
bg_variants = FormatVariants.get_variants("Bankgiro", "53939484")
assert "5393-9484" in bg_variants
# Amount
amount_variants = FormatVariants.get_variants("Amount", "1234.56")
assert "1234,56" in amount_variants
class TestFieldValidators:
"""Tests for FieldValidators class."""
def test_luhn_checksum_valid(self):
"""Test Luhn validation with valid numbers."""
# Valid Bankgiro numbers (with correct check digit)
assert FieldValidators.luhn_checksum("53939484") is True
# Valid OCR numbers
assert FieldValidators.luhn_checksum("1234567897") is True # check digit 7
def test_luhn_checksum_invalid(self):
"""Test Luhn validation with invalid numbers."""
assert FieldValidators.luhn_checksum("53939485") is False # wrong check digit
assert FieldValidators.luhn_checksum("1234567890") is False
def test_calculate_luhn_check_digit(self):
"""Test Luhn check digit calculation."""
# For "123456789", the check digit should make it valid
check = FieldValidators.calculate_luhn_check_digit("123456789")
full_number = "123456789" + str(check)
assert FieldValidators.luhn_checksum(full_number) is True
def test_is_valid_organisation_number(self):
"""Test organisation number validation."""
# Valid (with correct Luhn checksum)
# Note: Need actual valid org numbers for this test
# Using a well-known one: 5565006245 (placeholder)
pass # Skip without real test data
def test_is_valid_bankgiro(self):
"""Test Bankgiro validation."""
# Valid 8-digit Bankgiro with Luhn
assert FieldValidators.is_valid_bankgiro("53939484") is True
# Invalid (wrong length)
assert FieldValidators.is_valid_bankgiro("123") is False
assert FieldValidators.is_valid_bankgiro("123456789") is False # 9 digits
def test_format_bankgiro(self):
"""Test Bankgiro formatting."""
assert FieldValidators.format_bankgiro("53939484") == "5393-9484"
assert FieldValidators.format_bankgiro("1234567") == "123-4567"
assert FieldValidators.format_bankgiro("123") is None
def test_is_valid_plusgiro(self):
"""Test Plusgiro validation."""
# Valid Plusgiro (2-8 digits with Luhn)
assert FieldValidators.is_valid_plusgiro("18") is True # minimal
# Invalid (wrong length)
assert FieldValidators.is_valid_plusgiro("1") is False
def test_format_plusgiro(self):
"""Test Plusgiro formatting."""
assert FieldValidators.format_plusgiro("12345678") == "1234567-8"
assert FieldValidators.format_plusgiro("123456") == "12345-6"
def test_is_valid_amount(self):
"""Test amount validation."""
assert FieldValidators.is_valid_amount("1234.56") is True
assert FieldValidators.is_valid_amount("1 234,56") is True
assert FieldValidators.is_valid_amount("abc") is False
assert FieldValidators.is_valid_amount("-100") is False # below min
assert FieldValidators.is_valid_amount("100000000") is False # above max
def test_parse_amount(self):
"""Test amount parsing."""
assert FieldValidators.parse_amount("1234.56") == 1234.56
assert FieldValidators.parse_amount("1 234,56") == 1234.56
assert FieldValidators.parse_amount("1.234,56") == 1234.56 # German
assert FieldValidators.parse_amount("1,234.56") == 1234.56 # US
def test_is_valid_date(self):
"""Test date validation."""
assert FieldValidators.is_valid_date("2024-12-29") is True
assert FieldValidators.is_valid_date("29.12.2024") is True
assert FieldValidators.is_valid_date("29/12/2024") is True
assert FieldValidators.is_valid_date("not a date") is False
assert FieldValidators.is_valid_date("1900-01-01") is False # out of range
def test_format_date_iso(self):
"""Test date ISO formatting."""
assert FieldValidators.format_date_iso("29.12.2024") == "2024-12-29"
assert FieldValidators.format_date_iso("29/12/2024") == "2024-12-29"
assert FieldValidators.format_date_iso("2024-12-29") == "2024-12-29"
def test_validate_field_dispatch(self):
"""Test validate_field dispatches correctly."""
# Organisation number
is_valid, error = FieldValidators.validate_field("supplier_organisation_number", "")
assert is_valid is False
# Amount
is_valid, error = FieldValidators.validate_field("Amount", "1234.56")
assert is_valid is True
# Date
is_valid, error = FieldValidators.validate_field("InvoiceDate", "2024-12-29")
assert is_valid is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])

244
src/utils/text_cleaner.py Normal file
View File

@@ -0,0 +1,244 @@
"""
Text Cleaning Module
Provides text normalization and OCR error correction utilities.
Used by both inference (field_extractor) and matching (normalizer) stages.
"""
import re
from typing import Optional
class TextCleaner:
"""
Unified text cleaning utilities for invoice processing.
Handles:
- Unicode normalization (zero-width chars, dash variants)
- OCR error correction (O/0, l/1, etc.)
- Whitespace normalization
- Swedish-specific character handling
"""
# OCR常见错误修正映射 (用于数字字段)
# 当我们期望数字时,这些字符常被误识别
OCR_DIGIT_CORRECTIONS = {
'O': '0', 'o': '0', # 字母O -> 数字0
'Q': '0', # Q 有时像 0
'l': '1', 'I': '1', # 小写L/大写I -> 数字1
'|': '1', # 竖线 -> 1
'i': '1', # 小写i -> 1
'S': '5', 's': '5', # S -> 5
'B': '8', # B -> 8
'Z': '2', 'z': '2', # Z -> 2
'G': '6', 'g': '6', # G -> 6 (在某些字体中)
'A': '4', # A -> 4 (在某些字体中)
'T': '7', # T -> 7 (在某些字体中)
'q': '9', # q -> 9
'D': '0', # D -> 0
}
# 反向映射:数字被误识别为字母的情况 (用于字母数字混合字段)
OCR_LETTER_CORRECTIONS = {
'0': 'O',
'1': 'I',
'5': 'S',
'8': 'B',
'2': 'Z',
}
# Unicode 特殊字符归一化
UNICODE_NORMALIZATIONS = {
# 各种横线/破折号 -> 标准连字符
'\u2013': '-', # en-dash
'\u2014': '-', # em-dash —
'\u2212': '-', # minus sign
'\u00b7': '-', # middle dot ·
'\u2010': '-', # hyphen
'\u2011': '-', # non-breaking hyphen
'\u2012': '-', # figure dash
'\u2015': '-', # horizontal bar ―
# 各种空格 -> 标准空格
'\u00a0': ' ', # non-breaking space
'\u2002': ' ', # en space
'\u2003': ' ', # em space
'\u2009': ' ', # thin space
'\u200a': ' ', # hair space
# 零宽字符 -> 删除
'\u200b': '', # zero-width space
'\u200c': '', # zero-width non-joiner
'\u200d': '', # zero-width joiner
'\ufeff': '', # BOM / zero-width no-break space
# 各种引号 -> 标准引号
'\u2018': "'", # left single quote '
'\u2019': "'", # right single quote '
'\u201c': '"', # left double quote "
'\u201d': '"', # right double quote "
}
@classmethod
def clean_unicode(cls, text: str) -> str:
"""
Normalize Unicode characters to ASCII equivalents.
Handles:
- Various dash types -> standard hyphen (-)
- Various spaces -> standard space
- Zero-width characters -> removed
- Various quotes -> standard quotes
"""
for unicode_char, replacement in cls.UNICODE_NORMALIZATIONS.items():
text = text.replace(unicode_char, replacement)
return text
@classmethod
def normalize_whitespace(cls, text: str) -> str:
"""Collapse multiple whitespace to single space and strip."""
return ' '.join(text.split())
@classmethod
def clean_text(cls, text: str) -> str:
"""
Full text cleaning pipeline.
1. Normalize Unicode
2. Normalize whitespace
3. Strip
This is safe for all field types.
"""
text = cls.clean_unicode(text)
text = cls.normalize_whitespace(text)
return text.strip()
@classmethod
def apply_ocr_digit_corrections(cls, text: str) -> str:
"""
Apply OCR error corrections for digit-only fields.
Use this when the field is expected to contain only digits
(e.g., OCR number, organization number digits, etc.)
Example:
"556l23-4S67" -> "556123-4567"
"""
result = []
for char in text:
if char in cls.OCR_DIGIT_CORRECTIONS:
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
else:
result.append(char)
return ''.join(result)
@classmethod
def extract_digits(cls, text: str, apply_ocr_correction: bool = True) -> str:
"""
Extract only digits from text.
Args:
text: Input text
apply_ocr_correction: If True, apply OCR corrections ONLY to characters
that are adjacent to digits (not standalone letters)
Returns:
String containing only digits
"""
if apply_ocr_correction:
# 只对看起来像数字序列中的字符应用 OCR 修正
# 例如 "556O23" 中的 O 应该修正,但 "ABC 123" 中的 ABC 不应该
result = []
for i, char in enumerate(text):
if char.isdigit():
result.append(char)
elif char in cls.OCR_DIGIT_CORRECTIONS:
# 检查前后是否有数字
prev_is_digit = i > 0 and (text[i - 1].isdigit() or text[i - 1] in cls.OCR_DIGIT_CORRECTIONS)
next_is_digit = i < len(text) - 1 and (text[i + 1].isdigit() or text[i + 1] in cls.OCR_DIGIT_CORRECTIONS)
if prev_is_digit or next_is_digit:
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
# 其他字符跳过
return ''.join(result)
else:
return re.sub(r'\D', '', text)
@classmethod
def clean_for_digits(cls, text: str) -> str:
"""
Clean text that should primarily contain digits.
Pipeline:
1. Clean Unicode
2. Apply OCR digit corrections
3. Normalize whitespace
Preserves separators (-, /) for formatted numbers like "556123-4567"
"""
text = cls.clean_unicode(text)
text = cls.apply_ocr_digit_corrections(text)
text = cls.normalize_whitespace(text)
return text.strip()
@classmethod
def generate_ocr_variants(cls, text: str) -> list[str]:
"""
Generate possible OCR error variants of the input text.
This is useful for matching: if we have a CSV value,
we generate variants that might appear in OCR output.
Example:
"5561234567" -> ["5561234567", "556I234567", "5561234S67", ...]
"""
variants = {text}
# 只对数字生成字母变体
for digit, letter in cls.OCR_LETTER_CORRECTIONS.items():
if digit in text:
variants.add(text.replace(digit, letter))
# 对字母生成数字变体
for letter, digit in cls.OCR_DIGIT_CORRECTIONS.items():
if letter in text:
variants.add(text.replace(letter, digit))
return list(variants)
@classmethod
def normalize_amount_text(cls, text: str) -> str:
"""
Normalize amount text for parsing.
- Removes currency symbols and labels
- Normalizes separators
- Handles Swedish format (space as thousand separator)
"""
text = cls.clean_text(text)
# 移除货币符号和标签 (使用单词边界确保完整匹配)
text = re.sub(r'(?i)\b(kr|sek|kronor|öre)\b', '', text)
# 移除千位分隔空格 (Swedish: "1 234,56" -> "1234,56")
# 但保留小数点前的数字
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
return text.strip()
@classmethod
def normalize_for_comparison(cls, text: str) -> str:
"""
Normalize text for loose comparison.
- Lowercase
- Remove all non-alphanumeric
- Apply OCR corrections
This is the most aggressive normalization, used for fuzzy matching.
"""
text = cls.clean_text(text)
text = text.lower()
text = cls.apply_ocr_digit_corrections(text)
text = re.sub(r'[^a-z0-9]', '', text)
return text

393
src/utils/validators.py Normal file
View File

@@ -0,0 +1,393 @@
"""
Field Validators Module
Provides validation functions for Swedish invoice fields.
Used by both inference (to validate extracted values) and matching (to filter candidates).
"""
import re
from datetime import datetime
from typing import Optional
from .text_cleaner import TextCleaner
class FieldValidators:
"""
Validators for Swedish invoice field values.
Includes:
- Luhn (Mod10) checksum validation
- Format validation for specific field types
- Range validation for dates and amounts
"""
# =========================================================================
# Luhn (Mod10) Checksum
# =========================================================================
@classmethod
def luhn_checksum(cls, digits: str) -> bool:
"""
Validate using Luhn (Mod10) algorithm.
Used for:
- Bankgiro numbers
- Plusgiro numbers
- OCR reference numbers
- Swedish organization numbers
The checksum is valid if the total modulo 10 equals 0.
"""
# 只保留数字
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
if not digits or not digits.isdigit():
return False
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 1: # 从右往左,每隔一位加倍
digit *= 2
if digit > 9:
digit -= 9
total += digit
return total % 10 == 0
@classmethod
def calculate_luhn_check_digit(cls, digits: str) -> int:
"""
Calculate the Luhn check digit for a number.
Given a number without check digit, returns the digit that would make it valid.
"""
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
# 计算现有数字的 Luhn 和
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 0: # 注意:因为还要加一位,所以偶数位置加倍
digit *= 2
if digit > 9:
digit -= 9
total += digit
# 计算需要的校验位
check_digit = (10 - (total % 10)) % 10
return check_digit
# =========================================================================
# Organisation Number Validation
# =========================================================================
@classmethod
def is_valid_organisation_number(cls, value: str) -> bool:
"""
Validate Swedish organisation number.
Format: NNNNNN-NNNN (10 digits)
- First digit: 1-9
- Third digit: >= 2 (distinguishes from personal numbers)
- Last digit: Luhn check digit
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
# 处理 VAT 格式
if len(digits) == 12 and digits.endswith('01'):
digits = digits[:10]
elif len(digits) == 14 and digits.startswith('46') and digits.endswith('01'):
digits = digits[2:12]
if len(digits) != 10:
return False
# 第一位 1-9
if digits[0] == '0':
return False
# 第三位 >= 2 (区分组织号和个人号)
# 注意:有些特殊组织可能不符合此规则,所以这里放宽
# if int(digits[2]) < 2:
# return False
# Luhn 校验
return cls.luhn_checksum(digits)
# =========================================================================
# Bankgiro Validation
# =========================================================================
@classmethod
def is_valid_bankgiro(cls, value: str) -> bool:
"""
Validate Swedish Bankgiro number.
Format: 7 or 8 digits with Luhn checksum
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 7 or len(digits) > 8:
return False
return cls.luhn_checksum(digits)
@classmethod
def format_bankgiro(cls, value: str) -> Optional[str]:
"""
Format Bankgiro number to standard format.
Returns: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits), or None if invalid
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) == 7:
return f"{digits[:3]}-{digits[3:]}"
elif len(digits) == 8:
return f"{digits[:4]}-{digits[4:]}"
else:
return None
# =========================================================================
# Plusgiro Validation
# =========================================================================
@classmethod
def is_valid_plusgiro(cls, value: str) -> bool:
"""
Validate Swedish Plusgiro number.
Format: 2-8 digits with Luhn checksum
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 2 or len(digits) > 8:
return False
return cls.luhn_checksum(digits)
@classmethod
def format_plusgiro(cls, value: str) -> Optional[str]:
"""
Format Plusgiro number to standard format.
Returns: XXXXXXX-X format, or None if invalid
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 2 or len(digits) > 8:
return None
return f"{digits[:-1]}-{digits[-1]}"
# =========================================================================
# OCR Number Validation
# =========================================================================
@classmethod
def is_valid_ocr_number(cls, value: str, validate_checksum: bool = True) -> bool:
"""
Validate Swedish OCR reference number.
- Typically 10-25 digits
- Usually has Luhn checksum (but not always enforced)
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 5 or len(digits) > 25:
return False
if validate_checksum:
return cls.luhn_checksum(digits)
return True
# =========================================================================
# Amount Validation
# =========================================================================
@classmethod
def is_valid_amount(cls, value: str, min_amount: float = 0.0, max_amount: float = 10_000_000.0) -> bool:
"""
Validate monetary amount.
- Must be positive (or at least >= min_amount)
- Should be within reasonable range
"""
try:
# 尝试解析
text = TextCleaner.normalize_amount_text(value)
# 统一为点作为小数分隔符
text = text.replace(' ', '').replace(',', '.')
# 如果有多个点,保留最后一个
if text.count('.') > 1:
parts = text.rsplit('.', 1)
text = parts[0].replace('.', '') + '.' + parts[1]
amount = float(text)
return min_amount <= amount <= max_amount
except (ValueError, TypeError):
return False
@classmethod
def parse_amount(cls, value: str) -> Optional[float]:
"""
Parse amount from string, handling various formats.
Returns float or None if parsing fails.
"""
try:
text = TextCleaner.normalize_amount_text(value)
text = text.replace(' ', '')
# 检测格式并解析
# 瑞典/德国格式: 逗号是小数点
if re.match(r'^[\d.]+,\d{1,2}$', text):
text = text.replace('.', '').replace(',', '.')
# 美国格式: 点是小数点
elif re.match(r'^[\d,]+\.\d{1,2}$', text):
text = text.replace(',', '')
else:
# 简单格式
text = text.replace(',', '.')
if text.count('.') > 1:
parts = text.rsplit('.', 1)
text = parts[0].replace('.', '') + '.' + parts[1]
return float(text)
except (ValueError, TypeError):
return None
# =========================================================================
# Date Validation
# =========================================================================
@classmethod
def is_valid_date(cls, value: str, min_year: int = 2000, max_year: int = 2100) -> bool:
"""
Validate date string.
- Year should be within reasonable range
- Month 1-12
- Day 1-31 (basic check)
"""
parsed = cls.parse_date(value)
if parsed is None:
return False
year, month, day = parsed
if not (min_year <= year <= max_year):
return False
if not (1 <= month <= 12):
return False
if not (1 <= day <= 31):
return False
# 更精确的日期验证
try:
datetime(year, month, day)
return True
except ValueError:
return False
@classmethod
def parse_date(cls, value: str) -> Optional[tuple[int, int, int]]:
"""
Parse date from string.
Returns (year, month, day) tuple or None.
"""
from .format_variants import FormatVariants
return FormatVariants._parse_date(value)
@classmethod
def format_date_iso(cls, value: str) -> Optional[str]:
"""
Format date to ISO format (YYYY-MM-DD).
Returns formatted string or None if parsing fails.
"""
parsed = cls.parse_date(value)
if parsed is None:
return None
year, month, day = parsed
return f"{year}-{month:02d}-{day:02d}"
# =========================================================================
# Invoice Number Validation
# =========================================================================
@classmethod
def is_valid_invoice_number(cls, value: str, min_length: int = 1, max_length: int = 30) -> bool:
"""
Validate invoice number.
Basic validation - just length check since invoice numbers are highly variable.
"""
clean = TextCleaner.clean_text(value)
if not clean:
return False
# 提取有意义的字符(字母和数字)
meaningful = re.sub(r'[^a-zA-Z0-9]', '', clean)
return min_length <= len(meaningful) <= max_length
# =========================================================================
# Generic Validation
# =========================================================================
@classmethod
def validate_field(cls, field_name: str, value: str) -> tuple[bool, Optional[str]]:
"""
Validate a field by name.
Returns (is_valid, error_message).
"""
if not value:
return False, "Empty value"
field_lower = field_name.lower()
if 'organisation' in field_lower or 'org' in field_lower:
if cls.is_valid_organisation_number(value):
return True, None
return False, "Invalid organisation number format or checksum"
elif 'bankgiro' in field_lower:
if cls.is_valid_bankgiro(value):
return True, None
return False, "Invalid Bankgiro format or checksum"
elif 'plusgiro' in field_lower:
if cls.is_valid_plusgiro(value):
return True, None
return False, "Invalid Plusgiro format or checksum"
elif 'ocr' in field_lower:
if cls.is_valid_ocr_number(value, validate_checksum=False):
return True, None
return False, "Invalid OCR number length"
elif 'amount' in field_lower:
if cls.is_valid_amount(value):
return True, None
return False, "Invalid amount format"
elif 'date' in field_lower:
if cls.is_valid_date(value):
return True, None
return False, "Invalid date format"
elif 'invoice' in field_lower and 'number' in field_lower:
if cls.is_valid_invoice_number(value):
return True, None
return False, "Invalid invoice number"
else:
# 默认:只检查非空
if TextCleaner.clean_text(value):
return True, None
return False, "Empty value after cleaning"

View File

@@ -122,6 +122,7 @@ def create_api_router(
inference_result = InferenceResult(
document_id=service_result.document_id,
success=service_result.success,
document_type=service_result.document_type,
fields=service_result.fields,
confidence=service_result.confidence,
detections=[

View File

@@ -30,6 +30,9 @@ class InferenceResult(BaseModel):
document_id: str = Field(..., description="Document identifier")
success: bool = Field(..., description="Whether inference succeeded")
document_type: str = Field(
default="invoice", description="Document type: 'invoice' or 'letter'"
)
fields: dict[str, str | None] = Field(
default_factory=dict, description="Extracted field values"
)

View File

@@ -28,6 +28,7 @@ class ServiceResult:
document_id: str
success: bool = False
document_type: str = "invoice" # "invoice" or "letter"
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
@@ -145,6 +146,13 @@ class InferenceService:
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections for visualization
result.detections = [
{
@@ -202,6 +210,13 @@ class InferenceService:
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Get raw detections
result.detections = [
{