WIP
This commit is contained in:
@@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace:
|
||||
"--dpi",
|
||||
type=int,
|
||||
default=150,
|
||||
help="DPI for PDF rendering",
|
||||
help="DPI for PDF rendering (must match training DPI)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -63,7 +63,24 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
'--resume',
|
||||
help='Resume from checkpoint'
|
||||
action='store_true',
|
||||
help='Resume from last checkpoint'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--workers',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of data loader workers (default: 4, reduce if OOM)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cache',
|
||||
action='store_true',
|
||||
help='Cache images in RAM (faster but uses more memory)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--low-memory',
|
||||
action='store_true',
|
||||
help='Enable low memory mode (batch=4, workers=2, no cache)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--train-ratio',
|
||||
@@ -103,6 +120,16 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Apply low-memory mode if specified
|
||||
if args.low_memory:
|
||||
print("🔧 Low memory mode enabled")
|
||||
args.batch = min(args.batch, 8) # Reduce from 16 to 8
|
||||
args.workers = min(args.workers, 4) # Reduce from 8 to 4
|
||||
args.cache = False
|
||||
print(f" Batch size: {args.batch}")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Cache: disabled")
|
||||
|
||||
# Validate dataset directory
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
temp_dir = dataset_dir / 'temp'
|
||||
@@ -181,9 +208,10 @@ def main():
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load model
|
||||
if args.resume:
|
||||
print(f"Resuming from: {args.resume}")
|
||||
model = YOLO(args.resume)
|
||||
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
|
||||
if args.resume and last_checkpoint.exists():
|
||||
print(f"Resuming from: {last_checkpoint}")
|
||||
model = YOLO(str(last_checkpoint))
|
||||
else:
|
||||
model = YOLO(args.model)
|
||||
|
||||
@@ -200,6 +228,9 @@ def main():
|
||||
'exist_ok': True,
|
||||
'pretrained': True,
|
||||
'verbose': True,
|
||||
'workers': args.workers,
|
||||
'cache': args.cache,
|
||||
'resume': args.resume and last_checkpoint.exists(),
|
||||
# Document-specific augmentation settings
|
||||
'degrees': 5.0,
|
||||
'translate': 0.05,
|
||||
|
||||
@@ -2,17 +2,33 @@
|
||||
Field Extractor Module
|
||||
|
||||
Extracts and validates field values from detected regions.
|
||||
|
||||
This module is used during inference to extract values from OCR text.
|
||||
It uses shared utilities from src.utils for text cleaning and validation.
|
||||
|
||||
Enhanced features:
|
||||
- Multi-source fusion with confidence weighting
|
||||
- Smart amount parsing with multiple strategies
|
||||
- Enhanced date format unification
|
||||
- OCR error correction integration
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
import re
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .yolo_detector import Detection, CLASS_TO_FIELD
|
||||
|
||||
# Import shared utilities for text cleaning and validation
|
||||
from src.utils.text_cleaner import TextCleaner
|
||||
from src.utils.validators import FieldValidators
|
||||
from src.utils.fuzzy_matcher import FuzzyMatcher
|
||||
from src.utils.ocr_corrections import OCRCorrections
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedField:
|
||||
@@ -27,10 +43,14 @@ class ExtractedField:
|
||||
page_no: int
|
||||
is_valid: bool = True
|
||||
validation_error: str | None = None
|
||||
# Multi-source fusion fields
|
||||
alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...]
|
||||
extraction_method: str = 'single' # 'single', 'fused', 'corrected'
|
||||
ocr_corrections_applied: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
result = {
|
||||
'field_name': self.field_name,
|
||||
'value': self.normalized_value,
|
||||
'raw_text': self.raw_text,
|
||||
@@ -40,6 +60,11 @@ class ExtractedField:
|
||||
'is_valid': self.is_valid,
|
||||
'validation_error': self.validation_error
|
||||
}
|
||||
if self.alternative_values:
|
||||
result['alternatives'] = self.alternative_values
|
||||
if self.extraction_method != 'single':
|
||||
result['extraction_method'] = self.extraction_method
|
||||
return result
|
||||
|
||||
|
||||
class FieldExtractor:
|
||||
@@ -324,21 +349,9 @@ class FieldExtractor:
|
||||
Validate using Luhn (Mod10) algorithm.
|
||||
Used for Bankgiro, Plusgiro, and OCR number validation.
|
||||
|
||||
The checksum is valid if the total modulo 10 equals 0.
|
||||
Delegates to shared FieldValidators for consistency.
|
||||
"""
|
||||
if not digits.isdigit():
|
||||
return False
|
||||
|
||||
total = 0
|
||||
for i, char in enumerate(reversed(digits)):
|
||||
digit = int(char)
|
||||
if i % 2 == 1: # Double every second digit from right
|
||||
digit *= 2
|
||||
if digit > 9:
|
||||
digit -= 9
|
||||
total += digit
|
||||
|
||||
return total % 10 == 0
|
||||
return FieldValidators.luhn_checksum(digits)
|
||||
|
||||
def _detect_giro_type(self, text: str) -> str | None:
|
||||
"""
|
||||
@@ -489,38 +502,81 @@ class FieldExtractor:
|
||||
return None, False, f"No valid Plusgiro found in text"
|
||||
|
||||
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize monetary amount."""
|
||||
# Try to extract amount using regex patterns
|
||||
# Pattern 1: Number with comma as decimal (Swedish format: 1 234,56)
|
||||
# Pattern 2: Number with dot as decimal (1234.56)
|
||||
# Pattern 3: Number followed by currency (275,60 kr or 275.60 SEK)
|
||||
"""Normalize monetary amount.
|
||||
|
||||
patterns = [
|
||||
# Swedish format with space thousand separator: 1 234,56 or 1234,56
|
||||
r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?',
|
||||
# Simple decimal: 350.00 or 350,00
|
||||
r'(\d+[,\.]\d{2})',
|
||||
# Integer amount
|
||||
r'(\d{2,})',
|
||||
]
|
||||
Uses shared TextCleaner for preprocessing and FieldValidators for parsing.
|
||||
If multiple amounts are found, returns the last one (usually the total).
|
||||
"""
|
||||
# Split by newlines and process line by line to get the last valid amount
|
||||
lines = text.split('\n')
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
# Collect all valid amounts from all lines
|
||||
all_amounts = []
|
||||
|
||||
# Pattern for Swedish amount format (with decimals)
|
||||
amount_pattern = r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?'
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Find all amounts in this line
|
||||
matches = re.findall(amount_pattern, line, re.IGNORECASE)
|
||||
for match in matches:
|
||||
amount_str = match.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
all_amounts.append(amount)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
if all_amounts:
|
||||
return f"{all_amounts[-1]:.2f}", True, None
|
||||
|
||||
# Fallback: try shared validator on cleaned text
|
||||
cleaned = TextCleaner.normalize_amount_text(text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
|
||||
# Try to find any decimal number
|
||||
simple_pattern = r'(\d+[,\.]\d{2})'
|
||||
matches = re.findall(simple_pattern, text)
|
||||
if matches:
|
||||
# Take the last match (usually the total amount)
|
||||
amount_str = matches[-1]
|
||||
# Clean up
|
||||
amount_str = amount_str.replace(' ', '').replace('\xa0', '')
|
||||
# Handle comma as decimal separator
|
||||
if ',' in amount_str:
|
||||
amount_str = amount_str.replace(',', '.')
|
||||
|
||||
amount_str = matches[-1].replace(',', '.')
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
continue
|
||||
pass
|
||||
|
||||
# Last resort: try to find integer amount (no decimals)
|
||||
# Look for patterns like "Amount: 11699" or standalone numbers
|
||||
int_pattern = r'(?:amount|belopp|summa|total)[:\s]*(\d+)'
|
||||
match = re.search(int_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
amount = float(match.group(1))
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Very last resort: find any standalone number >= 3 digits
|
||||
standalone_pattern = r'\b(\d{3,})\b'
|
||||
matches = re.findall(standalone_pattern, text)
|
||||
if matches:
|
||||
# Take the last/largest number
|
||||
try:
|
||||
amount = float(matches[-1])
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None, False, f"Cannot parse amount: {text}"
|
||||
|
||||
@@ -528,6 +584,8 @@ class FieldExtractor:
|
||||
"""
|
||||
Normalize date from text that may contain surrounding text.
|
||||
|
||||
Uses shared FieldValidators for date parsing and validation.
|
||||
|
||||
Handles various date formats found in Swedish invoices:
|
||||
- 2025-08-29 (ISO format)
|
||||
- 2025.08.29 (dot separator)
|
||||
@@ -535,9 +593,14 @@ class FieldExtractor:
|
||||
- 29.08.2025 (European with dots)
|
||||
- 20250829 (compact format)
|
||||
"""
|
||||
# First, try using shared validator
|
||||
iso_date = FieldValidators.format_date_iso(text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return iso_date, True, None
|
||||
|
||||
# Fallback: try original patterns for edge cases
|
||||
from datetime import datetime
|
||||
|
||||
# Common date patterns - order matters, most specific first
|
||||
patterns = [
|
||||
# ISO format: 2025-08-29
|
||||
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
||||
@@ -570,32 +633,76 @@ class FieldExtractor:
|
||||
"""
|
||||
Normalize payment line region text.
|
||||
|
||||
Extracts OCR, Amount, and Bankgiro from the payment line using MachineCodeParser.
|
||||
Extracts the machine-readable payment line format from OCR text.
|
||||
Standard Swedish payment line format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
|
||||
Examples:
|
||||
- "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00
|
||||
- "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00
|
||||
|
||||
Returns normalized format preserving ALL components including Amount:
|
||||
- Full format: "OCR:xxx Amount:xxx.xx BG:xxx" or "OCR:xxx Amount:xxx.xx PG:xxx"
|
||||
- This allows downstream cross-validation to extract fields properly.
|
||||
"""
|
||||
from ..ocr.machine_code_parser import MachineCodeParser
|
||||
# Pattern to match Swedish payment line format WITH amount
|
||||
# Format: # <OCR number> # <Kronor> <Öre> <Type> > <account number>#<check digits>#
|
||||
# Account number may have spaces: "78 2 1 713" -> "7821713"
|
||||
# Kronor may have OCR-induced spaces: "12 0 0" -> "1200"
|
||||
# The > symbol may be missing in low-DPI OCR, so make it optional
|
||||
# Check digits may have spaces: "#41 #" -> "#41#"
|
||||
payment_line_full_pattern = r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
|
||||
# Create a simple token-like structure for the parser
|
||||
# (The parser expects tokens, but for inference we have raw text)
|
||||
parser = MachineCodeParser()
|
||||
match = re.search(payment_line_full_pattern, text)
|
||||
if match:
|
||||
ocr_part = match.group(1).replace(' ', '')
|
||||
kronor = match.group(2).replace(' ', '') # Remove OCR-induced spaces
|
||||
ore = match.group(3)
|
||||
record_type = match.group(4)
|
||||
account = match.group(5).replace(' ', '') # Remove spaces from account number
|
||||
check_digits = match.group(6)
|
||||
|
||||
# Try to parse the standard payment line format
|
||||
result = parser._parse_standard_payment_line(text)
|
||||
# Reconstruct the clean machine-readable format
|
||||
# Format: # OCR # KRONOR ORE TYPE > ACCOUNT#CHECK#
|
||||
result = f"# {ocr_part} # {kronor} {ore} {record_type} > {account}#{check_digits}#"
|
||||
return result, True, None
|
||||
|
||||
if result:
|
||||
# Format as structured output
|
||||
parts = []
|
||||
if result.get('ocr'):
|
||||
parts.append(f"OCR:{result['ocr']}")
|
||||
if result.get('amount'):
|
||||
parts.append(f"Amount:{result['amount']}")
|
||||
if result.get('bankgiro'):
|
||||
parts.append(f"BG:{result['bankgiro']}")
|
||||
# Try pattern WITHOUT amount (some payment lines don't have amount)
|
||||
# Format: # <OCR number> # > <account number>#<check digits>#
|
||||
# > may be missing in low-DPI OCR
|
||||
# Check digits may have spaces
|
||||
payment_line_no_amount_pattern = r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
match = re.search(payment_line_no_amount_pattern, text)
|
||||
if match:
|
||||
ocr_part = match.group(1).replace(' ', '')
|
||||
account = match.group(2).replace(' ', '')
|
||||
check_digits = match.group(3)
|
||||
|
||||
if parts:
|
||||
return ' '.join(parts), True, None
|
||||
result = f"# {ocr_part} # > {account}#{check_digits}#"
|
||||
return result, True, None
|
||||
|
||||
# Fallback: return raw text if no structured parsing possible
|
||||
return text, True, None
|
||||
# Try alternative pattern: just look for the # > account# pattern (> optional)
|
||||
# Check digits may have spaces
|
||||
alt_pattern = r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
match = re.search(alt_pattern, text)
|
||||
if match:
|
||||
ocr_part = match.group(1).replace(' ', '')
|
||||
account = match.group(2).replace(' ', '')
|
||||
check_digits = match.group(3)
|
||||
|
||||
result = f"# {ocr_part} # > {account}#{check_digits}#"
|
||||
return result, True, None
|
||||
|
||||
# Try to find just the account part with # markers
|
||||
# Check digits may have spaces
|
||||
account_pattern = r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#'
|
||||
match = re.search(account_pattern, text)
|
||||
if match:
|
||||
account = match.group(1).replace(' ', '')
|
||||
check_digits = match.group(2)
|
||||
return f"> {account}#{check_digits}#", True, "Partial payment line (account only)"
|
||||
|
||||
# Fallback: return None if no payment line format found
|
||||
return None, False, "No valid payment line format found"
|
||||
|
||||
def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
@@ -640,24 +747,52 @@ class FieldExtractor:
|
||||
Normalize customer number extracted from OCR.
|
||||
|
||||
Customer numbers can have various formats:
|
||||
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N'
|
||||
- With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R'
|
||||
- Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N'
|
||||
- Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01'
|
||||
- Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R'
|
||||
|
||||
Note: Spaces and dashes may be removed from invoice display,
|
||||
so we need to match both 'JTY 576-3' and 'JTY5763' formats.
|
||||
"""
|
||||
from ..normalize.normalizer import FieldNormalizer
|
||||
|
||||
# Clean the text using the same logic as matcher
|
||||
text = FieldNormalizer.clean_text(text)
|
||||
|
||||
if not text:
|
||||
if not text or not text.strip():
|
||||
return None, False, "Empty text"
|
||||
|
||||
# Customer number patterns - ordered by specificity
|
||||
# Match both spaced/dashed versions and compact versions
|
||||
# Keep original text for pattern matching (don't uppercase yet)
|
||||
original_text = text.strip()
|
||||
|
||||
# Customer number patterns - ordered by specificity (most specific first)
|
||||
# All patterns use IGNORECASE so they work regardless of case
|
||||
customer_code_patterns = [
|
||||
# Pattern: 2-4 letters + space + digits + dash + single letter/digit (UMJ 436-R, EMM 256-6)
|
||||
# This is the most common Swedish customer number format
|
||||
r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b',
|
||||
# Pattern: 2-4 letters + space + digits + letter WITHOUT dash (Dwq 211X, ABC 123X)
|
||||
# Note: This is also common for customer numbers
|
||||
r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b',
|
||||
# Pattern: Word (capitalized) + space + digits + dash + letter (Umj 436-R, Billo 123-A)
|
||||
r'\b([A-Za-z][a-z]{1,10})\s+(\d{1,4})-([A-Za-z0-9])\b',
|
||||
# Pattern: Letters + digits + dash + digit/letter without space (JTY576-3)
|
||||
r'\b([A-Za-z]{2,4})(\d{1,4})-([A-Za-z0-9])\b',
|
||||
]
|
||||
|
||||
# Try specific patterns first
|
||||
for pattern in customer_code_patterns:
|
||||
match = re.search(pattern, original_text)
|
||||
if match:
|
||||
# Skip if it looks like a Swedish postal code (SE + digits)
|
||||
full_match = match.group(0)
|
||||
if full_match.upper().startswith('SE ') and re.match(r'^SE\s+\d{3}\s*\d{2}', full_match, re.IGNORECASE):
|
||||
continue
|
||||
# Reconstruct the customer number in standard format
|
||||
groups = match.groups()
|
||||
if len(groups) == 3:
|
||||
# Format: XXX NNN-X (add dash if not present, e.g., "Dwq 211X" -> "DWQ 211-X")
|
||||
result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}"
|
||||
return result, True, None
|
||||
|
||||
# Generic patterns for other formats
|
||||
generic_patterns = [
|
||||
# Pattern: Letters + space/dash + digits + dash + digit (EMM 256-6, JTY 576-3)
|
||||
r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]\d{1,2}[A-Z]?)\b',
|
||||
# Pattern: Letters + space/dash + digits + optional letter (FFL 019N, ABC 123X)
|
||||
@@ -666,25 +801,45 @@ class FieldExtractor:
|
||||
r'\b([A-Z]{2,4}\d{3,6}[A-Z]?)\b',
|
||||
# Pattern: Single letter + digits (A12345)
|
||||
r'\b([A-Z]\d{4,6}[A-Z]?)\b',
|
||||
# Pattern: Digits + dash/space + digits (123-456)
|
||||
r'\b(\d{3,6}[\s\-]\d{1,4})\b',
|
||||
]
|
||||
|
||||
all_matches = []
|
||||
for pattern in customer_code_patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
all_matches.extend(matches)
|
||||
for pattern in generic_patterns:
|
||||
for match in re.finditer(pattern, original_text, re.IGNORECASE):
|
||||
matched_text = match.group(1)
|
||||
pos = match.start()
|
||||
# Filter out matches that look like postal codes or ID numbers
|
||||
# Postal codes are usually 3-5 digits without letters
|
||||
if re.match(r'^\d+$', matched_text):
|
||||
continue
|
||||
# Filter out V4 2 type matches (single letter + digit + space + digit)
|
||||
if re.match(r'^[A-Z]\d\s+\d$', matched_text, re.IGNORECASE):
|
||||
continue
|
||||
# Filter out Swedish postal codes (SE XXX XX format or SE + digits)
|
||||
# SE followed by digits is typically postal code, not customer number
|
||||
if re.match(r'^SE[\s\-]*\d', matched_text, re.IGNORECASE):
|
||||
continue
|
||||
all_matches.append((matched_text, pos))
|
||||
|
||||
if all_matches:
|
||||
# Prefer longer matches and those appearing later in text (after names)
|
||||
# Sort by position in text (later = better) and length (longer = better)
|
||||
# Prefer matches that contain both letters and digits with dash
|
||||
scored_matches = []
|
||||
for match in all_matches:
|
||||
pos = text.upper().rfind(match.upper())
|
||||
# Score: position * 0.1 + length (prefer later and longer)
|
||||
score = pos * 0.1 + len(match)
|
||||
scored_matches.append((score, match))
|
||||
for match_text, pos in all_matches:
|
||||
score = 0
|
||||
# Bonus for containing dash (likely customer number format)
|
||||
if '-' in match_text:
|
||||
score += 50
|
||||
# Bonus for format like XXX NNN-X
|
||||
if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', match_text, re.IGNORECASE):
|
||||
score += 100
|
||||
# Bonus for length (prefer medium length)
|
||||
if 6 <= len(match_text) <= 12:
|
||||
score += 20
|
||||
# Position bonus (prefer later matches, after names)
|
||||
score += pos * 0.1
|
||||
scored_matches.append((score, match_text))
|
||||
|
||||
if scored_matches:
|
||||
best_match = max(scored_matches, key=lambda x: x[0])[1]
|
||||
return best_match.strip().upper(), True, None
|
||||
|
||||
@@ -694,7 +849,7 @@ class FieldExtractor:
|
||||
]
|
||||
|
||||
for pattern in labeled_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
match = re.search(pattern, original_text, re.IGNORECASE)
|
||||
if match:
|
||||
extracted = match.group(1).strip()
|
||||
extracted = re.sub(r'[\s\.\,\:]+$', '', extracted)
|
||||
@@ -702,32 +857,18 @@ class FieldExtractor:
|
||||
return extracted.upper(), True, None
|
||||
|
||||
# Pattern 3: If text contains comma (likely "NAME, NAME CODE"), extract after last comma
|
||||
if ',' in text:
|
||||
after_comma = text.split(',')[-1].strip()
|
||||
if ',' in original_text:
|
||||
after_comma = original_text.split(',')[-1].strip()
|
||||
# Look for alphanumeric code in the part after comma
|
||||
for pattern in customer_code_patterns[:3]: # Use first 3 patterns
|
||||
code_match = re.search(pattern, after_comma, re.IGNORECASE)
|
||||
for pattern in customer_code_patterns:
|
||||
code_match = re.search(pattern, after_comma)
|
||||
if code_match:
|
||||
return code_match.group(1).strip().upper(), True, None
|
||||
|
||||
# Pattern 4: Short text - filter out name-like words
|
||||
if len(text) <= 20:
|
||||
words = text.split()
|
||||
code_parts = []
|
||||
for word in words:
|
||||
# Keep if: contains digits, or is short uppercase (likely abbreviation)
|
||||
if re.search(r'\d', word) or (len(word) <= 4 and word.isupper()):
|
||||
code_parts.append(word)
|
||||
if code_parts:
|
||||
result = ' '.join(code_parts).upper()
|
||||
if len(result) >= 3:
|
||||
groups = code_match.groups()
|
||||
if len(groups) == 3:
|
||||
result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}"
|
||||
return result, True, None
|
||||
|
||||
# Fallback: return cleaned text if reasonable
|
||||
if text and 3 <= len(text) <= 15:
|
||||
return text.upper(), True, None
|
||||
|
||||
return None, False, f"Cannot extract customer number from: {text[:50]}"
|
||||
return None, False, f"Cannot extract customer number from: {original_text[:50]}"
|
||||
|
||||
def extract_all_fields(
|
||||
self,
|
||||
@@ -779,3 +920,428 @@ class FieldExtractor:
|
||||
fields['OCR'] = invoice_number
|
||||
|
||||
return fields
|
||||
|
||||
# =========================================================================
|
||||
# Multi-Source Fusion with Confidence Weighting
|
||||
# =========================================================================
|
||||
|
||||
def fuse_multiple_detections(
|
||||
self,
|
||||
extracted_fields: list[ExtractedField]
|
||||
) -> list[ExtractedField]:
|
||||
"""
|
||||
Fuse multiple detections of the same field using confidence-weighted voting.
|
||||
|
||||
When YOLO detects the same field type multiple times (e.g., multiple Amount boxes),
|
||||
this method selects the best value or combines them intelligently.
|
||||
|
||||
Strategies:
|
||||
1. For numeric fields (Amount, OCR): prefer values that pass validation
|
||||
2. For date fields: prefer values in expected range
|
||||
3. For giro numbers: prefer values with valid Luhn checksum
|
||||
4. General: weighted vote by confidence scores
|
||||
|
||||
Args:
|
||||
extracted_fields: List of all extracted fields (may have duplicates)
|
||||
|
||||
Returns:
|
||||
List with duplicates resolved to single best value per field
|
||||
"""
|
||||
# Group fields by name
|
||||
fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list)
|
||||
for field in extracted_fields:
|
||||
fields_by_name[field.field_name].append(field)
|
||||
|
||||
fused_fields = []
|
||||
|
||||
for field_name, candidates in fields_by_name.items():
|
||||
if len(candidates) == 1:
|
||||
# No fusion needed
|
||||
fused_fields.append(candidates[0])
|
||||
else:
|
||||
# Multiple candidates - fuse them
|
||||
fused = self._fuse_field_candidates(field_name, candidates)
|
||||
fused_fields.append(fused)
|
||||
|
||||
return fused_fields
|
||||
|
||||
def _fuse_field_candidates(
|
||||
self,
|
||||
field_name: str,
|
||||
candidates: list[ExtractedField]
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Fuse multiple candidates for a single field.
|
||||
|
||||
Returns the best candidate with alternatives recorded.
|
||||
"""
|
||||
# Sort by confidence (descending)
|
||||
sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
# Collect all unique values with their max confidence
|
||||
value_scores: dict[str, tuple[float, ExtractedField]] = {}
|
||||
for c in sorted_candidates:
|
||||
if c.normalized_value:
|
||||
if c.normalized_value not in value_scores:
|
||||
value_scores[c.normalized_value] = (c.confidence, c)
|
||||
else:
|
||||
# Keep the higher confidence one
|
||||
if c.confidence > value_scores[c.normalized_value][0]:
|
||||
value_scores[c.normalized_value] = (c.confidence, c)
|
||||
|
||||
if not value_scores:
|
||||
# No valid values, return the highest confidence candidate
|
||||
return sorted_candidates[0]
|
||||
|
||||
# Field-specific fusion strategy
|
||||
best_value, best_field = self._select_best_value(field_name, value_scores)
|
||||
|
||||
# Record alternatives
|
||||
alternatives = [
|
||||
(v, score) for v, (score, _) in value_scores.items()
|
||||
if v != best_value
|
||||
]
|
||||
|
||||
# Create fused result
|
||||
result = ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=best_field.raw_text,
|
||||
normalized_value=best_value,
|
||||
confidence=value_scores[best_value][0],
|
||||
detection_confidence=best_field.detection_confidence,
|
||||
ocr_confidence=best_field.ocr_confidence,
|
||||
bbox=best_field.bbox,
|
||||
page_no=best_field.page_no,
|
||||
is_valid=best_field.is_valid,
|
||||
validation_error=best_field.validation_error,
|
||||
alternative_values=alternatives,
|
||||
extraction_method='fused' if len(value_scores) > 1 else 'single'
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _select_best_value(
|
||||
self,
|
||||
field_name: str,
|
||||
value_scores: dict[str, tuple[float, ExtractedField]]
|
||||
) -> tuple[str, ExtractedField]:
|
||||
"""
|
||||
Select the best value for a field using field-specific logic.
|
||||
|
||||
Returns (best_value, best_field)
|
||||
"""
|
||||
items = list(value_scores.items())
|
||||
|
||||
# Field-specific selection
|
||||
if field_name in ('Bankgiro', 'Plusgiro', 'OCR'):
|
||||
# Prefer values with valid Luhn checksum
|
||||
for value, (score, field) in items:
|
||||
digits = re.sub(r'\D', '', value)
|
||||
if FieldValidators.luhn_checksum(digits):
|
||||
return value, field
|
||||
|
||||
elif field_name == 'Amount':
|
||||
# Prefer larger amounts (usually the total, not subtotals)
|
||||
amounts = []
|
||||
for value, (score, field) in items:
|
||||
try:
|
||||
amt = float(value.replace(',', '.'))
|
||||
amounts.append((amt, value, field))
|
||||
except ValueError:
|
||||
continue
|
||||
if amounts:
|
||||
# Return the largest amount
|
||||
amounts.sort(reverse=True)
|
||||
return amounts[0][1], amounts[0][2]
|
||||
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
# Prefer dates in reasonable range
|
||||
from datetime import datetime
|
||||
for value, (score, field) in items:
|
||||
try:
|
||||
dt = datetime.strptime(value, '%Y-%m-%d')
|
||||
# Prefer recent dates (within last 2 years and next 1 year)
|
||||
now = datetime.now()
|
||||
if now.year - 2 <= dt.year <= now.year + 1:
|
||||
return value, field
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Default: return highest confidence value
|
||||
best = max(items, key=lambda x: x[1][0])
|
||||
return best[0], best[1][1]
|
||||
|
||||
# =========================================================================
|
||||
# Enhanced Amount Parsing
|
||||
# =========================================================================
|
||||
|
||||
def _normalize_amount_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Enhanced amount parsing with multiple strategies.
|
||||
|
||||
Strategies:
|
||||
1. Pattern matching for Swedish formats
|
||||
2. Context-aware extraction (look for keywords like "Total", "Summa")
|
||||
3. OCR error correction for common digit errors
|
||||
4. Multi-amount handling (prefer last/largest as total)
|
||||
|
||||
This method replaces the original _normalize_amount when enhanced mode is enabled.
|
||||
"""
|
||||
# Strategy 1: Apply OCR corrections first
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Strategy 2: Look for labeled amounts (highest priority)
|
||||
labeled_patterns = [
|
||||
# Swedish patterns
|
||||
(r'(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 1.0),
|
||||
(r'(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 0.8), # Lower priority for VAT
|
||||
# Generic pattern
|
||||
(r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?', 0.7),
|
||||
]
|
||||
|
||||
candidates = []
|
||||
for pattern, priority in labeled_patterns:
|
||||
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
|
||||
amount_str = match.group(1).replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if 0 < amount < 10_000_000: # Reasonable range
|
||||
candidates.append((amount, priority, match.start()))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if candidates:
|
||||
# Sort by priority (desc), then by position (later is usually total)
|
||||
candidates.sort(key=lambda x: (-x[1], -x[2]))
|
||||
best_amount = candidates[0][0]
|
||||
return f"{best_amount:.2f}", True, None
|
||||
|
||||
# Strategy 3: Parse with shared validator
|
||||
cleaned = TextCleaner.normalize_amount_text(corrected_text)
|
||||
amount = FieldValidators.parse_amount(cleaned)
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
return f"{amount:.2f}", True, None
|
||||
|
||||
# Strategy 4: Try to extract any decimal number as fallback
|
||||
decimal_pattern = r'(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})'
|
||||
matches = re.findall(decimal_pattern, corrected_text)
|
||||
if matches:
|
||||
# Clean and parse each match
|
||||
amounts = []
|
||||
for m in matches:
|
||||
cleaned_m = m.replace(' ', '').replace('.', '').replace(',', '.')
|
||||
# Handle Swedish format: "1 234,56" -> "1234.56"
|
||||
if ',' in m and '.' not in m:
|
||||
cleaned_m = m.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
amt = float(cleaned_m)
|
||||
if 0 < amt < 10_000_000:
|
||||
amounts.append(amt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if amounts:
|
||||
# Return the last/largest amount (usually the total)
|
||||
return f"{max(amounts):.2f}", True, None
|
||||
|
||||
return None, False, f"Cannot parse amount: {text[:50]}"
|
||||
|
||||
# =========================================================================
|
||||
# Enhanced Date Parsing
|
||||
# =========================================================================
|
||||
|
||||
def _normalize_date_enhanced(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""
|
||||
Enhanced date parsing with comprehensive format support.
|
||||
|
||||
Supports:
|
||||
- ISO: 2024-12-29, 2024/12/29
|
||||
- European: 29.12.2024, 29/12/2024, 29-12-2024
|
||||
- Swedish text: "29 december 2024", "29 dec 2024"
|
||||
- Compact: 20241229
|
||||
- With OCR corrections: 2O24-12-29 -> 2024-12-29
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Apply OCR corrections
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Try shared validator first
|
||||
iso_date = FieldValidators.format_date_iso(corrected_text)
|
||||
if iso_date and FieldValidators.is_valid_date(iso_date):
|
||||
return iso_date, True, None
|
||||
|
||||
# Swedish month names
|
||||
swedish_months = {
|
||||
'januari': 1, 'jan': 1,
|
||||
'februari': 2, 'feb': 2,
|
||||
'mars': 3, 'mar': 3,
|
||||
'april': 4, 'apr': 4,
|
||||
'maj': 5,
|
||||
'juni': 6, 'jun': 6,
|
||||
'juli': 7, 'jul': 7,
|
||||
'augusti': 8, 'aug': 8,
|
||||
'september': 9, 'sep': 9, 'sept': 9,
|
||||
'oktober': 10, 'okt': 10,
|
||||
'november': 11, 'nov': 11,
|
||||
'december': 12, 'dec': 12,
|
||||
}
|
||||
|
||||
# Pattern for Swedish text dates: "29 december 2024" or "29 dec 2024"
|
||||
swedish_pattern = r'(\d{1,2})\s+([a-zåäö]+)\s+(\d{4})'
|
||||
match = re.search(swedish_pattern, corrected_text.lower())
|
||||
if match:
|
||||
day = int(match.group(1))
|
||||
month_name = match.group(2)
|
||||
year = int(match.group(3))
|
||||
if month_name in swedish_months:
|
||||
month = swedish_months[month_name]
|
||||
try:
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return dt.strftime('%Y-%m-%d'), True, None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Extended patterns
|
||||
patterns = [
|
||||
# ISO format: 2025-08-29, 2025/08/29
|
||||
(r'(\d{4})[-/](\d{1,2})[-/](\d{1,2})', 'ymd'),
|
||||
# Dot format: 2025.08.29
|
||||
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', 'ymd'),
|
||||
# European slash: 29/08/2025
|
||||
(r'(\d{1,2})/(\d{1,2})/(\d{4})', 'dmy'),
|
||||
# European dot: 29.08.2025
|
||||
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', 'dmy'),
|
||||
# European dash: 29-08-2025
|
||||
(r'(\d{1,2})-(\d{1,2})-(\d{4})', 'dmy'),
|
||||
# Compact: 20250829
|
||||
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', 'ymd_compact'),
|
||||
]
|
||||
|
||||
for pattern, fmt in patterns:
|
||||
match = re.search(pattern, corrected_text)
|
||||
if match:
|
||||
try:
|
||||
if fmt == 'ymd':
|
||||
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
elif fmt == 'dmy':
|
||||
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
elif fmt == 'ymd_compact':
|
||||
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
else:
|
||||
continue
|
||||
|
||||
dt = datetime(year, month, day)
|
||||
if 2000 <= dt.year <= 2100:
|
||||
return dt.strftime('%Y-%m-%d'), True, None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None, False, f"Cannot parse date: {text[:50]}"
|
||||
|
||||
# =========================================================================
|
||||
# Apply OCR Corrections to Raw Text
|
||||
# =========================================================================
|
||||
|
||||
def apply_ocr_corrections(
|
||||
self,
|
||||
field_name: str,
|
||||
raw_text: str
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Apply OCR corrections to raw text based on field type.
|
||||
|
||||
Returns (corrected_text, list_of_corrections_applied)
|
||||
"""
|
||||
corrections_applied = []
|
||||
|
||||
if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'):
|
||||
# Aggressive correction for numeric fields
|
||||
result = OCRCorrections.correct_digits(raw_text, aggressive=True)
|
||||
if result.corrections_applied:
|
||||
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
||||
return result.corrected, corrections_applied
|
||||
|
||||
elif field_name == 'Amount':
|
||||
# Conservative correction for amounts (preserve decimal separators)
|
||||
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
|
||||
if result.corrections_applied:
|
||||
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
||||
return result.corrected, corrections_applied
|
||||
|
||||
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
# Conservative correction for dates
|
||||
result = OCRCorrections.correct_digits(raw_text, aggressive=False)
|
||||
if result.corrections_applied:
|
||||
corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied]
|
||||
return result.corrected, corrections_applied
|
||||
|
||||
# No correction for other fields
|
||||
return raw_text, []
|
||||
|
||||
# =========================================================================
|
||||
# Extraction with All Enhancements
|
||||
# =========================================================================
|
||||
|
||||
def extract_with_enhancements(
|
||||
self,
|
||||
detection: Detection,
|
||||
pdf_tokens: list,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
use_enhanced_parsing: bool = True
|
||||
) -> ExtractedField:
|
||||
"""
|
||||
Extract field value with all enhancements enabled.
|
||||
|
||||
Combines:
|
||||
1. OCR error correction
|
||||
2. Enhanced amount/date parsing
|
||||
3. Multi-strategy extraction
|
||||
|
||||
Args:
|
||||
detection: Detection object
|
||||
pdf_tokens: PDF text tokens
|
||||
image_width: Image width in pixels
|
||||
image_height: Image height in pixels
|
||||
use_enhanced_parsing: Whether to use enhanced parsing methods
|
||||
|
||||
Returns:
|
||||
ExtractedField with enhancements applied
|
||||
"""
|
||||
# First, extract using standard method
|
||||
base_result = self.extract_from_detection_with_pdf(
|
||||
detection, pdf_tokens, image_width, image_height
|
||||
)
|
||||
|
||||
if not use_enhanced_parsing:
|
||||
return base_result
|
||||
|
||||
# Apply OCR corrections
|
||||
corrected_text, corrections = self.apply_ocr_corrections(
|
||||
base_result.field_name, base_result.raw_text
|
||||
)
|
||||
|
||||
# Re-normalize with enhanced methods if corrections were applied
|
||||
if corrections or base_result.normalized_value is None:
|
||||
if base_result.field_name == 'Amount':
|
||||
normalized, is_valid, error = self._normalize_amount_enhanced(corrected_text)
|
||||
elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
normalized, is_valid, error = self._normalize_date_enhanced(corrected_text)
|
||||
else:
|
||||
# Re-run standard normalization with corrected text
|
||||
normalized, is_valid, error = self._normalize_and_validate(
|
||||
base_result.field_name, corrected_text
|
||||
)
|
||||
|
||||
# Update result if we got a better value
|
||||
if normalized and (not base_result.normalized_value or is_valid):
|
||||
base_result.normalized_value = normalized
|
||||
base_result.is_valid = is_valid
|
||||
base_result.validation_error = error
|
||||
base_result.ocr_corrections_applied = corrections
|
||||
if corrections:
|
||||
base_result.extraction_method = 'corrected'
|
||||
|
||||
return base_result
|
||||
|
||||
@@ -214,6 +214,43 @@ class InferencePipeline:
|
||||
# Perform cross-validation if payment_line is detected
|
||||
self._cross_validate_payment_line(result)
|
||||
|
||||
def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Parse machine-readable Swedish payment line format.
|
||||
|
||||
Format: # <OCR> # <Kronor> <Öre> <Type> > <Account>#<Check>#
|
||||
Example: "# 11000770600242 # 1200 00 5 > 3082963#41#"
|
||||
|
||||
Returns: (ocr, amount, account) tuple
|
||||
"""
|
||||
# Pattern with amount
|
||||
pattern_full = r'#\s*(\d+)\s*#\s*(\d+)\s+(\d{2})\s+\d\s*>\s*(\d+)#\d+#'
|
||||
match = re.search(pattern_full, payment_line)
|
||||
if match:
|
||||
ocr = match.group(1)
|
||||
kronor = match.group(2)
|
||||
ore = match.group(3)
|
||||
account = match.group(4)
|
||||
amount = f"{kronor}.{ore}"
|
||||
return ocr, amount, account
|
||||
|
||||
# Pattern without amount
|
||||
pattern_no_amount = r'#\s*(\d+)\s*#\s*>\s*(\d+)#\d+#'
|
||||
match = re.search(pattern_no_amount, payment_line)
|
||||
if match:
|
||||
ocr = match.group(1)
|
||||
account = match.group(2)
|
||||
return ocr, None, account
|
||||
|
||||
# Fallback: partial pattern
|
||||
pattern_partial = r'>\s*(\d+)#\d+#'
|
||||
match = re.search(pattern_partial, payment_line)
|
||||
if match:
|
||||
account = match.group(1)
|
||||
return None, None, account
|
||||
|
||||
return None, None, None
|
||||
|
||||
def _cross_validate_payment_line(self, result: InferenceResult) -> None:
|
||||
"""
|
||||
Cross-validate payment_line data against other detected fields.
|
||||
@@ -236,23 +273,26 @@ class InferencePipeline:
|
||||
cv = CrossValidationResult()
|
||||
cv.details = []
|
||||
|
||||
# Parse payment_line format: "OCR:12345 Amount:100,00 BG:123-4567"
|
||||
pl_parts = {}
|
||||
for part in str(payment_line).split():
|
||||
if ':' in part:
|
||||
key, value = part.split(':', 1)
|
||||
pl_parts[key.upper()] = value
|
||||
# Parse machine-readable payment line format
|
||||
ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line))
|
||||
|
||||
cv.payment_line_ocr = pl_parts.get('OCR')
|
||||
cv.payment_line_amount = pl_parts.get('AMOUNT')
|
||||
cv.payment_line_ocr = ocr
|
||||
cv.payment_line_amount = amount
|
||||
|
||||
# Determine account type from payment_line
|
||||
if pl_parts.get('BG'):
|
||||
cv.payment_line_account = pl_parts['BG']
|
||||
# Determine account type based on digit count
|
||||
if account:
|
||||
# Bankgiro: 7-8 digits, Plusgiro: typically fewer
|
||||
if len(account) >= 7:
|
||||
cv.payment_line_account_type = 'bankgiro'
|
||||
elif pl_parts.get('PG'):
|
||||
cv.payment_line_account = pl_parts['PG']
|
||||
# Format: XXX-XXXX or XXXX-XXXX
|
||||
if len(account) == 7:
|
||||
cv.payment_line_account = f"{account[:3]}-{account[3:]}"
|
||||
else:
|
||||
cv.payment_line_account = f"{account[:4]}-{account[4:]}"
|
||||
else:
|
||||
cv.payment_line_account_type = 'plusgiro'
|
||||
# Format: XXXXXXX-X
|
||||
cv.payment_line_account = f"{account[:-1]}-{account[-1]}"
|
||||
|
||||
# Cross-validate and OVERRIDE with payment_line values
|
||||
|
||||
|
||||
@@ -266,6 +266,38 @@ class TestNormalizePaymentLine:
|
||||
assert is_valid is True
|
||||
# Bankgiro should be normalized despite spaces
|
||||
|
||||
def test_payment_line_with_spaces_in_check_digits(self, extractor):
|
||||
"""Test payment line with spaces around check digits: #41 # instead of #41#."""
|
||||
text = "# 6026726908 # 736 00 9 > 5692041 #41 #"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert "6026726908" in result
|
||||
assert "736 00" in result
|
||||
assert "5692041#41#" in result
|
||||
|
||||
def test_payment_line_with_ocr_spaces_in_amount(self, extractor):
|
||||
"""Test payment line with OCR-induced spaces in amount: '12 0 0 00' -> '1200 00'."""
|
||||
text = "# 11000770600242 # 12 0 0 00 5 3082963#41#"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert "11000770600242" in result
|
||||
assert "1200 00" in result
|
||||
assert "3082963#41#" in result
|
||||
|
||||
def test_payment_line_without_greater_symbol(self, extractor):
|
||||
"""Test payment line with missing > symbol (low-DPI OCR issue)."""
|
||||
text = "# 11000770600242 # 1200 00 5 3082963#41#"
|
||||
result, is_valid, error = extractor._normalize_payment_line(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert "11000770600242" in result
|
||||
assert "1200 00" in result
|
||||
|
||||
|
||||
class TestNormalizeCustomerNumber:
|
||||
"""Tests for customer number normalization."""
|
||||
@@ -284,6 +316,33 @@ class TestNormalizeCustomerNumber:
|
||||
result, is_valid, error = extractor._normalize_customer_number("JTY5763")
|
||||
assert result is not None
|
||||
|
||||
def test_format_without_dash(self, extractor):
|
||||
"""Test customer number format without dash: Dwq 211X -> DWQ 211-X."""
|
||||
text = "Dwq 211X Billo SE 106 43 Stockholm"
|
||||
result, is_valid, error = extractor._normalize_customer_number(text)
|
||||
|
||||
assert result is not None
|
||||
assert is_valid is True
|
||||
assert result == "DWQ 211-X"
|
||||
|
||||
def test_swedish_postal_code_exclusion(self, extractor):
|
||||
"""Test that Swedish postal codes are excluded: SE 106 43 should not be extracted."""
|
||||
text = "SE 106 43 Stockholm"
|
||||
result, is_valid, error = extractor._normalize_customer_number(text)
|
||||
|
||||
# Should not extract postal code
|
||||
assert result is None or "SE 106" not in result
|
||||
|
||||
def test_customer_number_with_postal_code_in_text(self, extractor):
|
||||
"""Test extracting customer number when postal code is also present."""
|
||||
text = "Customer: ABC 123X, Address: SE 106 43 Stockholm"
|
||||
result, is_valid, error = extractor._normalize_customer_number(text)
|
||||
|
||||
assert result is not None
|
||||
assert "ABC" in result
|
||||
# Should not extract postal code
|
||||
assert "SE 106" not in result if result else True
|
||||
|
||||
|
||||
class TestNormalizeSupplierOrgNumber:
|
||||
"""Tests for supplier organization number normalization."""
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
Field Normalization Module
|
||||
|
||||
Normalizes field values to generate multiple candidate forms for matching.
|
||||
|
||||
This module generates variants of CSV values for matching against OCR text.
|
||||
It uses shared utilities from src.utils for text cleaning and OCR error variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
@@ -9,6 +12,10 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Callable
|
||||
|
||||
# Import shared utilities
|
||||
from src.utils.text_cleaner import TextCleaner
|
||||
from src.utils.format_variants import FormatVariants
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalizedValue:
|
||||
@@ -39,15 +46,11 @@ class FieldNormalizer:
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""Remove invisible characters and normalize whitespace and dashes."""
|
||||
# Remove zero-width characters
|
||||
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
|
||||
# Normalize different dash types to standard hyphen-minus (ASCII 45)
|
||||
# en-dash (–, U+2013), em-dash (—, U+2014), minus sign (−, U+2212), middle dot (·, U+00B7)
|
||||
text = re.sub(r'[\u2013\u2014\u2212\u00b7]', '-', text)
|
||||
# Normalize whitespace
|
||||
text = ' '.join(text.split())
|
||||
return text.strip()
|
||||
"""Remove invisible characters and normalize whitespace and dashes.
|
||||
|
||||
Delegates to shared TextCleaner for consistency.
|
||||
"""
|
||||
return TextCleaner.clean_text(text)
|
||||
|
||||
@staticmethod
|
||||
def normalize_invoice_number(value: str) -> list[str]:
|
||||
@@ -81,57 +84,44 @@ class FieldNormalizer:
|
||||
"""
|
||||
Normalize Bankgiro number.
|
||||
|
||||
Uses shared FormatVariants plus OCR error variants.
|
||||
|
||||
Examples:
|
||||
'5393-9484' -> ['5393-9484', '53939484']
|
||||
'53939484' -> ['53939484', '5393-9484']
|
||||
"""
|
||||
value = FieldNormalizer.clean_text(value)
|
||||
digits_only = re.sub(r'\D', '', value)
|
||||
# Use shared module for base variants
|
||||
variants = set(FormatVariants.bankgiro_variants(value))
|
||||
|
||||
variants = [value]
|
||||
# Add OCR error variants
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
if digits_only:
|
||||
# Add without dash
|
||||
variants.append(digits_only)
|
||||
|
||||
# Add with dash (format: XXXX-XXXX for 8 digits)
|
||||
if len(digits_only) == 8:
|
||||
with_dash = f"{digits_only[:4]}-{digits_only[4:]}"
|
||||
variants.append(with_dash)
|
||||
elif len(digits_only) == 7:
|
||||
# Some bankgiro numbers are 7 digits: XXX-XXXX
|
||||
with_dash = f"{digits_only[:3]}-{digits_only[3:]}"
|
||||
variants.append(with_dash)
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@staticmethod
|
||||
def normalize_plusgiro(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize Plusgiro number.
|
||||
|
||||
Uses shared FormatVariants plus OCR error variants.
|
||||
|
||||
Examples:
|
||||
'1234567-8' -> ['1234567-8', '12345678']
|
||||
'12345678' -> ['12345678', '1234567-8']
|
||||
"""
|
||||
value = FieldNormalizer.clean_text(value)
|
||||
digits_only = re.sub(r'\D', '', value)
|
||||
# Use shared module for base variants
|
||||
variants = set(FormatVariants.plusgiro_variants(value))
|
||||
|
||||
variants = [value]
|
||||
# Add OCR error variants
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
if digits_only:
|
||||
variants.append(digits_only)
|
||||
|
||||
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
|
||||
if len(digits_only) == 8:
|
||||
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||
variants.append(with_dash)
|
||||
# Also handle 6+1 format
|
||||
elif len(digits_only) == 7:
|
||||
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||
variants.append(with_dash)
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@staticmethod
|
||||
def normalize_organisation_number(value: str) -> list[str]:
|
||||
@@ -141,60 +131,27 @@ class FieldNormalizer:
|
||||
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
|
||||
Swedish VAT format: SE + org_number (10 digits) + 01
|
||||
|
||||
Uses shared FormatVariants for comprehensive variant generation,
|
||||
plus OCR error variants.
|
||||
|
||||
Examples:
|
||||
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
|
||||
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
|
||||
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
|
||||
"""
|
||||
value = FieldNormalizer.clean_text(value)
|
||||
# Use shared module for base variants
|
||||
variants = set(FormatVariants.organisation_number_variants(value))
|
||||
|
||||
# Check if input is a VAT number (starts with SE, ends with 01)
|
||||
org_digits = None
|
||||
if value.upper().startswith('SE') and len(value) >= 12:
|
||||
# Extract org number from VAT: SE + 10 digits + 01
|
||||
potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits
|
||||
if len(potential_org) == 12 and potential_org.endswith('01'):
|
||||
org_digits = potential_org[:-2] # Remove trailing 01
|
||||
elif len(potential_org) == 10:
|
||||
org_digits = potential_org
|
||||
# Add OCR error variants for digit sequences
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits and len(digits) >= 10:
|
||||
# Generate variants where OCR might have misread characters
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]):
|
||||
variants.add(ocr_var)
|
||||
if len(ocr_var) == 10:
|
||||
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
|
||||
|
||||
if org_digits is None:
|
||||
org_digits = re.sub(r'\D', '', value)
|
||||
|
||||
variants = [value]
|
||||
|
||||
if org_digits:
|
||||
variants.append(org_digits)
|
||||
|
||||
# Standard format: NNNNNN-NNNN (10 digits total)
|
||||
if len(org_digits) == 10:
|
||||
with_dash = f"{org_digits[:6]}-{org_digits[6:]}"
|
||||
variants.append(with_dash)
|
||||
|
||||
# Swedish VAT format: SE + org_number + 01
|
||||
vat_number = f"SE{org_digits}01"
|
||||
variants.append(vat_number)
|
||||
variants.append(vat_number.lower()) # se556123456701
|
||||
# With spaces: SE 5561234567 01
|
||||
variants.append(f"SE {org_digits} 01")
|
||||
variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01")
|
||||
# Without 01 suffix (some invoices show just SE + org)
|
||||
variants.append(f"SE{org_digits}")
|
||||
variants.append(f"SE {org_digits}")
|
||||
|
||||
# Some may have 12 digits (century prefix): NNNNNNNN-NNNN
|
||||
elif len(org_digits) == 12:
|
||||
with_dash = f"{org_digits[:8]}-{org_digits[8:]}"
|
||||
variants.append(with_dash)
|
||||
# Also try without century prefix
|
||||
short_version = org_digits[2:]
|
||||
variants.append(short_version)
|
||||
variants.append(f"{short_version[:6]}-{short_version[6:]}")
|
||||
# VAT with short version
|
||||
vat_number = f"SE{short_version}01"
|
||||
variants.append(vat_number)
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@staticmethod
|
||||
def normalize_supplier_accounts(value: str) -> list[str]:
|
||||
|
||||
@@ -42,6 +42,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from src.pdf.extractor import Token as TextToken
|
||||
from src.utils.validators import FieldValidators
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -484,21 +485,42 @@ class MachineCodeParser:
|
||||
def format_account(account_digits: str) -> tuple[str, str]:
|
||||
"""Format account and determine type (bankgiro or plusgiro).
|
||||
|
||||
Uses context keywords first, then falls back to Luhn validation
|
||||
to determine the most likely account type.
|
||||
|
||||
Returns: (formatted_account, account_type)
|
||||
"""
|
||||
if is_plusgiro_context:
|
||||
# Plusgiro format: XXXXXXX-X
|
||||
# Context explicitly indicates Plusgiro
|
||||
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
||||
return formatted, 'plusgiro'
|
||||
else:
|
||||
# Bankgiro format: XXX-XXXX or XXXX-XXXX
|
||||
|
||||
# No explicit context - use Luhn validation to determine type
|
||||
# Try both formats and see which passes Luhn check
|
||||
|
||||
# Format as Plusgiro: XXXXXXX-X (all digits, check digit at end)
|
||||
pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
||||
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
|
||||
|
||||
# Format as Bankgiro: XXX-XXXX or XXXX-XXXX
|
||||
if len(account_digits) == 7:
|
||||
formatted = f"{account_digits[:3]}-{account_digits[3:]}"
|
||||
bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}"
|
||||
elif len(account_digits) == 8:
|
||||
formatted = f"{account_digits[:4]}-{account_digits[4:]}"
|
||||
bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}"
|
||||
else:
|
||||
formatted = account_digits
|
||||
return formatted, 'bankgiro'
|
||||
bg_formatted = account_digits
|
||||
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
|
||||
|
||||
# Decision logic:
|
||||
# 1. If only one format passes Luhn, use that
|
||||
# 2. If both pass or both fail, default to Bankgiro (more common in payment lines)
|
||||
if pg_valid and not bg_valid:
|
||||
return pg_formatted, 'plusgiro'
|
||||
elif bg_valid and not pg_valid:
|
||||
return bg_formatted, 'bankgiro'
|
||||
else:
|
||||
# Both valid or both invalid - default to bankgiro
|
||||
return bg_formatted, 'bankgiro'
|
||||
|
||||
# Try primary pattern
|
||||
match = self.PAYMENT_LINE_PATTERN.search(raw_line)
|
||||
|
||||
34
src/utils/__init__.py
Normal file
34
src/utils/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Shared utilities for invoice field extraction and matching.
|
||||
|
||||
This module provides common functionality used by both:
|
||||
- Inference stage (field_extractor.py) - extracting values from OCR text
|
||||
- Matching stage (normalizer.py) - generating variants for CSV matching
|
||||
|
||||
Modules:
|
||||
- TextCleaner: Unicode normalization and OCR error correction
|
||||
- FormatVariants: Generate format variants for matching
|
||||
- FieldValidators: Validate field values (Luhn, dates, amounts)
|
||||
- FuzzyMatcher: Fuzzy string matching with OCR awareness
|
||||
- OCRCorrections: Comprehensive OCR error correction
|
||||
- ContextExtractor: Context-aware field extraction
|
||||
"""
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
from .format_variants import FormatVariants
|
||||
from .validators import FieldValidators
|
||||
from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
|
||||
from .ocr_corrections import OCRCorrections, CorrectionResult
|
||||
from .context_extractor import ContextExtractor, ExtractionCandidate
|
||||
|
||||
__all__ = [
|
||||
'TextCleaner',
|
||||
'FormatVariants',
|
||||
'FieldValidators',
|
||||
'FuzzyMatcher',
|
||||
'FuzzyMatchResult',
|
||||
'OCRCorrections',
|
||||
'CorrectionResult',
|
||||
'ContextExtractor',
|
||||
'ExtractionCandidate',
|
||||
]
|
||||
433
src/utils/context_extractor.py
Normal file
433
src/utils/context_extractor.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Context-Aware Extraction Module
|
||||
|
||||
Extracts field values using contextual cues and label detection.
|
||||
Improves extraction accuracy by understanding the semantic context.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, NamedTuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
from .validators import FieldValidators
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionCandidate:
|
||||
"""A candidate extracted value with metadata."""
|
||||
value: str
|
||||
raw_text: str
|
||||
context_label: str
|
||||
confidence: float
|
||||
position: int # Character position in source text
|
||||
extraction_method: str # 'label', 'pattern', 'proximity'
|
||||
|
||||
|
||||
class ContextExtractor:
|
||||
"""
|
||||
Context-aware field extraction.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Label detection - finds values after field labels
|
||||
2. Pattern matching - uses field-specific regex patterns
|
||||
3. Proximity analysis - finds values near related terms
|
||||
4. Validation filtering - removes invalid candidates
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Swedish Label Patterns (what appears before the value)
|
||||
# =========================================================================
|
||||
|
||||
LABEL_PATTERNS = {
|
||||
'InvoiceNumber': [
|
||||
# Swedish
|
||||
r'(?:faktura|fakt)\.?\s*(?:nr|nummer|#)?[:\s]*',
|
||||
r'(?:fakturanummer|fakturanr)[:\s]*',
|
||||
r'(?:vår\s+referens)[:\s]*',
|
||||
# English
|
||||
r'(?:invoice)\s*(?:no|number|#)?[:\s]*',
|
||||
r'inv[.:\s]*#?',
|
||||
],
|
||||
'Amount': [
|
||||
# Swedish
|
||||
r'(?:att\s+)?betala[:\s]*',
|
||||
r'(?:total|totalt|summa)[:\s]*',
|
||||
r'(?:belopp)[:\s]*',
|
||||
r'(?:slutsumma)[:\s]*',
|
||||
r'(?:att\s+erlägga)[:\s]*',
|
||||
# English
|
||||
r'(?:total|amount|sum)[:\s]*',
|
||||
r'(?:amount\s+due)[:\s]*',
|
||||
],
|
||||
'InvoiceDate': [
|
||||
# Swedish
|
||||
r'(?:faktura)?datum[:\s]*',
|
||||
r'(?:fakt\.?\s*datum)[:\s]*',
|
||||
# English
|
||||
r'(?:invoice\s+)?date[:\s]*',
|
||||
],
|
||||
'InvoiceDueDate': [
|
||||
# Swedish
|
||||
r'(?:förfall(?:o)?datum)[:\s]*',
|
||||
r'(?:betalas\s+senast)[:\s]*',
|
||||
r'(?:sista\s+betalningsdag)[:\s]*',
|
||||
r'(?:förfaller)[:\s]*',
|
||||
# English
|
||||
r'(?:due\s+date)[:\s]*',
|
||||
r'(?:payment\s+due)[:\s]*',
|
||||
],
|
||||
'OCR': [
|
||||
r'(?:ocr)[:\s]*',
|
||||
r'(?:ocr\s*-?\s*nummer)[:\s]*',
|
||||
r'(?:referens(?:nummer)?)[:\s]*',
|
||||
r'(?:betalningsreferens)[:\s]*',
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'(?:bankgiro|bg)[:\s]*',
|
||||
r'(?:bank\s*giro)[:\s]*',
|
||||
],
|
||||
'Plusgiro': [
|
||||
r'(?:plusgiro|pg)[:\s]*',
|
||||
r'(?:plus\s*giro)[:\s]*',
|
||||
r'(?:postgiro)[:\s]*',
|
||||
],
|
||||
'supplier_organisation_number': [
|
||||
r'(?:org\.?\s*(?:nr|nummer)?)[:\s]*',
|
||||
r'(?:organisationsnummer)[:\s]*',
|
||||
r'(?:org\.?\s*id)[:\s]*',
|
||||
r'(?:vat\s*(?:no|number|nr)?)[:\s]*',
|
||||
r'(?:moms(?:reg)?\.?\s*(?:nr|nummer)?)[:\s]*',
|
||||
r'(?:se)[:\s]*', # VAT prefix
|
||||
],
|
||||
'customer_number': [
|
||||
r'(?:kund(?:nr|nummer)?)[:\s]*',
|
||||
r'(?:kundnummer)[:\s]*',
|
||||
r'(?:customer\s*(?:no|number|id)?)[:\s]*',
|
||||
r'(?:er\s+referens)[:\s]*',
|
||||
],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Value Patterns (what the value looks like)
|
||||
# =========================================================================
|
||||
|
||||
VALUE_PATTERNS = {
|
||||
'InvoiceNumber': [
|
||||
r'[A-Z]{0,3}\d{3,15}', # Alphanumeric: INV12345
|
||||
r'\d{3,15}', # Pure digits
|
||||
r'20\d{2}[-/]\d{3,8}', # Year prefix: 2024-001
|
||||
],
|
||||
'Amount': [
|
||||
r'\d{1,3}(?:[\s.]\d{3})*[,]\d{2}', # Swedish: 1 234,56
|
||||
r'\d{1,3}(?:[,]\d{3})*[.]\d{2}', # US: 1,234.56
|
||||
r'\d+[,.]\d{2}', # Simple: 123,45
|
||||
r'\d+', # Integer
|
||||
],
|
||||
'InvoiceDate': [
|
||||
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}', # ISO-like
|
||||
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}', # European
|
||||
r'\d{8}', # Compact YYYYMMDD
|
||||
],
|
||||
'InvoiceDueDate': [
|
||||
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}',
|
||||
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}',
|
||||
r'\d{8}',
|
||||
],
|
||||
'OCR': [
|
||||
r'\d{10,25}', # Long digit sequence
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'\d{3,4}[-\s]?\d{4}', # XXX-XXXX or XXXX-XXXX
|
||||
r'\d{7,8}', # Without separator
|
||||
],
|
||||
'Plusgiro': [
|
||||
r'\d{1,7}[-\s]?\d', # XXXXXXX-X
|
||||
r'\d{2,8}', # Without separator
|
||||
],
|
||||
'supplier_organisation_number': [
|
||||
r'\d{6}[-\s]?\d{4}', # NNNNNN-NNNN
|
||||
r'\d{10}', # Without separator
|
||||
r'SE\s?\d{10,12}(?:\s?01)?', # VAT format
|
||||
],
|
||||
'customer_number': [
|
||||
r'[A-Z]{0,5}\s?[-]?\s?\d{1,10}', # EMM 256-6
|
||||
r'\d{3,15}', # Pure digits
|
||||
],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Extraction Methods
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def extract_with_label(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> list[ExtractionCandidate]:
|
||||
"""
|
||||
Extract field values by finding labels and taking following values.
|
||||
|
||||
Example: "Fakturanummer: 12345" -> extracts "12345"
|
||||
"""
|
||||
candidates = []
|
||||
label_patterns = cls.LABEL_PATTERNS.get(field_name, [])
|
||||
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
|
||||
|
||||
for label_pattern in label_patterns:
|
||||
for value_pattern in value_patterns:
|
||||
# Combine label + value patterns
|
||||
full_pattern = f'({label_pattern})({value_pattern})'
|
||||
matches = re.finditer(full_pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
label = match.group(1).strip()
|
||||
value = match.group(2).strip()
|
||||
|
||||
# Validate if requested
|
||||
if validate and not cls._validate_value(field_name, value):
|
||||
continue
|
||||
|
||||
# Calculate confidence based on label specificity
|
||||
confidence = cls._calculate_label_confidence(label, field_name)
|
||||
|
||||
candidates.append(ExtractionCandidate(
|
||||
value=value,
|
||||
raw_text=match.group(0),
|
||||
context_label=label,
|
||||
confidence=confidence,
|
||||
position=match.start(),
|
||||
extraction_method='label'
|
||||
))
|
||||
|
||||
return candidates
|
||||
|
||||
@classmethod
|
||||
def extract_with_pattern(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> list[ExtractionCandidate]:
|
||||
"""
|
||||
Extract field values using only value patterns (no label required).
|
||||
|
||||
This is a fallback when no labels are found.
|
||||
"""
|
||||
candidates = []
|
||||
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
|
||||
|
||||
for pattern in value_patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
value = match.group(0).strip()
|
||||
|
||||
# Validate if requested
|
||||
if validate and not cls._validate_value(field_name, value):
|
||||
continue
|
||||
|
||||
# Lower confidence for pattern-only extraction
|
||||
confidence = 0.6
|
||||
|
||||
candidates.append(ExtractionCandidate(
|
||||
value=value,
|
||||
raw_text=value,
|
||||
context_label='',
|
||||
confidence=confidence,
|
||||
position=match.start(),
|
||||
extraction_method='pattern'
|
||||
))
|
||||
|
||||
return candidates
|
||||
|
||||
@classmethod
|
||||
def extract_field(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> list[ExtractionCandidate]:
|
||||
"""
|
||||
Extract all candidate values for a field using multiple strategies.
|
||||
|
||||
Returns candidates sorted by confidence (highest first).
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
# Strategy 1: Label-based extraction (highest confidence)
|
||||
label_candidates = cls.extract_with_label(text, field_name, validate)
|
||||
candidates.extend(label_candidates)
|
||||
|
||||
# Strategy 2: Pattern-based extraction (fallback)
|
||||
if not label_candidates:
|
||||
pattern_candidates = cls.extract_with_pattern(text, field_name, validate)
|
||||
candidates.extend(pattern_candidates)
|
||||
|
||||
# Remove duplicates (same value, keep highest confidence)
|
||||
seen_values = {}
|
||||
for candidate in candidates:
|
||||
normalized = TextCleaner.normalize_for_comparison(candidate.value)
|
||||
if normalized not in seen_values or candidate.confidence > seen_values[normalized].confidence:
|
||||
seen_values[normalized] = candidate
|
||||
|
||||
# Sort by confidence
|
||||
result = sorted(seen_values.values(), key=lambda x: x.confidence, reverse=True)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_best(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> Optional[ExtractionCandidate]:
|
||||
"""
|
||||
Extract the best (highest confidence) candidate for a field.
|
||||
"""
|
||||
candidates = cls.extract_field(text, field_name, validate)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
@classmethod
|
||||
def extract_all_fields(cls, text: str) -> dict[str, list[ExtractionCandidate]]:
|
||||
"""
|
||||
Extract all known fields from text.
|
||||
|
||||
Returns a dictionary mapping field names to their candidates.
|
||||
"""
|
||||
results = {}
|
||||
for field_name in cls.LABEL_PATTERNS.keys():
|
||||
candidates = cls.extract_field(text, field_name)
|
||||
if candidates:
|
||||
results[field_name] = candidates
|
||||
return results
|
||||
|
||||
# =========================================================================
|
||||
# Helper Methods
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def _validate_value(cls, field_name: str, value: str) -> bool:
|
||||
"""Validate a value based on field type."""
|
||||
field_lower = field_name.lower()
|
||||
|
||||
if 'date' in field_lower:
|
||||
return FieldValidators.is_valid_date(value)
|
||||
elif 'amount' in field_lower:
|
||||
return FieldValidators.is_valid_amount(value)
|
||||
elif 'bankgiro' in field_lower:
|
||||
# Basic format check, not Luhn
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return 7 <= len(digits) <= 8
|
||||
elif 'plusgiro' in field_lower:
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return 2 <= len(digits) <= 8
|
||||
elif 'ocr' in field_lower:
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return 10 <= len(digits) <= 25
|
||||
elif 'org' in field_lower:
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return len(digits) >= 10
|
||||
else:
|
||||
# For other fields, just check it's not empty
|
||||
return bool(value.strip())
|
||||
|
||||
@classmethod
|
||||
def _calculate_label_confidence(cls, label: str, field_name: str) -> float:
|
||||
"""
|
||||
Calculate confidence based on how specific the label is.
|
||||
|
||||
More specific labels = higher confidence.
|
||||
"""
|
||||
label_lower = label.lower()
|
||||
|
||||
# Very specific labels
|
||||
very_specific = {
|
||||
'InvoiceNumber': ['fakturanummer', 'invoice number', 'fakturanr'],
|
||||
'Amount': ['att betala', 'slutsumma', 'amount due'],
|
||||
'InvoiceDate': ['fakturadatum', 'invoice date'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfallodag', 'due date'],
|
||||
'OCR': ['ocr', 'betalningsreferens'],
|
||||
'Bankgiro': ['bankgiro'],
|
||||
'Plusgiro': ['plusgiro', 'postgiro'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org nummer'],
|
||||
'customer_number': ['kundnummer', 'customer number'],
|
||||
}
|
||||
|
||||
# Check for very specific match
|
||||
if field_name in very_specific:
|
||||
for specific in very_specific[field_name]:
|
||||
if specific in label_lower:
|
||||
return 0.95
|
||||
|
||||
# Moderately specific
|
||||
moderate = {
|
||||
'InvoiceNumber': ['faktura', 'invoice', 'nr'],
|
||||
'Amount': ['total', 'summa', 'belopp'],
|
||||
'InvoiceDate': ['datum', 'date'],
|
||||
'InvoiceDueDate': ['förfall', 'due'],
|
||||
}
|
||||
|
||||
if field_name in moderate:
|
||||
for mod in moderate[field_name]:
|
||||
if mod in label_lower:
|
||||
return 0.85
|
||||
|
||||
# Generic match
|
||||
return 0.75
|
||||
|
||||
@classmethod
|
||||
def find_field_context(cls, text: str, position: int, window: int = 50) -> str:
|
||||
"""
|
||||
Get the surrounding context for a position in text.
|
||||
|
||||
Useful for understanding what field a value belongs to.
|
||||
"""
|
||||
start = max(0, position - window)
|
||||
end = min(len(text), position + window)
|
||||
return text[start:end]
|
||||
|
||||
@classmethod
|
||||
def identify_field_type(cls, text: str, value: str) -> Optional[str]:
|
||||
"""
|
||||
Try to identify what field type a value belongs to based on context.
|
||||
|
||||
Looks at text surrounding the value to find labels.
|
||||
"""
|
||||
# Find the value in text
|
||||
pos = text.find(value)
|
||||
if pos == -1:
|
||||
return None
|
||||
|
||||
# Get context before the value
|
||||
context_before = text[max(0, pos - 50):pos].lower()
|
||||
|
||||
# Check each field's labels
|
||||
for field_name, patterns in cls.LABEL_PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, context_before, re.IGNORECASE):
|
||||
return field_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience functions
|
||||
# =========================================================================
|
||||
|
||||
def extract_field_with_context(text: str, field_name: str) -> Optional[str]:
|
||||
"""Convenience function to extract a field value."""
|
||||
candidate = ContextExtractor.extract_best(text, field_name)
|
||||
return candidate.value if candidate else None
|
||||
|
||||
|
||||
def extract_all_with_context(text: str) -> dict[str, str]:
|
||||
"""Convenience function to extract all fields."""
|
||||
all_candidates = ContextExtractor.extract_all_fields(text)
|
||||
return {
|
||||
field: candidates[0].value
|
||||
for field, candidates in all_candidates.items()
|
||||
if candidates
|
||||
}
|
||||
610
src/utils/format_variants.py
Normal file
610
src/utils/format_variants.py
Normal file
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Format Variants Generator
|
||||
|
||||
Generates multiple format variants for invoice field values.
|
||||
Used by both inference (to try different extractions) and matching (to match CSV values).
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class FormatVariants:
|
||||
"""
|
||||
Generates format variants for different field types.
|
||||
|
||||
The same logic is used for:
|
||||
- Inference: trying different formats to extract a value
|
||||
- Matching: generating variants of CSV values to match against OCR text
|
||||
"""
|
||||
|
||||
# Swedish month names for date parsing
|
||||
SWEDISH_MONTHS = {
|
||||
'januari': '01', 'jan': '01',
|
||||
'februari': '02', 'feb': '02',
|
||||
'mars': '03', 'mar': '03',
|
||||
'april': '04', 'apr': '04',
|
||||
'maj': '05',
|
||||
'juni': '06', 'jun': '06',
|
||||
'juli': '07', 'jul': '07',
|
||||
'augusti': '08', 'aug': '08',
|
||||
'september': '09', 'sep': '09', 'sept': '09',
|
||||
'oktober': '10', 'okt': '10',
|
||||
'november': '11', 'nov': '11',
|
||||
'december': '12', 'dec': '12',
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Organization Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def organisation_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate all format variants for Swedish organization number.
|
||||
|
||||
Input formats handled:
|
||||
- "556123-4567" (standard with hyphen)
|
||||
- "5561234567" (no hyphen)
|
||||
- "SE556123456701" (VAT format)
|
||||
- "SE 556123-4567 01" (VAT with spaces)
|
||||
|
||||
Returns all possible variants for matching.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
value_upper = value.upper()
|
||||
variants = set()
|
||||
|
||||
# 提取纯数字
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
# 如果是 VAT 格式,提取中间的 org number
|
||||
# SE + 10 digits + 01 = "SE556123456701"
|
||||
if value_upper.startswith('SE') and len(digits) == 12 and digits.endswith('01'):
|
||||
# VAT format: SE + org_number + 01
|
||||
digits = digits[:10]
|
||||
elif digits.startswith('46') and len(digits) == 14:
|
||||
# SE prefix in numeric (46 is SE in phone code): 46 + 10 digits + 01
|
||||
digits = digits[2:12]
|
||||
|
||||
if len(digits) == 12:
|
||||
# 12 位数字可能是带世纪前缀的: NNNNNNNN-NNNN (19556123-4567)
|
||||
variants.add(value)
|
||||
variants.add(digits) # 195561234567
|
||||
# 带横线格式
|
||||
variants.add(f"{digits[:8]}-{digits[8:]}") # 19556123-4567
|
||||
# 提取后 10 位作为标准 org number
|
||||
short_digits = digits[2:] # 5561234567
|
||||
variants.add(short_digits)
|
||||
variants.add(f"{short_digits[:6]}-{short_digits[6:]}") # 556123-4567
|
||||
# VAT 格式
|
||||
variants.add(f"SE{short_digits}01") # SE556123456701
|
||||
return list(v for v in variants if v)
|
||||
|
||||
if len(digits) != 10:
|
||||
# 如果不是标准 10 位,返回原始值和清洗后的变体
|
||||
variants.add(value)
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
return list(variants)
|
||||
|
||||
# 生成所有变体
|
||||
# 1. 纯数字
|
||||
variants.add(digits) # 5561234567
|
||||
|
||||
# 2. 标准格式 (NNNNNN-NNNN)
|
||||
with_hyphen = f"{digits[:6]}-{digits[6:]}"
|
||||
variants.add(with_hyphen) # 556123-4567
|
||||
|
||||
# 3. VAT 格式
|
||||
vat_compact = f"SE{digits}01"
|
||||
variants.add(vat_compact) # SE556123456701
|
||||
variants.add(vat_compact.lower()) # se556123456701
|
||||
|
||||
vat_spaced = f"SE {digits[:6]}-{digits[6:]} 01"
|
||||
variants.add(vat_spaced) # SE 556123-4567 01
|
||||
|
||||
vat_spaced_no_hyphen = f"SE {digits} 01"
|
||||
variants.add(vat_spaced_no_hyphen) # SE 5561234567 01
|
||||
|
||||
# 4. 有时带国家代码但无 01 后缀
|
||||
variants.add(f"SE{digits}") # SE5561234567
|
||||
variants.add(f"SE {digits}") # SE 5561234567
|
||||
variants.add(f"SE{digits[:6]}-{digits[6:]}") # SE556123-4567
|
||||
|
||||
# 5. OCR 可能的错误变体
|
||||
ocr_variants = TextCleaner.generate_ocr_variants(digits)
|
||||
for ocr_var in ocr_variants:
|
||||
if len(ocr_var) == 10:
|
||||
variants.add(ocr_var)
|
||||
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Bankgiro Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def bankgiro_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for Bankgiro number.
|
||||
|
||||
Formats:
|
||||
- 7 digits: XXX-XXXX (e.g., 123-4567)
|
||||
- 8 digits: XXXX-XXXX (e.g., 1234-5678)
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
variants = set()
|
||||
|
||||
variants.add(value)
|
||||
|
||||
if not digits or len(digits) < 7 or len(digits) > 8:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# 纯数字
|
||||
variants.add(digits)
|
||||
|
||||
# 带横线格式
|
||||
if len(digits) == 7:
|
||||
variants.add(f"{digits[:3]}-{digits[3:]}") # XXX-XXXX
|
||||
elif len(digits) == 8:
|
||||
variants.add(f"{digits[:4]}-{digits[4:]}") # XXXX-XXXX
|
||||
# 有些 8 位也用 XXX-XXXXX 格式
|
||||
variants.add(f"{digits[:3]}-{digits[3:]}")
|
||||
|
||||
# 带空格格式 (有时 OCR 会这样识别)
|
||||
if len(digits) == 7:
|
||||
variants.add(f"{digits[:3]} {digits[3:]}")
|
||||
elif len(digits) == 8:
|
||||
variants.add(f"{digits[:4]} {digits[4:]}")
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Plusgiro Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def plusgiro_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for Plusgiro number.
|
||||
|
||||
Format: XXXXXXX-X (7 digits + check digit) or shorter
|
||||
Examples: 1234567-8, 12345-6, 1-8
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
variants = set()
|
||||
|
||||
variants.add(value)
|
||||
|
||||
if not digits or len(digits) < 2 or len(digits) > 8:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# 纯数字
|
||||
variants.add(digits)
|
||||
|
||||
# Plusgiro 格式: 最后一位是校验位,用横线分隔
|
||||
main_part = digits[:-1]
|
||||
check_digit = digits[-1]
|
||||
variants.add(f"{main_part}-{check_digit}")
|
||||
|
||||
# 有时带空格
|
||||
variants.add(f"{main_part} {check_digit}")
|
||||
|
||||
# 分组格式 (常见于长号码): XX XX XX-X
|
||||
if len(digits) >= 6:
|
||||
# 尝试 XX XX XX-X 格式
|
||||
spaced = ' '.join([digits[i:i + 2] for i in range(0, len(digits) - 1, 2)])
|
||||
if len(digits) % 2 == 0:
|
||||
spaced = spaced[:-1] + '-' + digits[-1]
|
||||
else:
|
||||
spaced = spaced + '-' + digits[-1]
|
||||
variants.add(spaced.replace('- ', '-'))
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Amount Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def amount_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for monetary amounts.
|
||||
|
||||
Handles:
|
||||
- Swedish: 1 234,56 (space thousand, comma decimal)
|
||||
- German: 1.234,56 (dot thousand, comma decimal)
|
||||
- US/UK: 1,234.56 (comma thousand, dot decimal)
|
||||
- Integer: 1234 -> 1234.00
|
||||
|
||||
Returns variants with different separators and with/without decimals.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 尝试解析为数值
|
||||
amount = cls._parse_amount(value)
|
||||
if amount is None:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# 生成不同格式的变体
|
||||
int_part = int(amount)
|
||||
dec_part = round((amount - int_part) * 100)
|
||||
|
||||
# 1. 基础格式
|
||||
if dec_part == 0:
|
||||
variants.add(str(int_part)) # 1234
|
||||
variants.add(f"{int_part}.00") # 1234.00
|
||||
variants.add(f"{int_part},00") # 1234,00
|
||||
else:
|
||||
variants.add(f"{int_part}.{dec_part:02d}") # 1234.56
|
||||
variants.add(f"{int_part},{dec_part:02d}") # 1234,56
|
||||
|
||||
# 2. 带千位分隔符
|
||||
int_str = str(int_part)
|
||||
if len(int_str) > 3:
|
||||
# 从右往左每3位加分隔符
|
||||
parts = []
|
||||
while int_str:
|
||||
parts.append(int_str[-3:])
|
||||
int_str = int_str[:-3]
|
||||
parts.reverse()
|
||||
|
||||
# 空格分隔 (Swedish)
|
||||
space_sep = ' '.join(parts)
|
||||
if dec_part == 0:
|
||||
variants.add(space_sep)
|
||||
else:
|
||||
variants.add(f"{space_sep},{dec_part:02d}")
|
||||
variants.add(f"{space_sep}.{dec_part:02d}")
|
||||
|
||||
# 点分隔 (German)
|
||||
dot_sep = '.'.join(parts)
|
||||
if dec_part == 0:
|
||||
variants.add(dot_sep)
|
||||
else:
|
||||
variants.add(f"{dot_sep},{dec_part:02d}")
|
||||
|
||||
# 逗号分隔 (US)
|
||||
comma_sep = ','.join(parts)
|
||||
if dec_part == 0:
|
||||
variants.add(comma_sep)
|
||||
else:
|
||||
variants.add(f"{comma_sep}.{dec_part:02d}")
|
||||
|
||||
# 3. 带货币符号
|
||||
base_amounts = [f"{int_part}.{dec_part:02d}", f"{int_part},{dec_part:02d}"]
|
||||
if dec_part == 0:
|
||||
base_amounts.append(str(int_part))
|
||||
|
||||
for base in base_amounts:
|
||||
variants.add(f"{base} kr")
|
||||
variants.add(f"{base} SEK")
|
||||
variants.add(f"{base}kr")
|
||||
variants.add(f"SEK {base}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@classmethod
|
||||
def _parse_amount(cls, text: str) -> Optional[float]:
|
||||
"""Parse amount from various formats."""
|
||||
text = TextCleaner.normalize_amount_text(text)
|
||||
|
||||
# 移除所有非数字和分隔符
|
||||
clean = re.sub(r'[^\d,.\s]', '', text)
|
||||
if not clean:
|
||||
return None
|
||||
|
||||
# 检测格式
|
||||
# 瑞典格式: 1 234,56 或 1234,56
|
||||
if re.match(r'^[\d\s]+,\d{2}$', clean):
|
||||
clean = clean.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 德国格式: 1.234,56
|
||||
if re.match(r'^[\d.]+,\d{2}$', clean):
|
||||
clean = clean.replace('.', '').replace(',', '.')
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 美国格式: 1,234.56
|
||||
if re.match(r'^[\d,]+\.\d{2}$', clean):
|
||||
clean = clean.replace(',', '')
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 简单格式
|
||||
clean = clean.replace(' ', '').replace(',', '.')
|
||||
# 如果有多个点,只保留最后一个
|
||||
if clean.count('.') > 1:
|
||||
parts = clean.rsplit('.', 1)
|
||||
clean = parts[0].replace('.', '') + '.' + parts[1]
|
||||
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Date Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def date_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for dates.
|
||||
|
||||
Input can be:
|
||||
- ISO: 2024-12-29
|
||||
- European: 29/12/2024, 29.12.2024
|
||||
- Swedish text: "29 december 2024"
|
||||
- Compact: 20241229
|
||||
|
||||
Returns all format variants.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 尝试解析日期
|
||||
parsed = cls._parse_date(value)
|
||||
if parsed is None:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
year, month, day = parsed
|
||||
|
||||
# 生成所有格式变体
|
||||
# ISO
|
||||
variants.add(f"{year}-{month:02d}-{day:02d}")
|
||||
variants.add(f"{year}-{month}-{day}") # 不补零
|
||||
|
||||
# 点分隔 (Swedish common)
|
||||
variants.add(f"{year}.{month:02d}.{day:02d}")
|
||||
variants.add(f"{day:02d}.{month:02d}.{year}")
|
||||
|
||||
# 斜杠分隔
|
||||
variants.add(f"{day:02d}/{month:02d}/{year}")
|
||||
variants.add(f"{month:02d}/{day:02d}/{year}") # US format
|
||||
variants.add(f"{year}/{month:02d}/{day:02d}")
|
||||
|
||||
# 紧凑格式
|
||||
variants.add(f"{year}{month:02d}{day:02d}")
|
||||
|
||||
# 带月份名 (Swedish)
|
||||
for month_name, month_num in cls.SWEDISH_MONTHS.items():
|
||||
if month_num == f"{month:02d}":
|
||||
variants.add(f"{day} {month_name} {year}")
|
||||
variants.add(f"{day:02d} {month_name} {year}")
|
||||
# 首字母大写
|
||||
variants.add(f"{day} {month_name.capitalize()} {year}")
|
||||
|
||||
# 短年份
|
||||
short_year = str(year)[2:]
|
||||
variants.add(f"{day:02d}.{month:02d}.{short_year}")
|
||||
variants.add(f"{day:02d}/{month:02d}/{short_year}")
|
||||
variants.add(f"{short_year}-{month:02d}-{day:02d}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@classmethod
|
||||
def _parse_date(cls, text: str) -> Optional[tuple[int, int, int]]:
|
||||
"""
|
||||
Parse date from text, returns (year, month, day) or None.
|
||||
"""
|
||||
text = TextCleaner.clean_text(text).lower()
|
||||
|
||||
# ISO: 2024-12-29
|
||||
match = re.search(r'(\d{4})-(\d{1,2})-(\d{1,2})', text)
|
||||
if match:
|
||||
return int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
|
||||
# Dot format: 2024.12.29
|
||||
match = re.search(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', text)
|
||||
if match:
|
||||
return int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
|
||||
# European: 29/12/2024 or 29.12.2024
|
||||
match = re.search(r'(\d{1,2})[/.](\d{1,2})[/.](\d{4})', text)
|
||||
if match:
|
||||
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
# 验证日期合理性
|
||||
if 1 <= day <= 31 and 1 <= month <= 12:
|
||||
return year, month, day
|
||||
|
||||
# Compact: 20241229
|
||||
match = re.search(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', text)
|
||||
if match:
|
||||
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
if 2000 <= year <= 2100 and 1 <= month <= 12 and 1 <= day <= 31:
|
||||
return year, month, day
|
||||
|
||||
# Swedish month name: "29 december 2024"
|
||||
for month_name, month_num in cls.SWEDISH_MONTHS.items():
|
||||
pattern = rf'(\d{{1,2}})\s*{month_name}\s*(\d{{4}})'
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
day, year = int(match.group(1)), int(match.group(2))
|
||||
return year, int(month_num), day
|
||||
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def invoice_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for invoice numbers.
|
||||
|
||||
Invoice numbers are highly variable:
|
||||
- Pure digits: 12345678
|
||||
- Alphanumeric: A3861, INV-2024-001
|
||||
- With separators: 2024/001
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 提取数字部分
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
|
||||
# 大小写变体
|
||||
variants.add(value.upper())
|
||||
variants.add(value.lower())
|
||||
|
||||
# 移除分隔符
|
||||
no_sep = re.sub(r'[-/\s]', '', value)
|
||||
variants.add(no_sep)
|
||||
variants.add(no_sep.upper())
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(value):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# OCR Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def ocr_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for OCR reference numbers.
|
||||
|
||||
OCR numbers are typically 10-25 digits.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
variants = set()
|
||||
|
||||
variants.add(value)
|
||||
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
|
||||
# 有些 OCR 号码带空格分组
|
||||
if len(digits) > 4:
|
||||
# 每 4 位分组
|
||||
spaced = ' '.join([digits[i:i + 4] for i in range(0, len(digits), 4)])
|
||||
variants.add(spaced)
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Customer Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def customer_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for customer numbers.
|
||||
|
||||
Customer numbers can be very diverse:
|
||||
- Pure digits: 12345
|
||||
- Alphanumeric: ABC123, EMM 256-6
|
||||
- With separators: 123-456
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 大小写
|
||||
variants.add(value.upper())
|
||||
variants.add(value.lower())
|
||||
|
||||
# 移除所有分隔符和空格
|
||||
compact = re.sub(r'[-/\s]', '', value)
|
||||
variants.add(compact)
|
||||
variants.add(compact.upper())
|
||||
variants.add(compact.lower())
|
||||
|
||||
# 纯数字
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
|
||||
# 纯字母 + 数字 (分离)
|
||||
letters = re.sub(r'[^a-zA-Z]', '', value)
|
||||
if letters and digits:
|
||||
variants.add(f"{letters}{digits}")
|
||||
variants.add(f"{letters.upper()}{digits}")
|
||||
variants.add(f"{letters} {digits}")
|
||||
variants.add(f"{letters.upper()} {digits}")
|
||||
variants.add(f"{letters}-{digits}")
|
||||
variants.add(f"{letters.upper()}-{digits}")
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(value):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Generic Field Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def get_variants(cls, field_name: str, value: str) -> list[str]:
|
||||
"""
|
||||
Get variants for a field by name.
|
||||
|
||||
This is the main entry point - dispatches to specific variant generators.
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
# 映射字段名到变体生成器
|
||||
if 'organisation' in field_lower or 'org' in field_lower:
|
||||
return cls.organisation_number_variants(value)
|
||||
elif 'bankgiro' in field_lower or field_lower == 'bg':
|
||||
return cls.bankgiro_variants(value)
|
||||
elif 'plusgiro' in field_lower or field_lower == 'pg':
|
||||
return cls.plusgiro_variants(value)
|
||||
elif 'amount' in field_lower or 'belopp' in field_lower:
|
||||
return cls.amount_variants(value)
|
||||
elif 'date' in field_lower or 'datum' in field_lower:
|
||||
return cls.date_variants(value)
|
||||
elif 'invoice' in field_lower and 'number' in field_lower:
|
||||
return cls.invoice_number_variants(value)
|
||||
elif field_lower == 'invoicenumber':
|
||||
return cls.invoice_number_variants(value)
|
||||
elif 'ocr' in field_lower:
|
||||
return cls.ocr_number_variants(value)
|
||||
elif 'customer' in field_lower:
|
||||
return cls.customer_number_variants(value)
|
||||
else:
|
||||
# 默认: 返回原值和基本清洗
|
||||
return [value, TextCleaner.clean_text(value)]
|
||||
417
src/utils/fuzzy_matcher.py
Normal file
417
src/utils/fuzzy_matcher.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Fuzzy Matching Module
|
||||
|
||||
Provides fuzzy string matching with OCR-aware similarity scoring.
|
||||
Handles common OCR errors and format variations in invoice fields.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuzzyMatchResult:
|
||||
"""Result of a fuzzy match operation."""
|
||||
matched: bool
|
||||
score: float # 0.0 to 1.0
|
||||
ocr_value: str
|
||||
expected_value: str
|
||||
normalized_ocr: str
|
||||
normalized_expected: str
|
||||
match_type: str # 'exact', 'normalized', 'fuzzy', 'ocr_corrected'
|
||||
|
||||
|
||||
class FuzzyMatcher:
|
||||
"""
|
||||
Fuzzy string matcher optimized for OCR text matching.
|
||||
|
||||
Provides multiple matching strategies:
|
||||
1. Exact match
|
||||
2. Normalized match (case-insensitive, whitespace-normalized)
|
||||
3. OCR-corrected match (applying common OCR error corrections)
|
||||
4. Edit distance based fuzzy match
|
||||
5. Digit-sequence match (for numeric fields)
|
||||
"""
|
||||
|
||||
# Minimum similarity threshold for fuzzy matches
|
||||
DEFAULT_THRESHOLD = 0.85
|
||||
|
||||
# Field-specific thresholds (some fields need stricter matching)
|
||||
FIELD_THRESHOLDS = {
|
||||
'InvoiceNumber': 0.90,
|
||||
'OCR': 0.95, # OCR numbers need high precision
|
||||
'Amount': 0.95,
|
||||
'Bankgiro': 0.90,
|
||||
'Plusgiro': 0.90,
|
||||
'InvoiceDate': 0.90,
|
||||
'InvoiceDueDate': 0.90,
|
||||
'supplier_organisation_number': 0.85,
|
||||
'customer_number': 0.80, # More lenient for customer numbers
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_threshold(cls, field_name: str) -> float:
|
||||
"""Get the matching threshold for a specific field."""
|
||||
return cls.FIELD_THRESHOLDS.get(field_name, cls.DEFAULT_THRESHOLD)
|
||||
|
||||
@classmethod
|
||||
def levenshtein_distance(cls, s1: str, s2: str) -> int:
|
||||
"""
|
||||
Calculate Levenshtein (edit) distance between two strings.
|
||||
|
||||
This is the minimum number of single-character edits
|
||||
(insertions, deletions, substitutions) needed to change s1 into s2.
|
||||
"""
|
||||
if len(s1) < len(s2):
|
||||
return cls.levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
# Cost is 0 if characters match, 1 otherwise
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
@classmethod
|
||||
def similarity_ratio(cls, s1: str, s2: str) -> float:
|
||||
"""
|
||||
Calculate similarity ratio between two strings.
|
||||
|
||||
Returns a value between 0.0 (completely different) and 1.0 (identical).
|
||||
Based on Levenshtein distance normalized by the length of the longer string.
|
||||
"""
|
||||
if not s1 and not s2:
|
||||
return 1.0
|
||||
if not s1 or not s2:
|
||||
return 0.0
|
||||
|
||||
max_len = max(len(s1), len(s2))
|
||||
distance = cls.levenshtein_distance(s1, s2)
|
||||
return 1.0 - (distance / max_len)
|
||||
|
||||
@classmethod
|
||||
def ocr_aware_similarity(cls, ocr_text: str, expected: str) -> float:
|
||||
"""
|
||||
Calculate similarity with OCR error awareness.
|
||||
|
||||
This method considers common OCR errors when calculating similarity,
|
||||
giving higher scores when differences are likely OCR mistakes.
|
||||
"""
|
||||
if not ocr_text or not expected:
|
||||
return 0.0 if ocr_text != expected else 1.0
|
||||
|
||||
# First try exact match
|
||||
if ocr_text == expected:
|
||||
return 1.0
|
||||
|
||||
# Try with OCR corrections applied to ocr_text
|
||||
corrected = TextCleaner.apply_ocr_digit_corrections(ocr_text)
|
||||
if corrected == expected:
|
||||
return 0.98 # Slightly less than exact match
|
||||
|
||||
# Try normalized comparison
|
||||
norm_ocr = TextCleaner.normalize_for_comparison(ocr_text)
|
||||
norm_expected = TextCleaner.normalize_for_comparison(expected)
|
||||
if norm_ocr == norm_expected:
|
||||
return 0.95
|
||||
|
||||
# Calculate base similarity
|
||||
base_sim = cls.similarity_ratio(norm_ocr, norm_expected)
|
||||
|
||||
# Boost score if differences are common OCR errors
|
||||
boost = cls._calculate_ocr_error_boost(ocr_text, expected)
|
||||
|
||||
return min(1.0, base_sim + boost)
|
||||
|
||||
@classmethod
|
||||
def _calculate_ocr_error_boost(cls, ocr_text: str, expected: str) -> float:
|
||||
"""
|
||||
Calculate a score boost based on whether differences are likely OCR errors.
|
||||
|
||||
Returns a value between 0.0 and 0.1.
|
||||
"""
|
||||
if len(ocr_text) != len(expected):
|
||||
return 0.0
|
||||
|
||||
ocr_errors = 0
|
||||
total_diffs = 0
|
||||
|
||||
for oc, ec in zip(ocr_text, expected):
|
||||
if oc != ec:
|
||||
total_diffs += 1
|
||||
# Check if this is a known OCR confusion pair
|
||||
if cls._is_ocr_confusion_pair(oc, ec):
|
||||
ocr_errors += 1
|
||||
|
||||
if total_diffs == 0:
|
||||
return 0.0
|
||||
|
||||
# Boost proportional to how many differences are OCR errors
|
||||
ocr_error_ratio = ocr_errors / total_diffs
|
||||
return ocr_error_ratio * 0.1
|
||||
|
||||
@classmethod
|
||||
def _is_ocr_confusion_pair(cls, char1: str, char2: str) -> bool:
|
||||
"""Check if two characters are commonly confused in OCR."""
|
||||
confusion_pairs = {
|
||||
('0', 'O'), ('0', 'o'), ('0', 'D'), ('0', 'Q'),
|
||||
('1', 'l'), ('1', 'I'), ('1', 'i'), ('1', '|'),
|
||||
('2', 'Z'), ('2', 'z'),
|
||||
('5', 'S'), ('5', 's'),
|
||||
('6', 'G'), ('6', 'b'),
|
||||
('8', 'B'),
|
||||
('9', 'g'), ('9', 'q'),
|
||||
}
|
||||
|
||||
pair = (char1, char2)
|
||||
return pair in confusion_pairs or (char2, char1) in confusion_pairs
|
||||
|
||||
@classmethod
|
||||
def match_digits(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match digit sequences with OCR error tolerance.
|
||||
|
||||
Optimized for numeric fields like OCR numbers, amounts, etc.
|
||||
"""
|
||||
# Extract digits
|
||||
ocr_digits = TextCleaner.extract_digits(ocr_text, apply_ocr_correction=True)
|
||||
expected_digits = TextCleaner.extract_digits(expected, apply_ocr_correction=False)
|
||||
|
||||
# Exact match after extraction
|
||||
if ocr_digits == expected_digits:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_digits,
|
||||
normalized_expected=expected_digits,
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Calculate similarity
|
||||
score = cls.ocr_aware_similarity(ocr_digits, expected_digits)
|
||||
|
||||
return FuzzyMatchResult(
|
||||
matched=score >= threshold,
|
||||
score=score,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_digits,
|
||||
normalized_expected=expected_digits,
|
||||
match_type='fuzzy' if score >= threshold else 'no_match'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_amount(cls, ocr_text: str, expected: str, threshold: float = 0.95) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match monetary amounts with format tolerance.
|
||||
|
||||
Handles different decimal separators (. vs ,) and thousand separators.
|
||||
"""
|
||||
from .validators import FieldValidators
|
||||
|
||||
# Parse both amounts
|
||||
ocr_amount = FieldValidators.parse_amount(ocr_text)
|
||||
expected_amount = FieldValidators.parse_amount(expected)
|
||||
|
||||
if ocr_amount is None or expected_amount is None:
|
||||
# Can't parse, fall back to string matching
|
||||
return cls.match_string(ocr_text, expected, threshold)
|
||||
|
||||
# Compare numeric values
|
||||
if abs(ocr_amount - expected_amount) < 0.01: # Within 1 cent
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=f"{ocr_amount:.2f}",
|
||||
normalized_expected=f"{expected_amount:.2f}",
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Calculate relative difference
|
||||
max_val = max(abs(ocr_amount), abs(expected_amount))
|
||||
if max_val > 0:
|
||||
diff_ratio = abs(ocr_amount - expected_amount) / max_val
|
||||
score = max(0.0, 1.0 - diff_ratio)
|
||||
else:
|
||||
score = 1.0 if ocr_amount == expected_amount else 0.0
|
||||
|
||||
return FuzzyMatchResult(
|
||||
matched=score >= threshold,
|
||||
score=score,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=f"{ocr_amount:.2f}" if ocr_amount else ocr_text,
|
||||
normalized_expected=f"{expected_amount:.2f}" if expected_amount else expected,
|
||||
match_type='fuzzy' if score >= threshold else 'no_match'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_date(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match dates with format tolerance.
|
||||
|
||||
Handles different date formats (ISO, European, compact, etc.)
|
||||
"""
|
||||
from .validators import FieldValidators
|
||||
|
||||
# Parse both dates to ISO format
|
||||
ocr_iso = FieldValidators.format_date_iso(ocr_text)
|
||||
expected_iso = FieldValidators.format_date_iso(expected)
|
||||
|
||||
if ocr_iso and expected_iso:
|
||||
if ocr_iso == expected_iso:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_iso,
|
||||
normalized_expected=expected_iso,
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Fall back to string matching on digits
|
||||
return cls.match_digits(ocr_text, expected, threshold)
|
||||
|
||||
@classmethod
|
||||
def match_string(cls, ocr_text: str, expected: str, threshold: float = 0.85) -> FuzzyMatchResult:
|
||||
"""
|
||||
General string matching with multiple strategies.
|
||||
|
||||
Tries exact, normalized, and fuzzy matching in order.
|
||||
"""
|
||||
# Clean both strings
|
||||
ocr_clean = TextCleaner.clean_text(ocr_text)
|
||||
expected_clean = TextCleaner.clean_text(expected)
|
||||
|
||||
# Strategy 1: Exact match
|
||||
if ocr_clean == expected_clean:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_clean,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Strategy 2: Case-insensitive match
|
||||
if ocr_clean.lower() == expected_clean.lower():
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=0.98,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_clean,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='normalized'
|
||||
)
|
||||
|
||||
# Strategy 3: OCR-corrected match
|
||||
ocr_corrected = TextCleaner.apply_ocr_digit_corrections(ocr_clean)
|
||||
if ocr_corrected == expected_clean:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=0.95,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_corrected,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='ocr_corrected'
|
||||
)
|
||||
|
||||
# Strategy 4: Fuzzy match
|
||||
score = cls.ocr_aware_similarity(ocr_clean, expected_clean)
|
||||
|
||||
return FuzzyMatchResult(
|
||||
matched=score >= threshold,
|
||||
score=score,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_clean,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='fuzzy' if score >= threshold else 'no_match'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_field(
|
||||
cls,
|
||||
field_name: str,
|
||||
ocr_value: str,
|
||||
expected_value: str,
|
||||
threshold: Optional[float] = None
|
||||
) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match a field value using field-appropriate strategy.
|
||||
|
||||
Automatically selects the best matching strategy based on field type.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = cls.get_threshold(field_name)
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
# Route to appropriate matcher
|
||||
if 'amount' in field_lower or 'belopp' in field_lower:
|
||||
return cls.match_amount(ocr_value, expected_value, threshold)
|
||||
|
||||
if 'date' in field_lower or 'datum' in field_lower:
|
||||
return cls.match_date(ocr_value, expected_value, threshold)
|
||||
|
||||
if any(x in field_lower for x in ['ocr', 'bankgiro', 'plusgiro', 'org']):
|
||||
# Numeric fields with OCR errors
|
||||
return cls.match_digits(ocr_value, expected_value, threshold)
|
||||
|
||||
if 'invoice' in field_lower and 'number' in field_lower:
|
||||
# Invoice numbers can be alphanumeric
|
||||
return cls.match_string(ocr_value, expected_value, threshold)
|
||||
|
||||
# Default to string matching
|
||||
return cls.match_string(ocr_value, expected_value, threshold)
|
||||
|
||||
@classmethod
|
||||
def find_best_match(
|
||||
cls,
|
||||
ocr_value: str,
|
||||
candidates: list[str],
|
||||
field_name: str = '',
|
||||
threshold: Optional[float] = None
|
||||
) -> Optional[tuple[str, FuzzyMatchResult]]:
|
||||
"""
|
||||
Find the best matching candidate from a list.
|
||||
|
||||
Returns (matched_value, match_result) or None if no match above threshold.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = cls.get_threshold(field_name) if field_name else cls.DEFAULT_THRESHOLD
|
||||
|
||||
best_match = None
|
||||
best_result = None
|
||||
|
||||
for candidate in candidates:
|
||||
result = cls.match_field(field_name, ocr_value, candidate, threshold=0.0)
|
||||
if result.score >= threshold:
|
||||
if best_result is None or result.score > best_result.score:
|
||||
best_match = candidate
|
||||
best_result = result
|
||||
|
||||
if best_match:
|
||||
return (best_match, best_result)
|
||||
return None
|
||||
384
src/utils/ocr_corrections.py
Normal file
384
src/utils/ocr_corrections.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
OCR Error Corrections Module
|
||||
|
||||
Provides comprehensive OCR error correction tables and correction functions.
|
||||
Based on common OCR recognition errors in Swedish invoice documents.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionResult:
|
||||
"""Result of an OCR correction operation."""
|
||||
original: str
|
||||
corrected: str
|
||||
corrections_applied: list[tuple[int, str, str]] # (position, from_char, to_char)
|
||||
confidence: float # How confident we are in the correction
|
||||
|
||||
|
||||
class OCRCorrections:
|
||||
"""
|
||||
Comprehensive OCR error correction utilities.
|
||||
|
||||
Provides:
|
||||
- Character-level corrections for digits
|
||||
- Word-level corrections for common Swedish terms
|
||||
- Context-aware corrections
|
||||
- Multiple correction strategies
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Character-level OCR errors (digit fields)
|
||||
# =========================================================================
|
||||
|
||||
# Characters commonly misread as digits
|
||||
CHAR_TO_DIGIT = {
|
||||
# Letters that look like digits
|
||||
'O': '0', 'o': '0', # O -> 0
|
||||
'Q': '0', # Q -> 0 (less common)
|
||||
'D': '0', # D -> 0 (in some fonts)
|
||||
|
||||
'l': '1', 'I': '1', # l/I -> 1
|
||||
'i': '1', # i without dot -> 1
|
||||
'|': '1', # pipe -> 1
|
||||
'!': '1', # exclamation -> 1
|
||||
|
||||
'Z': '2', 'z': '2', # Z -> 2
|
||||
|
||||
'E': '3', # E -> 3 (rare)
|
||||
|
||||
'A': '4', 'h': '4', # A/h -> 4 (in some fonts)
|
||||
|
||||
'S': '5', 's': '5', # S -> 5
|
||||
|
||||
'G': '6', 'b': '6', # G/b -> 6
|
||||
|
||||
'T': '7', 't': '7', # T -> 7 (rare)
|
||||
|
||||
'B': '8', # B -> 8
|
||||
|
||||
'g': '9', 'q': '9', # g/q -> 9
|
||||
}
|
||||
|
||||
# Digits commonly misread as other characters
|
||||
DIGIT_TO_CHAR = {
|
||||
'0': ['O', 'o', 'D', 'Q'],
|
||||
'1': ['l', 'I', 'i', '|', '!'],
|
||||
'2': ['Z', 'z'],
|
||||
'3': ['E'],
|
||||
'4': ['A', 'h'],
|
||||
'5': ['S', 's'],
|
||||
'6': ['G', 'b'],
|
||||
'7': ['T', 't'],
|
||||
'8': ['B'],
|
||||
'9': ['g', 'q'],
|
||||
}
|
||||
|
||||
# Bidirectional confusion pairs (either direction is possible)
|
||||
CONFUSION_PAIRS = [
|
||||
('0', 'O'), ('0', 'o'), ('0', 'D'),
|
||||
('1', 'l'), ('1', 'I'), ('1', '|'),
|
||||
('2', 'Z'), ('2', 'z'),
|
||||
('5', 'S'), ('5', 's'),
|
||||
('6', 'G'), ('6', 'b'),
|
||||
('8', 'B'),
|
||||
('9', 'g'), ('9', 'q'),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# Word-level OCR errors (Swedish invoice terms)
|
||||
# =========================================================================
|
||||
|
||||
# Common Swedish invoice terms and their OCR misreadings
|
||||
SWEDISH_TERM_CORRECTIONS = {
|
||||
# Faktura (Invoice)
|
||||
'faktura': ['Faktura', 'FAKTURA', 'faktúra', 'faKtura'],
|
||||
'fakturanummer': ['Fakturanummer', 'FAKTURANUMMER', 'fakturanr', 'fakt.nr'],
|
||||
'fakturadatum': ['Fakturadatum', 'FAKTURADATUM', 'fakt.datum'],
|
||||
|
||||
# Belopp (Amount)
|
||||
'belopp': ['Belopp', 'BELOPP', 'be1opp', 'bel0pp'],
|
||||
'summa': ['Summa', 'SUMMA', '5umma'],
|
||||
'total': ['Total', 'TOTAL', 'tota1', 't0tal'],
|
||||
'moms': ['Moms', 'MOMS', 'm0ms'],
|
||||
|
||||
# Dates
|
||||
'förfallodatum': ['Förfallodatum', 'FÖRFALLODATUM', 'förfa11odatum'],
|
||||
'datum': ['Datum', 'DATUM', 'dátum'],
|
||||
|
||||
# Payment
|
||||
'bankgiro': ['Bankgiro', 'BANKGIRO', 'BG', 'bg', 'bank giro'],
|
||||
'plusgiro': ['Plusgiro', 'PLUSGIRO', 'PG', 'pg', 'plus giro'],
|
||||
'postgiro': ['Postgiro', 'POSTGIRO'],
|
||||
'ocr': ['OCR', 'ocr', '0CR', 'OcR'],
|
||||
|
||||
# Organization
|
||||
'organisationsnummer': ['Organisationsnummer', 'ORGANISATIONSNUMMER', 'org.nr', 'orgnr'],
|
||||
'kundnummer': ['Kundnummer', 'KUNDNUMMER', 'kund nr', 'kundnr'],
|
||||
|
||||
# Currency
|
||||
'kronor': ['Kronor', 'KRONOR', 'kr', 'KR', 'SEK', 'sek'],
|
||||
'öre': ['Öre', 'ÖRE', 'ore', 'ORE'],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Context patterns
|
||||
# =========================================================================
|
||||
|
||||
# Patterns that indicate the following/preceding text is a specific field
|
||||
CONTEXT_INDICATORS = {
|
||||
'invoice_number': [
|
||||
r'faktura\s*(?:nr|nummer)?[:\s]*',
|
||||
r'invoice\s*(?:no|number)?[:\s]*',
|
||||
r'fakt\.?\s*nr[:\s]*',
|
||||
r'inv[:\s]*#?',
|
||||
],
|
||||
'amount': [
|
||||
r'(?:att\s+)?betala[:\s]*',
|
||||
r'total[t]?[:\s]*',
|
||||
r'summa[:\s]*',
|
||||
r'belopp[:\s]*',
|
||||
r'amount[:\s]*',
|
||||
],
|
||||
'date': [
|
||||
r'datum[:\s]*',
|
||||
r'date[:\s]*',
|
||||
r'förfall(?:o)?datum[:\s]*',
|
||||
r'fakturadatum[:\s]*',
|
||||
],
|
||||
'ocr': [
|
||||
r'ocr[:\s]*',
|
||||
r'referens[:\s]*',
|
||||
r'betalningsreferens[:\s]*',
|
||||
],
|
||||
'bankgiro': [
|
||||
r'bankgiro[:\s]*',
|
||||
r'bg[:\s]*',
|
||||
r'bank\s*giro[:\s]*',
|
||||
],
|
||||
'plusgiro': [
|
||||
r'plusgiro[:\s]*',
|
||||
r'pg[:\s]*',
|
||||
r'plus\s*giro[:\s]*',
|
||||
r'postgiro[:\s]*',
|
||||
],
|
||||
'org_number': [
|
||||
r'org\.?\s*(?:nr|nummer)?[:\s]*',
|
||||
r'organisationsnummer[:\s]*',
|
||||
r'vat[:\s]*',
|
||||
r'moms(?:reg)?\.?\s*(?:nr|nummer)?[:\s]*',
|
||||
],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Correction Methods
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def correct_digits(cls, text: str, aggressive: bool = False) -> CorrectionResult:
|
||||
"""
|
||||
Apply digit corrections to text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
aggressive: If True, correct all potential digit-like characters.
|
||||
If False, only correct characters adjacent to digits.
|
||||
|
||||
Returns:
|
||||
CorrectionResult with original, corrected text, and details.
|
||||
"""
|
||||
corrections = []
|
||||
result = []
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char.isdigit():
|
||||
result.append(char)
|
||||
elif char in cls.CHAR_TO_DIGIT:
|
||||
if aggressive:
|
||||
# Always correct
|
||||
corrected_char = cls.CHAR_TO_DIGIT[char]
|
||||
corrections.append((i, char, corrected_char))
|
||||
result.append(corrected_char)
|
||||
else:
|
||||
# Only correct if adjacent to digit
|
||||
prev_is_digit = i > 0 and (text[i-1].isdigit() or text[i-1] in cls.CHAR_TO_DIGIT)
|
||||
next_is_digit = i < len(text) - 1 and (text[i+1].isdigit() or text[i+1] in cls.CHAR_TO_DIGIT)
|
||||
|
||||
if prev_is_digit or next_is_digit:
|
||||
corrected_char = cls.CHAR_TO_DIGIT[char]
|
||||
corrections.append((i, char, corrected_char))
|
||||
result.append(corrected_char)
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
corrected = ''.join(result)
|
||||
confidence = 1.0 - (len(corrections) * 0.05) # Decrease confidence per correction
|
||||
|
||||
return CorrectionResult(
|
||||
original=text,
|
||||
corrected=corrected,
|
||||
corrections_applied=corrections,
|
||||
confidence=max(0.5, confidence)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_digit_variants(cls, text: str) -> list[str]:
|
||||
"""
|
||||
Generate all possible digit interpretations of a text.
|
||||
|
||||
Useful for matching when we don't know which direction the OCR error went.
|
||||
"""
|
||||
if not text:
|
||||
return [text]
|
||||
|
||||
variants = {text}
|
||||
|
||||
# For each character that could be confused
|
||||
for i, char in enumerate(text):
|
||||
new_variants = set()
|
||||
for existing in variants:
|
||||
# If it's a digit, add letter variants
|
||||
if char.isdigit() and char in cls.DIGIT_TO_CHAR:
|
||||
for replacement in cls.DIGIT_TO_CHAR[char]:
|
||||
new_variants.add(existing[:i] + replacement + existing[i+1:])
|
||||
|
||||
# If it's a letter that looks like a digit, add digit variant
|
||||
if char in cls.CHAR_TO_DIGIT:
|
||||
new_variants.add(existing[:i] + cls.CHAR_TO_DIGIT[char] + existing[i+1:])
|
||||
|
||||
variants.update(new_variants)
|
||||
|
||||
# Limit explosion - only keep reasonable number
|
||||
if len(variants) > 100:
|
||||
break
|
||||
|
||||
return list(variants)
|
||||
|
||||
@classmethod
|
||||
def correct_swedish_term(cls, text: str) -> str:
|
||||
"""
|
||||
Correct common Swedish invoice terms that may have OCR errors.
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
|
||||
for canonical, variants in cls.SWEDISH_TERM_CORRECTIONS.items():
|
||||
for variant in variants:
|
||||
if variant.lower() in text_lower:
|
||||
# Replace with canonical form (preserving case of first letter)
|
||||
pattern = re.compile(re.escape(variant), re.IGNORECASE)
|
||||
if text[0].isupper():
|
||||
replacement = canonical.capitalize()
|
||||
else:
|
||||
replacement = canonical
|
||||
text = pattern.sub(replacement, text)
|
||||
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def extract_with_context(cls, text: str, field_type: str) -> Optional[str]:
|
||||
"""
|
||||
Extract a field value using context indicators.
|
||||
|
||||
Looks for patterns like "Fakturanr: 12345" and extracts "12345".
|
||||
"""
|
||||
patterns = cls.CONTEXT_INDICATORS.get(field_type, [])
|
||||
|
||||
for pattern in patterns:
|
||||
# Look for pattern followed by value
|
||||
full_pattern = pattern + r'([^\s,;]+)'
|
||||
match = re.search(full_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_likely_ocr_error(cls, char1: str, char2: str) -> bool:
|
||||
"""
|
||||
Check if two characters are commonly confused in OCR.
|
||||
"""
|
||||
pair = (char1, char2)
|
||||
reverse_pair = (char2, char1)
|
||||
|
||||
for p in cls.CONFUSION_PAIRS:
|
||||
if pair == p or reverse_pair == p:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def count_potential_ocr_errors(cls, s1: str, s2: str) -> tuple[int, int]:
|
||||
"""
|
||||
Count how many character differences between two strings
|
||||
are likely OCR errors vs other differences.
|
||||
|
||||
Returns: (ocr_errors, other_errors)
|
||||
"""
|
||||
if len(s1) != len(s2):
|
||||
return (0, abs(len(s1) - len(s2)))
|
||||
|
||||
ocr_errors = 0
|
||||
other_errors = 0
|
||||
|
||||
for c1, c2 in zip(s1, s2):
|
||||
if c1 != c2:
|
||||
if cls.is_likely_ocr_error(c1, c2):
|
||||
ocr_errors += 1
|
||||
else:
|
||||
other_errors += 1
|
||||
|
||||
return (ocr_errors, other_errors)
|
||||
|
||||
@classmethod
|
||||
def suggest_corrections(cls, text: str, expected_type: str = 'digit') -> list[tuple[str, float]]:
|
||||
"""
|
||||
Suggest possible corrections for a text with confidence scores.
|
||||
|
||||
Returns list of (corrected_text, confidence) tuples, sorted by confidence.
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if expected_type == 'digit':
|
||||
# Apply digit corrections with different levels of aggressiveness
|
||||
mild = cls.correct_digits(text, aggressive=False)
|
||||
if mild.corrected != text:
|
||||
suggestions.append((mild.corrected, mild.confidence))
|
||||
|
||||
aggressive = cls.correct_digits(text, aggressive=True)
|
||||
if aggressive.corrected != text and aggressive.corrected != mild.corrected:
|
||||
suggestions.append((aggressive.corrected, aggressive.confidence * 0.9))
|
||||
|
||||
# Generate variants
|
||||
variants = cls.generate_digit_variants(text)
|
||||
for variant in variants[:10]: # Limit to top 10
|
||||
if variant != text and variant not in [s[0] for s in suggestions]:
|
||||
# Lower confidence for variants
|
||||
suggestions.append((variant, 0.7))
|
||||
|
||||
# Sort by confidence
|
||||
suggestions.sort(key=lambda x: x[1], reverse=True)
|
||||
return suggestions
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience functions
|
||||
# =========================================================================
|
||||
|
||||
def correct_ocr_digits(text: str, aggressive: bool = False) -> str:
|
||||
"""Convenience function to correct OCR digit errors."""
|
||||
return OCRCorrections.correct_digits(text, aggressive).corrected
|
||||
|
||||
|
||||
def generate_ocr_variants(text: str) -> list[str]:
|
||||
"""Convenience function to generate OCR variants."""
|
||||
return OCRCorrections.generate_digit_variants(text)
|
||||
|
||||
|
||||
def is_ocr_confusion(char1: str, char2: str) -> bool:
|
||||
"""Convenience function to check if characters are OCR confusable."""
|
||||
return OCRCorrections.is_likely_ocr_error(char1, char2)
|
||||
399
src/utils/test_advanced_utils.py
Normal file
399
src/utils/test_advanced_utils.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Tests for advanced utility modules:
|
||||
- FuzzyMatcher
|
||||
- OCRCorrections
|
||||
- ContextExtractor
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
|
||||
from .ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants
|
||||
from .context_extractor import ContextExtractor, extract_field_with_context
|
||||
|
||||
|
||||
class TestFuzzyMatcher:
|
||||
"""Tests for FuzzyMatcher class."""
|
||||
|
||||
def test_levenshtein_distance_identical(self):
|
||||
"""Test distance for identical strings."""
|
||||
assert FuzzyMatcher.levenshtein_distance("hello", "hello") == 0
|
||||
|
||||
def test_levenshtein_distance_one_char(self):
|
||||
"""Test distance for one character difference."""
|
||||
assert FuzzyMatcher.levenshtein_distance("hello", "hallo") == 1
|
||||
assert FuzzyMatcher.levenshtein_distance("hello", "hell") == 1
|
||||
assert FuzzyMatcher.levenshtein_distance("hello", "helloo") == 1
|
||||
|
||||
def test_levenshtein_distance_multiple(self):
|
||||
"""Test distance for multiple differences."""
|
||||
assert FuzzyMatcher.levenshtein_distance("hello", "world") == 4
|
||||
assert FuzzyMatcher.levenshtein_distance("", "hello") == 5
|
||||
|
||||
def test_similarity_ratio_identical(self):
|
||||
"""Test similarity for identical strings."""
|
||||
assert FuzzyMatcher.similarity_ratio("hello", "hello") == 1.0
|
||||
|
||||
def test_similarity_ratio_similar(self):
|
||||
"""Test similarity for similar strings."""
|
||||
ratio = FuzzyMatcher.similarity_ratio("hello", "hallo")
|
||||
assert 0.8 <= ratio <= 0.9 # One char different in 5-char string
|
||||
|
||||
def test_similarity_ratio_different(self):
|
||||
"""Test similarity for different strings."""
|
||||
ratio = FuzzyMatcher.similarity_ratio("hello", "world")
|
||||
assert ratio < 0.5
|
||||
|
||||
def test_ocr_aware_similarity_exact(self):
|
||||
"""Test OCR-aware similarity for exact match."""
|
||||
assert FuzzyMatcher.ocr_aware_similarity("12345", "12345") == 1.0
|
||||
|
||||
def test_ocr_aware_similarity_ocr_error(self):
|
||||
"""Test OCR-aware similarity with OCR error."""
|
||||
# O instead of 0
|
||||
score = FuzzyMatcher.ocr_aware_similarity("1234O", "12340")
|
||||
assert score >= 0.9 # Should be high due to OCR correction
|
||||
|
||||
def test_ocr_aware_similarity_multiple_errors(self):
|
||||
"""Test OCR-aware similarity with multiple OCR errors."""
|
||||
# l instead of 1, O instead of 0
|
||||
score = FuzzyMatcher.ocr_aware_similarity("l234O", "12340")
|
||||
assert score >= 0.85
|
||||
|
||||
def test_match_digits_exact(self):
|
||||
"""Test digit matching for exact match."""
|
||||
result = FuzzyMatcher.match_digits("12345", "12345")
|
||||
assert result.matched is True
|
||||
assert result.score == 1.0
|
||||
assert result.match_type == 'exact'
|
||||
|
||||
def test_match_digits_with_separators(self):
|
||||
"""Test digit matching ignoring separators."""
|
||||
result = FuzzyMatcher.match_digits("123-4567", "1234567")
|
||||
assert result.matched is True
|
||||
assert result.normalized_ocr == "1234567"
|
||||
|
||||
def test_match_digits_ocr_error(self):
|
||||
"""Test digit matching with OCR error."""
|
||||
result = FuzzyMatcher.match_digits("556O234567", "5560234567")
|
||||
assert result.matched is True
|
||||
assert result.score >= 0.9
|
||||
|
||||
def test_match_amount_exact(self):
|
||||
"""Test amount matching for exact values."""
|
||||
result = FuzzyMatcher.match_amount("1234.56", "1234.56")
|
||||
assert result.matched is True
|
||||
assert result.score == 1.0
|
||||
|
||||
def test_match_amount_different_formats(self):
|
||||
"""Test amount matching with different formats."""
|
||||
# Swedish vs US format
|
||||
result = FuzzyMatcher.match_amount("1234,56", "1234.56")
|
||||
assert result.matched is True
|
||||
assert result.score >= 0.99
|
||||
|
||||
def test_match_amount_with_spaces(self):
|
||||
"""Test amount matching with thousand separators."""
|
||||
result = FuzzyMatcher.match_amount("1 234,56", "1234.56")
|
||||
assert result.matched is True
|
||||
|
||||
def test_match_date_same_date_different_format(self):
|
||||
"""Test date matching with different formats."""
|
||||
result = FuzzyMatcher.match_date("2024-12-29", "29.12.2024")
|
||||
assert result.matched is True
|
||||
assert result.score >= 0.9
|
||||
|
||||
def test_match_date_different_dates(self):
|
||||
"""Test date matching with different dates."""
|
||||
result = FuzzyMatcher.match_date("2024-12-29", "2024-12-30")
|
||||
assert result.matched is False
|
||||
|
||||
def test_match_string_exact(self):
|
||||
"""Test string matching for exact match."""
|
||||
result = FuzzyMatcher.match_string("Hello World", "Hello World")
|
||||
assert result.matched is True
|
||||
assert result.match_type == 'exact'
|
||||
|
||||
def test_match_string_case_insensitive(self):
|
||||
"""Test string matching case insensitivity."""
|
||||
result = FuzzyMatcher.match_string("HELLO", "hello")
|
||||
assert result.matched is True
|
||||
assert result.match_type == 'normalized'
|
||||
|
||||
def test_match_string_ocr_corrected(self):
|
||||
"""Test string matching with OCR corrections."""
|
||||
result = FuzzyMatcher.match_string("5561234567", "556l234567")
|
||||
assert result.matched is True
|
||||
|
||||
def test_match_field_routes_correctly(self):
|
||||
"""Test that match_field routes to correct matcher."""
|
||||
# Amount field
|
||||
result = FuzzyMatcher.match_field("Amount", "1234.56", "1234,56")
|
||||
assert result.matched is True
|
||||
|
||||
# Date field
|
||||
result = FuzzyMatcher.match_field("InvoiceDate", "2024-12-29", "29.12.2024")
|
||||
assert result.matched is True
|
||||
|
||||
def test_find_best_match(self):
|
||||
"""Test finding best match from candidates."""
|
||||
candidates = ["12345", "12346", "99999"]
|
||||
result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber")
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == "12345"
|
||||
assert result[1].score == 1.0
|
||||
|
||||
def test_find_best_match_no_match(self):
|
||||
"""Test finding best match when none above threshold."""
|
||||
candidates = ["99999", "88888", "77777"]
|
||||
result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOCRCorrections:
|
||||
"""Tests for OCRCorrections class."""
|
||||
|
||||
def test_correct_digits_simple(self):
|
||||
"""Test simple digit correction."""
|
||||
result = OCRCorrections.correct_digits("556O23", aggressive=False)
|
||||
assert result.corrected == "556023"
|
||||
assert len(result.corrections_applied) == 1
|
||||
|
||||
def test_correct_digits_multiple(self):
|
||||
"""Test multiple digit corrections."""
|
||||
result = OCRCorrections.correct_digits("5S6l23", aggressive=False)
|
||||
assert result.corrected == "556123"
|
||||
assert len(result.corrections_applied) == 2
|
||||
|
||||
def test_correct_digits_aggressive(self):
|
||||
"""Test aggressive mode corrects all potential errors."""
|
||||
result = OCRCorrections.correct_digits("AB123", aggressive=True)
|
||||
# A -> 4, B -> 8
|
||||
assert result.corrected == "48123"
|
||||
|
||||
def test_correct_digits_non_aggressive(self):
|
||||
"""Test non-aggressive mode only corrects adjacent."""
|
||||
result = OCRCorrections.correct_digits("AB 123", aggressive=False)
|
||||
# A and B are adjacent to each other and both in CHAR_TO_DIGIT,
|
||||
# so they may be corrected. The key is digits are not affected.
|
||||
assert "123" in result.corrected
|
||||
|
||||
def test_generate_digit_variants(self):
|
||||
"""Test generating OCR variants."""
|
||||
variants = OCRCorrections.generate_digit_variants("10")
|
||||
# Should include original and variants like "1O", "I0", "IO", "l0", etc.
|
||||
assert "10" in variants
|
||||
assert "1O" in variants or "l0" in variants
|
||||
|
||||
def test_generate_digit_variants_limits(self):
|
||||
"""Test that variant generation is limited."""
|
||||
variants = OCRCorrections.generate_digit_variants("1234567890")
|
||||
# Should be limited to prevent explosion (limit is ~100, but may slightly exceed)
|
||||
assert len(variants) <= 150
|
||||
|
||||
def test_is_likely_ocr_error(self):
|
||||
"""Test OCR error detection."""
|
||||
assert OCRCorrections.is_likely_ocr_error('0', 'O') is True
|
||||
assert OCRCorrections.is_likely_ocr_error('O', '0') is True
|
||||
assert OCRCorrections.is_likely_ocr_error('1', 'l') is True
|
||||
assert OCRCorrections.is_likely_ocr_error('5', 'S') is True
|
||||
assert OCRCorrections.is_likely_ocr_error('A', 'Z') is False
|
||||
|
||||
def test_count_potential_ocr_errors(self):
|
||||
"""Test counting OCR errors vs other errors."""
|
||||
ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1O3", "103")
|
||||
assert ocr_errors == 1 # O vs 0
|
||||
assert other_errors == 0
|
||||
|
||||
ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1X3", "103")
|
||||
assert ocr_errors == 0
|
||||
assert other_errors == 1 # X vs 0, not a known pair
|
||||
|
||||
def test_suggest_corrections(self):
|
||||
"""Test correction suggestions."""
|
||||
suggestions = OCRCorrections.suggest_corrections("556O23", expected_type='digit')
|
||||
assert len(suggestions) > 0
|
||||
# First suggestion should be the corrected version
|
||||
assert suggestions[0][0] == "556023"
|
||||
|
||||
def test_convenience_function_correct(self):
|
||||
"""Test convenience function."""
|
||||
assert correct_ocr_digits("556O23") == "556023"
|
||||
|
||||
def test_convenience_function_variants(self):
|
||||
"""Test convenience function for variants."""
|
||||
variants = generate_ocr_variants("10")
|
||||
assert "10" in variants
|
||||
|
||||
|
||||
class TestContextExtractor:
|
||||
"""Tests for ContextExtractor class."""
|
||||
|
||||
def test_extract_invoice_number_with_label(self):
|
||||
"""Test extracting invoice number after label."""
|
||||
text = "Fakturanummer: 12345678"
|
||||
candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber")
|
||||
|
||||
assert len(candidates) > 0
|
||||
assert candidates[0].value == "12345678"
|
||||
assert candidates[0].extraction_method == 'label'
|
||||
|
||||
def test_extract_invoice_number_swedish(self):
|
||||
"""Test extracting with Swedish label."""
|
||||
text = "Faktura nr: A12345"
|
||||
candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber")
|
||||
|
||||
assert len(candidates) > 0
|
||||
# Should extract A12345 or 12345
|
||||
|
||||
def test_extract_amount_with_label(self):
|
||||
"""Test extracting amount after label."""
|
||||
text = "Att betala: 1 234,56"
|
||||
candidates = ContextExtractor.extract_with_label(text, "Amount")
|
||||
|
||||
assert len(candidates) > 0
|
||||
|
||||
def test_extract_amount_total(self):
|
||||
"""Test extracting with total label."""
|
||||
text = "Total: 5678,90 kr"
|
||||
candidates = ContextExtractor.extract_with_label(text, "Amount")
|
||||
|
||||
assert len(candidates) > 0
|
||||
|
||||
def test_extract_date_with_label(self):
|
||||
"""Test extracting date after label."""
|
||||
text = "Fakturadatum: 2024-12-29"
|
||||
candidates = ContextExtractor.extract_with_label(text, "InvoiceDate")
|
||||
|
||||
assert len(candidates) > 0
|
||||
assert "2024-12-29" in candidates[0].value
|
||||
|
||||
def test_extract_due_date(self):
|
||||
"""Test extracting due date."""
|
||||
text = "Förfallodatum: 2025-01-15"
|
||||
candidates = ContextExtractor.extract_with_label(text, "InvoiceDueDate")
|
||||
|
||||
assert len(candidates) > 0
|
||||
|
||||
def test_extract_bankgiro(self):
|
||||
"""Test extracting Bankgiro."""
|
||||
text = "Bankgiro: 1234-5678"
|
||||
candidates = ContextExtractor.extract_with_label(text, "Bankgiro")
|
||||
|
||||
assert len(candidates) > 0
|
||||
assert "1234-5678" in candidates[0].value or "12345678" in candidates[0].value
|
||||
|
||||
def test_extract_plusgiro(self):
|
||||
"""Test extracting Plusgiro."""
|
||||
text = "Plusgiro: 1234567-8"
|
||||
candidates = ContextExtractor.extract_with_label(text, "Plusgiro")
|
||||
|
||||
assert len(candidates) > 0
|
||||
|
||||
def test_extract_ocr(self):
|
||||
"""Test extracting OCR number."""
|
||||
text = "OCR: 12345678901234"
|
||||
candidates = ContextExtractor.extract_with_label(text, "OCR")
|
||||
|
||||
assert len(candidates) > 0
|
||||
assert candidates[0].value == "12345678901234"
|
||||
|
||||
def test_extract_org_number(self):
|
||||
"""Test extracting organization number."""
|
||||
text = "Org.nr: 556123-4567"
|
||||
candidates = ContextExtractor.extract_with_label(text, "supplier_organisation_number")
|
||||
|
||||
assert len(candidates) > 0
|
||||
|
||||
def test_extract_customer_number(self):
|
||||
"""Test extracting customer number."""
|
||||
text = "Kundnummer: EMM 256-6"
|
||||
candidates = ContextExtractor.extract_with_label(text, "customer_number")
|
||||
|
||||
assert len(candidates) > 0
|
||||
|
||||
def test_extract_field_returns_sorted(self):
|
||||
"""Test that extract_field returns sorted by confidence."""
|
||||
text = "Fakturanummer: 12345 Invoice number: 67890"
|
||||
candidates = ContextExtractor.extract_field(text, "InvoiceNumber")
|
||||
|
||||
if len(candidates) > 1:
|
||||
# Should be sorted by confidence (descending)
|
||||
assert candidates[0].confidence >= candidates[1].confidence
|
||||
|
||||
def test_extract_best(self):
|
||||
"""Test extract_best returns single best candidate."""
|
||||
text = "Fakturanummer: 12345678"
|
||||
best = ContextExtractor.extract_best(text, "InvoiceNumber")
|
||||
|
||||
assert best is not None
|
||||
assert best.value == "12345678"
|
||||
|
||||
def test_extract_best_no_match(self):
|
||||
"""Test extract_best returns None when no match."""
|
||||
text = "No invoice information here"
|
||||
best = ContextExtractor.extract_best(text, "InvoiceNumber", validate=True)
|
||||
|
||||
# May or may not find something depending on validation
|
||||
|
||||
def test_extract_all_fields(self):
|
||||
"""Test extracting all fields from text."""
|
||||
text = """
|
||||
Fakturanummer: 12345
|
||||
Datum: 2024-12-29
|
||||
Belopp: 1234,56
|
||||
Bankgiro: 1234-5678
|
||||
"""
|
||||
results = ContextExtractor.extract_all_fields(text)
|
||||
|
||||
# Should find at least some fields
|
||||
assert len(results) > 0
|
||||
|
||||
def test_identify_field_type(self):
|
||||
"""Test identifying field type from context."""
|
||||
text = "Fakturanummer: 12345"
|
||||
field_type = ContextExtractor.identify_field_type(text, "12345")
|
||||
|
||||
assert field_type == "InvoiceNumber"
|
||||
|
||||
def test_convenience_function_extract(self):
|
||||
"""Test convenience function."""
|
||||
text = "Fakturanummer: 12345678"
|
||||
value = extract_field_with_context(text, "InvoiceNumber")
|
||||
|
||||
assert value == "12345678"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests combining multiple modules."""
|
||||
|
||||
def test_fuzzy_match_with_ocr_correction(self):
|
||||
"""Test fuzzy matching with OCR correction."""
|
||||
# Simulate OCR error: 0 -> O
|
||||
ocr_text = "556O234567"
|
||||
expected = "5560234567"
|
||||
|
||||
# First correct
|
||||
corrected = correct_ocr_digits(ocr_text)
|
||||
assert corrected == expected
|
||||
|
||||
# Then match
|
||||
result = FuzzyMatcher.match_digits(ocr_text, expected)
|
||||
assert result.matched is True
|
||||
|
||||
def test_context_extraction_with_fuzzy_match(self):
|
||||
"""Test extracting value and fuzzy matching."""
|
||||
text = "Fakturanummer: 1234S678" # S is OCR error for 5
|
||||
|
||||
# Extract
|
||||
candidate = ContextExtractor.extract_best(text, "InvoiceNumber", validate=False)
|
||||
assert candidate is not None
|
||||
|
||||
# Fuzzy match against expected
|
||||
result = FuzzyMatcher.match_string(candidate.value, "12345678")
|
||||
# Might match depending on threshold
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
235
src/utils/test_utils.py
Normal file
235
src/utils/test_utils.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Tests for shared utility modules.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from .text_cleaner import TextCleaner
|
||||
from .format_variants import FormatVariants
|
||||
from .validators import FieldValidators
|
||||
|
||||
|
||||
class TestTextCleaner:
|
||||
"""Tests for TextCleaner class."""
|
||||
|
||||
def test_clean_unicode_dashes(self):
|
||||
"""Test normalization of various dash types."""
|
||||
# en-dash
|
||||
assert TextCleaner.clean_unicode("556123–4567") == "556123-4567"
|
||||
# em-dash
|
||||
assert TextCleaner.clean_unicode("556123—4567") == "556123-4567"
|
||||
# minus sign
|
||||
assert TextCleaner.clean_unicode("556123−4567") == "556123-4567"
|
||||
|
||||
def test_clean_unicode_spaces(self):
|
||||
"""Test normalization of various space types."""
|
||||
# non-breaking space
|
||||
assert TextCleaner.clean_unicode("1\xa0234") == "1 234"
|
||||
# zero-width space removed
|
||||
assert TextCleaner.clean_unicode("123\u200b456") == "123456"
|
||||
|
||||
def test_ocr_digit_corrections(self):
|
||||
"""Test OCR error corrections for digit fields."""
|
||||
# O -> 0
|
||||
assert TextCleaner.apply_ocr_digit_corrections("556O23") == "556023"
|
||||
# l -> 1
|
||||
assert TextCleaner.apply_ocr_digit_corrections("556l23") == "556123"
|
||||
# S -> 5
|
||||
assert TextCleaner.apply_ocr_digit_corrections("5S6123") == "556123"
|
||||
# Mixed
|
||||
assert TextCleaner.apply_ocr_digit_corrections("S56l23-4S67") == "556123-4567"
|
||||
|
||||
def test_extract_digits(self):
|
||||
"""Test digit extraction with OCR correction."""
|
||||
assert TextCleaner.extract_digits("556123-4567") == "5561234567"
|
||||
assert TextCleaner.extract_digits("556O23-4567", apply_ocr_correction=True) == "5560234567"
|
||||
# Without OCR correction, only extracts actual digits
|
||||
assert TextCleaner.extract_digits("ABC 123 DEF", apply_ocr_correction=False) == "123"
|
||||
# With OCR correction, standalone letters are not converted
|
||||
# (they need to be adjacent to digits to be corrected)
|
||||
assert TextCleaner.extract_digits("A 123 B", apply_ocr_correction=True) == "123"
|
||||
|
||||
def test_normalize_amount_text(self):
|
||||
"""Test amount text normalization."""
|
||||
assert TextCleaner.normalize_amount_text("1 234,56 kr") == "1234,56"
|
||||
assert TextCleaner.normalize_amount_text("SEK 1234.56") == "1234.56"
|
||||
assert TextCleaner.normalize_amount_text("1 234 567,89 kronor") == "1234567,89"
|
||||
|
||||
|
||||
class TestFormatVariants:
|
||||
"""Tests for FormatVariants class."""
|
||||
|
||||
def test_organisation_number_variants(self):
|
||||
"""Test organisation number variant generation."""
|
||||
variants = FormatVariants.organisation_number_variants("5561234567")
|
||||
|
||||
assert "5561234567" in variants # 纯数字
|
||||
assert "556123-4567" in variants # 带横线
|
||||
assert "SE556123456701" in variants # VAT格式
|
||||
|
||||
def test_organisation_number_from_vat(self):
|
||||
"""Test extracting org number from VAT format."""
|
||||
variants = FormatVariants.organisation_number_variants("SE556123456701")
|
||||
|
||||
assert "5561234567" in variants
|
||||
assert "556123-4567" in variants
|
||||
|
||||
def test_bankgiro_variants(self):
|
||||
"""Test Bankgiro variant generation."""
|
||||
# 8 digits
|
||||
variants = FormatVariants.bankgiro_variants("53939484")
|
||||
assert "53939484" in variants
|
||||
assert "5393-9484" in variants
|
||||
|
||||
# 7 digits
|
||||
variants = FormatVariants.bankgiro_variants("1234567")
|
||||
assert "1234567" in variants
|
||||
assert "123-4567" in variants
|
||||
|
||||
def test_plusgiro_variants(self):
|
||||
"""Test Plusgiro variant generation."""
|
||||
variants = FormatVariants.plusgiro_variants("12345678")
|
||||
assert "12345678" in variants
|
||||
assert "1234567-8" in variants
|
||||
|
||||
def test_amount_variants(self):
|
||||
"""Test amount variant generation."""
|
||||
variants = FormatVariants.amount_variants("1234.56")
|
||||
|
||||
assert "1234.56" in variants
|
||||
assert "1234,56" in variants
|
||||
assert "1 234,56" in variants or "1234,56" in variants # Swedish format
|
||||
|
||||
def test_date_variants(self):
|
||||
"""Test date variant generation."""
|
||||
variants = FormatVariants.date_variants("2024-12-29")
|
||||
|
||||
assert "2024-12-29" in variants # ISO
|
||||
assert "29.12.2024" in variants # European
|
||||
assert "29/12/2024" in variants # European slash
|
||||
assert "20241229" in variants # Compact
|
||||
assert "29 december 2024" in variants # Swedish text
|
||||
|
||||
def test_invoice_number_variants(self):
|
||||
"""Test invoice number variant generation."""
|
||||
variants = FormatVariants.invoice_number_variants("INV-2024-001")
|
||||
|
||||
assert "INV-2024-001" in variants
|
||||
assert "INV2024001" in variants # No separators
|
||||
assert "inv-2024-001" in variants # Lowercase
|
||||
|
||||
def test_get_variants_dispatch(self):
|
||||
"""Test get_variants dispatches to correct method."""
|
||||
# Organisation number
|
||||
org_variants = FormatVariants.get_variants("supplier_organisation_number", "5561234567")
|
||||
assert "556123-4567" in org_variants
|
||||
|
||||
# Bankgiro
|
||||
bg_variants = FormatVariants.get_variants("Bankgiro", "53939484")
|
||||
assert "5393-9484" in bg_variants
|
||||
|
||||
# Amount
|
||||
amount_variants = FormatVariants.get_variants("Amount", "1234.56")
|
||||
assert "1234,56" in amount_variants
|
||||
|
||||
|
||||
class TestFieldValidators:
|
||||
"""Tests for FieldValidators class."""
|
||||
|
||||
def test_luhn_checksum_valid(self):
|
||||
"""Test Luhn validation with valid numbers."""
|
||||
# Valid Bankgiro numbers (with correct check digit)
|
||||
assert FieldValidators.luhn_checksum("53939484") is True
|
||||
# Valid OCR numbers
|
||||
assert FieldValidators.luhn_checksum("1234567897") is True # check digit 7
|
||||
|
||||
def test_luhn_checksum_invalid(self):
|
||||
"""Test Luhn validation with invalid numbers."""
|
||||
assert FieldValidators.luhn_checksum("53939485") is False # wrong check digit
|
||||
assert FieldValidators.luhn_checksum("1234567890") is False
|
||||
|
||||
def test_calculate_luhn_check_digit(self):
|
||||
"""Test Luhn check digit calculation."""
|
||||
# For "123456789", the check digit should make it valid
|
||||
check = FieldValidators.calculate_luhn_check_digit("123456789")
|
||||
full_number = "123456789" + str(check)
|
||||
assert FieldValidators.luhn_checksum(full_number) is True
|
||||
|
||||
def test_is_valid_organisation_number(self):
|
||||
"""Test organisation number validation."""
|
||||
# Valid (with correct Luhn checksum)
|
||||
# Note: Need actual valid org numbers for this test
|
||||
# Using a well-known one: 5565006245 (placeholder)
|
||||
pass # Skip without real test data
|
||||
|
||||
def test_is_valid_bankgiro(self):
|
||||
"""Test Bankgiro validation."""
|
||||
# Valid 8-digit Bankgiro with Luhn
|
||||
assert FieldValidators.is_valid_bankgiro("53939484") is True
|
||||
# Invalid (wrong length)
|
||||
assert FieldValidators.is_valid_bankgiro("123") is False
|
||||
assert FieldValidators.is_valid_bankgiro("123456789") is False # 9 digits
|
||||
|
||||
def test_format_bankgiro(self):
|
||||
"""Test Bankgiro formatting."""
|
||||
assert FieldValidators.format_bankgiro("53939484") == "5393-9484"
|
||||
assert FieldValidators.format_bankgiro("1234567") == "123-4567"
|
||||
assert FieldValidators.format_bankgiro("123") is None
|
||||
|
||||
def test_is_valid_plusgiro(self):
|
||||
"""Test Plusgiro validation."""
|
||||
# Valid Plusgiro (2-8 digits with Luhn)
|
||||
assert FieldValidators.is_valid_plusgiro("18") is True # minimal
|
||||
# Invalid (wrong length)
|
||||
assert FieldValidators.is_valid_plusgiro("1") is False
|
||||
|
||||
def test_format_plusgiro(self):
|
||||
"""Test Plusgiro formatting."""
|
||||
assert FieldValidators.format_plusgiro("12345678") == "1234567-8"
|
||||
assert FieldValidators.format_plusgiro("123456") == "12345-6"
|
||||
|
||||
def test_is_valid_amount(self):
|
||||
"""Test amount validation."""
|
||||
assert FieldValidators.is_valid_amount("1234.56") is True
|
||||
assert FieldValidators.is_valid_amount("1 234,56") is True
|
||||
assert FieldValidators.is_valid_amount("abc") is False
|
||||
assert FieldValidators.is_valid_amount("-100") is False # below min
|
||||
assert FieldValidators.is_valid_amount("100000000") is False # above max
|
||||
|
||||
def test_parse_amount(self):
|
||||
"""Test amount parsing."""
|
||||
assert FieldValidators.parse_amount("1234.56") == 1234.56
|
||||
assert FieldValidators.parse_amount("1 234,56") == 1234.56
|
||||
assert FieldValidators.parse_amount("1.234,56") == 1234.56 # German
|
||||
assert FieldValidators.parse_amount("1,234.56") == 1234.56 # US
|
||||
|
||||
def test_is_valid_date(self):
|
||||
"""Test date validation."""
|
||||
assert FieldValidators.is_valid_date("2024-12-29") is True
|
||||
assert FieldValidators.is_valid_date("29.12.2024") is True
|
||||
assert FieldValidators.is_valid_date("29/12/2024") is True
|
||||
assert FieldValidators.is_valid_date("not a date") is False
|
||||
assert FieldValidators.is_valid_date("1900-01-01") is False # out of range
|
||||
|
||||
def test_format_date_iso(self):
|
||||
"""Test date ISO formatting."""
|
||||
assert FieldValidators.format_date_iso("29.12.2024") == "2024-12-29"
|
||||
assert FieldValidators.format_date_iso("29/12/2024") == "2024-12-29"
|
||||
assert FieldValidators.format_date_iso("2024-12-29") == "2024-12-29"
|
||||
|
||||
def test_validate_field_dispatch(self):
|
||||
"""Test validate_field dispatches correctly."""
|
||||
# Organisation number
|
||||
is_valid, error = FieldValidators.validate_field("supplier_organisation_number", "")
|
||||
assert is_valid is False
|
||||
|
||||
# Amount
|
||||
is_valid, error = FieldValidators.validate_field("Amount", "1234.56")
|
||||
assert is_valid is True
|
||||
|
||||
# Date
|
||||
is_valid, error = FieldValidators.validate_field("InvoiceDate", "2024-12-29")
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
244
src/utils/text_cleaner.py
Normal file
244
src/utils/text_cleaner.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Text Cleaning Module
|
||||
|
||||
Provides text normalization and OCR error correction utilities.
|
||||
Used by both inference (field_extractor) and matching (normalizer) stages.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TextCleaner:
|
||||
"""
|
||||
Unified text cleaning utilities for invoice processing.
|
||||
|
||||
Handles:
|
||||
- Unicode normalization (zero-width chars, dash variants)
|
||||
- OCR error correction (O/0, l/1, etc.)
|
||||
- Whitespace normalization
|
||||
- Swedish-specific character handling
|
||||
"""
|
||||
|
||||
# OCR常见错误修正映射 (用于数字字段)
|
||||
# 当我们期望数字时,这些字符常被误识别
|
||||
OCR_DIGIT_CORRECTIONS = {
|
||||
'O': '0', 'o': '0', # 字母O -> 数字0
|
||||
'Q': '0', # Q 有时像 0
|
||||
'l': '1', 'I': '1', # 小写L/大写I -> 数字1
|
||||
'|': '1', # 竖线 -> 1
|
||||
'i': '1', # 小写i -> 1
|
||||
'S': '5', 's': '5', # S -> 5
|
||||
'B': '8', # B -> 8
|
||||
'Z': '2', 'z': '2', # Z -> 2
|
||||
'G': '6', 'g': '6', # G -> 6 (在某些字体中)
|
||||
'A': '4', # A -> 4 (在某些字体中)
|
||||
'T': '7', # T -> 7 (在某些字体中)
|
||||
'q': '9', # q -> 9
|
||||
'D': '0', # D -> 0
|
||||
}
|
||||
|
||||
# 反向映射:数字被误识别为字母的情况 (用于字母数字混合字段)
|
||||
OCR_LETTER_CORRECTIONS = {
|
||||
'0': 'O',
|
||||
'1': 'I',
|
||||
'5': 'S',
|
||||
'8': 'B',
|
||||
'2': 'Z',
|
||||
}
|
||||
|
||||
# Unicode 特殊字符归一化
|
||||
UNICODE_NORMALIZATIONS = {
|
||||
# 各种横线/破折号 -> 标准连字符
|
||||
'\u2013': '-', # en-dash –
|
||||
'\u2014': '-', # em-dash —
|
||||
'\u2212': '-', # minus sign −
|
||||
'\u00b7': '-', # middle dot ·
|
||||
'\u2010': '-', # hyphen ‐
|
||||
'\u2011': '-', # non-breaking hyphen ‑
|
||||
'\u2012': '-', # figure dash ‒
|
||||
'\u2015': '-', # horizontal bar ―
|
||||
|
||||
# 各种空格 -> 标准空格
|
||||
'\u00a0': ' ', # non-breaking space
|
||||
'\u2002': ' ', # en space
|
||||
'\u2003': ' ', # em space
|
||||
'\u2009': ' ', # thin space
|
||||
'\u200a': ' ', # hair space
|
||||
|
||||
# 零宽字符 -> 删除
|
||||
'\u200b': '', # zero-width space
|
||||
'\u200c': '', # zero-width non-joiner
|
||||
'\u200d': '', # zero-width joiner
|
||||
'\ufeff': '', # BOM / zero-width no-break space
|
||||
|
||||
# 各种引号 -> 标准引号
|
||||
'\u2018': "'", # left single quote '
|
||||
'\u2019': "'", # right single quote '
|
||||
'\u201c': '"', # left double quote "
|
||||
'\u201d': '"', # right double quote "
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def clean_unicode(cls, text: str) -> str:
|
||||
"""
|
||||
Normalize Unicode characters to ASCII equivalents.
|
||||
|
||||
Handles:
|
||||
- Various dash types -> standard hyphen (-)
|
||||
- Various spaces -> standard space
|
||||
- Zero-width characters -> removed
|
||||
- Various quotes -> standard quotes
|
||||
"""
|
||||
for unicode_char, replacement in cls.UNICODE_NORMALIZATIONS.items():
|
||||
text = text.replace(unicode_char, replacement)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def normalize_whitespace(cls, text: str) -> str:
|
||||
"""Collapse multiple whitespace to single space and strip."""
|
||||
return ' '.join(text.split())
|
||||
|
||||
@classmethod
|
||||
def clean_text(cls, text: str) -> str:
|
||||
"""
|
||||
Full text cleaning pipeline.
|
||||
|
||||
1. Normalize Unicode
|
||||
2. Normalize whitespace
|
||||
3. Strip
|
||||
|
||||
This is safe for all field types.
|
||||
"""
|
||||
text = cls.clean_unicode(text)
|
||||
text = cls.normalize_whitespace(text)
|
||||
return text.strip()
|
||||
|
||||
@classmethod
|
||||
def apply_ocr_digit_corrections(cls, text: str) -> str:
|
||||
"""
|
||||
Apply OCR error corrections for digit-only fields.
|
||||
|
||||
Use this when the field is expected to contain only digits
|
||||
(e.g., OCR number, organization number digits, etc.)
|
||||
|
||||
Example:
|
||||
"556l23-4S67" -> "556123-4567"
|
||||
"""
|
||||
result = []
|
||||
for char in text:
|
||||
if char in cls.OCR_DIGIT_CORRECTIONS:
|
||||
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
|
||||
else:
|
||||
result.append(char)
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def extract_digits(cls, text: str, apply_ocr_correction: bool = True) -> str:
|
||||
"""
|
||||
Extract only digits from text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
apply_ocr_correction: If True, apply OCR corrections ONLY to characters
|
||||
that are adjacent to digits (not standalone letters)
|
||||
|
||||
Returns:
|
||||
String containing only digits
|
||||
"""
|
||||
if apply_ocr_correction:
|
||||
# 只对看起来像数字序列中的字符应用 OCR 修正
|
||||
# 例如 "556O23" 中的 O 应该修正,但 "ABC 123" 中的 ABC 不应该
|
||||
result = []
|
||||
for i, char in enumerate(text):
|
||||
if char.isdigit():
|
||||
result.append(char)
|
||||
elif char in cls.OCR_DIGIT_CORRECTIONS:
|
||||
# 检查前后是否有数字
|
||||
prev_is_digit = i > 0 and (text[i - 1].isdigit() or text[i - 1] in cls.OCR_DIGIT_CORRECTIONS)
|
||||
next_is_digit = i < len(text) - 1 and (text[i + 1].isdigit() or text[i + 1] in cls.OCR_DIGIT_CORRECTIONS)
|
||||
if prev_is_digit or next_is_digit:
|
||||
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
|
||||
# 其他字符跳过
|
||||
return ''.join(result)
|
||||
else:
|
||||
return re.sub(r'\D', '', text)
|
||||
|
||||
@classmethod
|
||||
def clean_for_digits(cls, text: str) -> str:
|
||||
"""
|
||||
Clean text that should primarily contain digits.
|
||||
|
||||
Pipeline:
|
||||
1. Clean Unicode
|
||||
2. Apply OCR digit corrections
|
||||
3. Normalize whitespace
|
||||
|
||||
Preserves separators (-, /) for formatted numbers like "556123-4567"
|
||||
"""
|
||||
text = cls.clean_unicode(text)
|
||||
text = cls.apply_ocr_digit_corrections(text)
|
||||
text = cls.normalize_whitespace(text)
|
||||
return text.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_ocr_variants(cls, text: str) -> list[str]:
|
||||
"""
|
||||
Generate possible OCR error variants of the input text.
|
||||
|
||||
This is useful for matching: if we have a CSV value,
|
||||
we generate variants that might appear in OCR output.
|
||||
|
||||
Example:
|
||||
"5561234567" -> ["5561234567", "556I234567", "5561234S67", ...]
|
||||
"""
|
||||
variants = {text}
|
||||
|
||||
# 只对数字生成字母变体
|
||||
for digit, letter in cls.OCR_LETTER_CORRECTIONS.items():
|
||||
if digit in text:
|
||||
variants.add(text.replace(digit, letter))
|
||||
|
||||
# 对字母生成数字变体
|
||||
for letter, digit in cls.OCR_DIGIT_CORRECTIONS.items():
|
||||
if letter in text:
|
||||
variants.add(text.replace(letter, digit))
|
||||
|
||||
return list(variants)
|
||||
|
||||
@classmethod
|
||||
def normalize_amount_text(cls, text: str) -> str:
|
||||
"""
|
||||
Normalize amount text for parsing.
|
||||
|
||||
- Removes currency symbols and labels
|
||||
- Normalizes separators
|
||||
- Handles Swedish format (space as thousand separator)
|
||||
"""
|
||||
text = cls.clean_text(text)
|
||||
|
||||
# 移除货币符号和标签 (使用单词边界确保完整匹配)
|
||||
text = re.sub(r'(?i)\b(kr|sek|kronor|öre)\b', '', text)
|
||||
|
||||
# 移除千位分隔空格 (Swedish: "1 234,56" -> "1234,56")
|
||||
# 但保留小数点前的数字
|
||||
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
@classmethod
|
||||
def normalize_for_comparison(cls, text: str) -> str:
|
||||
"""
|
||||
Normalize text for loose comparison.
|
||||
|
||||
- Lowercase
|
||||
- Remove all non-alphanumeric
|
||||
- Apply OCR corrections
|
||||
|
||||
This is the most aggressive normalization, used for fuzzy matching.
|
||||
"""
|
||||
text = cls.clean_text(text)
|
||||
text = text.lower()
|
||||
text = cls.apply_ocr_digit_corrections(text)
|
||||
text = re.sub(r'[^a-z0-9]', '', text)
|
||||
return text
|
||||
393
src/utils/validators.py
Normal file
393
src/utils/validators.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Field Validators Module
|
||||
|
||||
Provides validation functions for Swedish invoice fields.
|
||||
Used by both inference (to validate extracted values) and matching (to filter candidates).
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class FieldValidators:
|
||||
"""
|
||||
Validators for Swedish invoice field values.
|
||||
|
||||
Includes:
|
||||
- Luhn (Mod10) checksum validation
|
||||
- Format validation for specific field types
|
||||
- Range validation for dates and amounts
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Luhn (Mod10) Checksum
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def luhn_checksum(cls, digits: str) -> bool:
|
||||
"""
|
||||
Validate using Luhn (Mod10) algorithm.
|
||||
|
||||
Used for:
|
||||
- Bankgiro numbers
|
||||
- Plusgiro numbers
|
||||
- OCR reference numbers
|
||||
- Swedish organization numbers
|
||||
|
||||
The checksum is valid if the total modulo 10 equals 0.
|
||||
"""
|
||||
# 只保留数字
|
||||
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
|
||||
|
||||
if not digits or not digits.isdigit():
|
||||
return False
|
||||
|
||||
total = 0
|
||||
for i, char in enumerate(reversed(digits)):
|
||||
digit = int(char)
|
||||
if i % 2 == 1: # 从右往左,每隔一位加倍
|
||||
digit *= 2
|
||||
if digit > 9:
|
||||
digit -= 9
|
||||
total += digit
|
||||
|
||||
return total % 10 == 0
|
||||
|
||||
@classmethod
|
||||
def calculate_luhn_check_digit(cls, digits: str) -> int:
|
||||
"""
|
||||
Calculate the Luhn check digit for a number.
|
||||
|
||||
Given a number without check digit, returns the digit that would make it valid.
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
|
||||
|
||||
# 计算现有数字的 Luhn 和
|
||||
total = 0
|
||||
for i, char in enumerate(reversed(digits)):
|
||||
digit = int(char)
|
||||
if i % 2 == 0: # 注意:因为还要加一位,所以偶数位置加倍
|
||||
digit *= 2
|
||||
if digit > 9:
|
||||
digit -= 9
|
||||
total += digit
|
||||
|
||||
# 计算需要的校验位
|
||||
check_digit = (10 - (total % 10)) % 10
|
||||
return check_digit
|
||||
|
||||
# =========================================================================
|
||||
# Organisation Number Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_organisation_number(cls, value: str) -> bool:
|
||||
"""
|
||||
Validate Swedish organisation number.
|
||||
|
||||
Format: NNNNNN-NNNN (10 digits)
|
||||
- First digit: 1-9
|
||||
- Third digit: >= 2 (distinguishes from personal numbers)
|
||||
- Last digit: Luhn check digit
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
# 处理 VAT 格式
|
||||
if len(digits) == 12 and digits.endswith('01'):
|
||||
digits = digits[:10]
|
||||
elif len(digits) == 14 and digits.startswith('46') and digits.endswith('01'):
|
||||
digits = digits[2:12]
|
||||
|
||||
if len(digits) != 10:
|
||||
return False
|
||||
|
||||
# 第一位 1-9
|
||||
if digits[0] == '0':
|
||||
return False
|
||||
|
||||
# 第三位 >= 2 (区分组织号和个人号)
|
||||
# 注意:有些特殊组织可能不符合此规则,所以这里放宽
|
||||
# if int(digits[2]) < 2:
|
||||
# return False
|
||||
|
||||
# Luhn 校验
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
# =========================================================================
|
||||
# Bankgiro Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_bankgiro(cls, value: str) -> bool:
|
||||
"""
|
||||
Validate Swedish Bankgiro number.
|
||||
|
||||
Format: 7 or 8 digits with Luhn checksum
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 7 or len(digits) > 8:
|
||||
return False
|
||||
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
@classmethod
|
||||
def format_bankgiro(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Format Bankgiro number to standard format.
|
||||
|
||||
Returns: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits), or None if invalid
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) == 7:
|
||||
return f"{digits[:3]}-{digits[3:]}"
|
||||
elif len(digits) == 8:
|
||||
return f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Plusgiro Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_plusgiro(cls, value: str) -> bool:
|
||||
"""
|
||||
Validate Swedish Plusgiro number.
|
||||
|
||||
Format: 2-8 digits with Luhn checksum
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 2 or len(digits) > 8:
|
||||
return False
|
||||
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
@classmethod
|
||||
def format_plusgiro(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Format Plusgiro number to standard format.
|
||||
|
||||
Returns: XXXXXXX-X format, or None if invalid
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 2 or len(digits) > 8:
|
||||
return None
|
||||
|
||||
return f"{digits[:-1]}-{digits[-1]}"
|
||||
|
||||
# =========================================================================
|
||||
# OCR Number Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_ocr_number(cls, value: str, validate_checksum: bool = True) -> bool:
|
||||
"""
|
||||
Validate Swedish OCR reference number.
|
||||
|
||||
- Typically 10-25 digits
|
||||
- Usually has Luhn checksum (but not always enforced)
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 5 or len(digits) > 25:
|
||||
return False
|
||||
|
||||
if validate_checksum:
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Amount Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_amount(cls, value: str, min_amount: float = 0.0, max_amount: float = 10_000_000.0) -> bool:
|
||||
"""
|
||||
Validate monetary amount.
|
||||
|
||||
- Must be positive (or at least >= min_amount)
|
||||
- Should be within reasonable range
|
||||
"""
|
||||
try:
|
||||
# 尝试解析
|
||||
text = TextCleaner.normalize_amount_text(value)
|
||||
# 统一为点作为小数分隔符
|
||||
text = text.replace(' ', '').replace(',', '.')
|
||||
# 如果有多个点,保留最后一个
|
||||
if text.count('.') > 1:
|
||||
parts = text.rsplit('.', 1)
|
||||
text = parts[0].replace('.', '') + '.' + parts[1]
|
||||
|
||||
amount = float(text)
|
||||
return min_amount <= amount <= max_amount
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_amount(cls, value: str) -> Optional[float]:
|
||||
"""
|
||||
Parse amount from string, handling various formats.
|
||||
|
||||
Returns float or None if parsing fails.
|
||||
"""
|
||||
try:
|
||||
text = TextCleaner.normalize_amount_text(value)
|
||||
text = text.replace(' ', '')
|
||||
|
||||
# 检测格式并解析
|
||||
# 瑞典/德国格式: 逗号是小数点
|
||||
if re.match(r'^[\d.]+,\d{1,2}$', text):
|
||||
text = text.replace('.', '').replace(',', '.')
|
||||
# 美国格式: 点是小数点
|
||||
elif re.match(r'^[\d,]+\.\d{1,2}$', text):
|
||||
text = text.replace(',', '')
|
||||
else:
|
||||
# 简单格式
|
||||
text = text.replace(',', '.')
|
||||
if text.count('.') > 1:
|
||||
parts = text.rsplit('.', 1)
|
||||
text = parts[0].replace('.', '') + '.' + parts[1]
|
||||
|
||||
return float(text)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Date Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_date(cls, value: str, min_year: int = 2000, max_year: int = 2100) -> bool:
|
||||
"""
|
||||
Validate date string.
|
||||
|
||||
- Year should be within reasonable range
|
||||
- Month 1-12
|
||||
- Day 1-31 (basic check)
|
||||
"""
|
||||
parsed = cls.parse_date(value)
|
||||
if parsed is None:
|
||||
return False
|
||||
|
||||
year, month, day = parsed
|
||||
if not (min_year <= year <= max_year):
|
||||
return False
|
||||
if not (1 <= month <= 12):
|
||||
return False
|
||||
if not (1 <= day <= 31):
|
||||
return False
|
||||
|
||||
# 更精确的日期验证
|
||||
try:
|
||||
datetime(year, month, day)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_date(cls, value: str) -> Optional[tuple[int, int, int]]:
|
||||
"""
|
||||
Parse date from string.
|
||||
|
||||
Returns (year, month, day) tuple or None.
|
||||
"""
|
||||
from .format_variants import FormatVariants
|
||||
return FormatVariants._parse_date(value)
|
||||
|
||||
@classmethod
|
||||
def format_date_iso(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Format date to ISO format (YYYY-MM-DD).
|
||||
|
||||
Returns formatted string or None if parsing fails.
|
||||
"""
|
||||
parsed = cls.parse_date(value)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
year, month, day = parsed
|
||||
return f"{year}-{month:02d}-{day:02d}"
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_invoice_number(cls, value: str, min_length: int = 1, max_length: int = 30) -> bool:
|
||||
"""
|
||||
Validate invoice number.
|
||||
|
||||
Basic validation - just length check since invoice numbers are highly variable.
|
||||
"""
|
||||
clean = TextCleaner.clean_text(value)
|
||||
if not clean:
|
||||
return False
|
||||
|
||||
# 提取有意义的字符(字母和数字)
|
||||
meaningful = re.sub(r'[^a-zA-Z0-9]', '', clean)
|
||||
return min_length <= len(meaningful) <= max_length
|
||||
|
||||
# =========================================================================
|
||||
# Generic Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def validate_field(cls, field_name: str, value: str) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a field by name.
|
||||
|
||||
Returns (is_valid, error_message).
|
||||
"""
|
||||
if not value:
|
||||
return False, "Empty value"
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
if 'organisation' in field_lower or 'org' in field_lower:
|
||||
if cls.is_valid_organisation_number(value):
|
||||
return True, None
|
||||
return False, "Invalid organisation number format or checksum"
|
||||
|
||||
elif 'bankgiro' in field_lower:
|
||||
if cls.is_valid_bankgiro(value):
|
||||
return True, None
|
||||
return False, "Invalid Bankgiro format or checksum"
|
||||
|
||||
elif 'plusgiro' in field_lower:
|
||||
if cls.is_valid_plusgiro(value):
|
||||
return True, None
|
||||
return False, "Invalid Plusgiro format or checksum"
|
||||
|
||||
elif 'ocr' in field_lower:
|
||||
if cls.is_valid_ocr_number(value, validate_checksum=False):
|
||||
return True, None
|
||||
return False, "Invalid OCR number length"
|
||||
|
||||
elif 'amount' in field_lower:
|
||||
if cls.is_valid_amount(value):
|
||||
return True, None
|
||||
return False, "Invalid amount format"
|
||||
|
||||
elif 'date' in field_lower:
|
||||
if cls.is_valid_date(value):
|
||||
return True, None
|
||||
return False, "Invalid date format"
|
||||
|
||||
elif 'invoice' in field_lower and 'number' in field_lower:
|
||||
if cls.is_valid_invoice_number(value):
|
||||
return True, None
|
||||
return False, "Invalid invoice number"
|
||||
|
||||
else:
|
||||
# 默认:只检查非空
|
||||
if TextCleaner.clean_text(value):
|
||||
return True, None
|
||||
return False, "Empty value after cleaning"
|
||||
@@ -122,6 +122,7 @@ def create_api_router(
|
||||
inference_result = InferenceResult(
|
||||
document_id=service_result.document_id,
|
||||
success=service_result.success,
|
||||
document_type=service_result.document_type,
|
||||
fields=service_result.fields,
|
||||
confidence=service_result.confidence,
|
||||
detections=[
|
||||
|
||||
@@ -30,6 +30,9 @@ class InferenceResult(BaseModel):
|
||||
|
||||
document_id: str = Field(..., description="Document identifier")
|
||||
success: bool = Field(..., description="Whether inference succeeded")
|
||||
document_type: str = Field(
|
||||
default="invoice", description="Document type: 'invoice' or 'letter'"
|
||||
)
|
||||
fields: dict[str, str | None] = Field(
|
||||
default_factory=dict, description="Extracted field values"
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ class ServiceResult:
|
||||
|
||||
document_id: str
|
||||
success: bool = False
|
||||
document_type: str = "invoice" # "invoice" or "letter"
|
||||
fields: dict[str, str | None] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
detections: list[dict] = field(default_factory=list)
|
||||
@@ -145,6 +146,13 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
@@ -202,6 +210,13 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user