From 8fd61ea928e7201958d17e63b42527022a13dabf Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Thu, 22 Jan 2026 22:03:24 +0100 Subject: [PATCH] WIP --- src/cli/serve.py | 2 +- src/cli/train.py | 39 +- src/inference/field_extractor.py | 782 ++++++++++++++++++++++---- src/inference/pipeline.py | 70 ++- src/inference/test_field_extractor.py | 59 ++ src/normalize/normalizer.py | 135 ++--- src/ocr/machine_code_parser.py | 40 +- src/utils/__init__.py | 34 ++ src/utils/context_extractor.py | 433 ++++++++++++++ src/utils/format_variants.py | 610 ++++++++++++++++++++ src/utils/fuzzy_matcher.py | 417 ++++++++++++++ src/utils/ocr_corrections.py | 384 +++++++++++++ src/utils/test_advanced_utils.py | 399 +++++++++++++ src/utils/test_utils.py | 235 ++++++++ src/utils/text_cleaner.py | 244 ++++++++ src/utils/validators.py | 393 +++++++++++++ src/web/routes.py | 1 + src/web/schemas.py | 3 + src/web/services.py | 15 + 19 files changed, 4069 insertions(+), 226 deletions(-) create mode 100644 src/utils/__init__.py create mode 100644 src/utils/context_extractor.py create mode 100644 src/utils/format_variants.py create mode 100644 src/utils/fuzzy_matcher.py create mode 100644 src/utils/ocr_corrections.py create mode 100644 src/utils/test_advanced_utils.py create mode 100644 src/utils/test_utils.py create mode 100644 src/utils/text_cleaner.py create mode 100644 src/utils/validators.py diff --git a/src/cli/serve.py b/src/cli/serve.py index 94cc94c..8e26dc5 100644 --- a/src/cli/serve.py +++ b/src/cli/serve.py @@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace: "--dpi", type=int, default=150, - help="DPI for PDF rendering", + help="DPI for PDF rendering (must match training DPI)", ) parser.add_argument( diff --git a/src/cli/train.py b/src/cli/train.py index 78fef9c..a063e48 100644 --- a/src/cli/train.py +++ b/src/cli/train.py @@ -63,7 +63,24 @@ def main(): ) parser.add_argument( '--resume', - help='Resume from checkpoint' + action='store_true', + help='Resume from last checkpoint' + ) + parser.add_argument( + '--workers', + type=int, + default=4, + help='Number of data loader workers (default: 4, reduce if OOM)' + ) + parser.add_argument( + '--cache', + action='store_true', + help='Cache images in RAM (faster but uses more memory)' + ) + parser.add_argument( + '--low-memory', + action='store_true', + help='Enable low memory mode (batch=4, workers=2, no cache)' ) parser.add_argument( '--train-ratio', @@ -103,6 +120,16 @@ def main(): args = parser.parse_args() + # Apply low-memory mode if specified + if args.low_memory: + print("🔧 Low memory mode enabled") + args.batch = min(args.batch, 8) # Reduce from 16 to 8 + args.workers = min(args.workers, 4) # Reduce from 8 to 4 + args.cache = False + print(f" Batch size: {args.batch}") + print(f" Workers: {args.workers}") + print(f" Cache: disabled") + # Validate dataset directory dataset_dir = Path(args.dataset_dir) temp_dir = dataset_dir / 'temp' @@ -181,9 +208,10 @@ def main(): from ultralytics import YOLO # Load model - if args.resume: - print(f"Resuming from: {args.resume}") - model = YOLO(args.resume) + last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt' + if args.resume and last_checkpoint.exists(): + print(f"Resuming from: {last_checkpoint}") + model = YOLO(str(last_checkpoint)) else: model = YOLO(args.model) @@ -200,6 +228,9 @@ def main(): 'exist_ok': True, 'pretrained': True, 'verbose': True, + 'workers': args.workers, + 'cache': args.cache, + 'resume': args.resume and last_checkpoint.exists(), # Document-specific augmentation settings 'degrees': 5.0, 'translate': 0.05, diff --git a/src/inference/field_extractor.py b/src/inference/field_extractor.py index 3ef4ddb..c8aff67 100644 --- a/src/inference/field_extractor.py +++ b/src/inference/field_extractor.py @@ -2,17 +2,33 @@ Field Extractor Module Extracts and validates field values from detected regions. + +This module is used during inference to extract values from OCR text. +It uses shared utilities from src.utils for text cleaning and validation. + +Enhanced features: +- Multi-source fusion with confidence weighting +- Smart amount parsing with multiple strategies +- Enhanced date format unification +- OCR error correction integration """ -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any +from collections import defaultdict import re import numpy as np from PIL import Image from .yolo_detector import Detection, CLASS_TO_FIELD +# Import shared utilities for text cleaning and validation +from src.utils.text_cleaner import TextCleaner +from src.utils.validators import FieldValidators +from src.utils.fuzzy_matcher import FuzzyMatcher +from src.utils.ocr_corrections import OCRCorrections + @dataclass class ExtractedField: @@ -27,10 +43,14 @@ class ExtractedField: page_no: int is_valid: bool = True validation_error: str | None = None + # Multi-source fusion fields + alternative_values: list[tuple[str, float]] = field(default_factory=list) # [(value, confidence), ...] + extraction_method: str = 'single' # 'single', 'fused', 'corrected' + ocr_corrections_applied: list[str] = field(default_factory=list) def to_dict(self) -> dict: """Convert to dictionary.""" - return { + result = { 'field_name': self.field_name, 'value': self.normalized_value, 'raw_text': self.raw_text, @@ -40,6 +60,11 @@ class ExtractedField: 'is_valid': self.is_valid, 'validation_error': self.validation_error } + if self.alternative_values: + result['alternatives'] = self.alternative_values + if self.extraction_method != 'single': + result['extraction_method'] = self.extraction_method + return result class FieldExtractor: @@ -324,21 +349,9 @@ class FieldExtractor: Validate using Luhn (Mod10) algorithm. Used for Bankgiro, Plusgiro, and OCR number validation. - The checksum is valid if the total modulo 10 equals 0. + Delegates to shared FieldValidators for consistency. """ - if not digits.isdigit(): - return False - - total = 0 - for i, char in enumerate(reversed(digits)): - digit = int(char) - if i % 2 == 1: # Double every second digit from right - digit *= 2 - if digit > 9: - digit -= 9 - total += digit - - return total % 10 == 0 + return FieldValidators.luhn_checksum(digits) def _detect_giro_type(self, text: str) -> str | None: """ @@ -489,45 +502,90 @@ class FieldExtractor: return None, False, f"No valid Plusgiro found in text" def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]: - """Normalize monetary amount.""" - # Try to extract amount using regex patterns - # Pattern 1: Number with comma as decimal (Swedish format: 1 234,56) - # Pattern 2: Number with dot as decimal (1234.56) - # Pattern 3: Number followed by currency (275,60 kr or 275.60 SEK) + """Normalize monetary amount. - patterns = [ - # Swedish format with space thousand separator: 1 234,56 or 1234,56 - r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?', - # Simple decimal: 350.00 or 350,00 - r'(\d+[,\.]\d{2})', - # Integer amount - r'(\d{2,})', - ] + Uses shared TextCleaner for preprocessing and FieldValidators for parsing. + If multiple amounts are found, returns the last one (usually the total). + """ + # Split by newlines and process line by line to get the last valid amount + lines = text.split('\n') - for pattern in patterns: - matches = re.findall(pattern, text, re.IGNORECASE) - if matches: - # Take the last match (usually the total amount) - amount_str = matches[-1] - # Clean up - amount_str = amount_str.replace(' ', '').replace('\xa0', '') - # Handle comma as decimal separator - if ',' in amount_str: - amount_str = amount_str.replace(',', '.') + # Collect all valid amounts from all lines + all_amounts = [] + # Pattern for Swedish amount format (with decimals) + amount_pattern = r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?' + + for line in lines: + line = line.strip() + if not line: + continue + + # Find all amounts in this line + matches = re.findall(amount_pattern, line, re.IGNORECASE) + for match in matches: + amount_str = match.replace(' ', '').replace(',', '.') try: amount = float(amount_str) if amount > 0: - return f"{amount:.2f}", True, None + all_amounts.append(amount) except ValueError: continue + # Return the last amount found (usually the total) + if all_amounts: + return f"{all_amounts[-1]:.2f}", True, None + + # Fallback: try shared validator on cleaned text + cleaned = TextCleaner.normalize_amount_text(text) + amount = FieldValidators.parse_amount(cleaned) + if amount is not None and amount > 0: + return f"{amount:.2f}", True, None + + # Try to find any decimal number + simple_pattern = r'(\d+[,\.]\d{2})' + matches = re.findall(simple_pattern, text) + if matches: + amount_str = matches[-1].replace(',', '.') + try: + amount = float(amount_str) + if amount > 0: + return f"{amount:.2f}", True, None + except ValueError: + pass + + # Last resort: try to find integer amount (no decimals) + # Look for patterns like "Amount: 11699" or standalone numbers + int_pattern = r'(?:amount|belopp|summa|total)[:\s]*(\d+)' + match = re.search(int_pattern, text, re.IGNORECASE) + if match: + try: + amount = float(match.group(1)) + if amount > 0: + return f"{amount:.2f}", True, None + except ValueError: + pass + + # Very last resort: find any standalone number >= 3 digits + standalone_pattern = r'\b(\d{3,})\b' + matches = re.findall(standalone_pattern, text) + if matches: + # Take the last/largest number + try: + amount = float(matches[-1]) + if amount > 0: + return f"{amount:.2f}", True, None + except ValueError: + pass + return None, False, f"Cannot parse amount: {text}" def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]: """ Normalize date from text that may contain surrounding text. + Uses shared FieldValidators for date parsing and validation. + Handles various date formats found in Swedish invoices: - 2025-08-29 (ISO format) - 2025.08.29 (dot separator) @@ -535,9 +593,14 @@ class FieldExtractor: - 29.08.2025 (European with dots) - 20250829 (compact format) """ + # First, try using shared validator + iso_date = FieldValidators.format_date_iso(text) + if iso_date and FieldValidators.is_valid_date(iso_date): + return iso_date, True, None + + # Fallback: try original patterns for edge cases from datetime import datetime - # Common date patterns - order matters, most specific first patterns = [ # ISO format: 2025-08-29 (r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"), @@ -570,32 +633,76 @@ class FieldExtractor: """ Normalize payment line region text. - Extracts OCR, Amount, and Bankgiro from the payment line using MachineCodeParser. + Extracts the machine-readable payment line format from OCR text. + Standard Swedish payment line format: # # <Öre> > ## + + Examples: + - "# 94228110015950070 # 15658 00 8 > 48666036#14#" -> includes amount 15658.00 + - "# 11000770600242 # 1200 00 5 > 3082963#41#" -> includes amount 1200.00 + + Returns normalized format preserving ALL components including Amount: + - Full format: "OCR:xxx Amount:xxx.xx BG:xxx" or "OCR:xxx Amount:xxx.xx PG:xxx" + - This allows downstream cross-validation to extract fields properly. """ - from ..ocr.machine_code_parser import MachineCodeParser + # Pattern to match Swedish payment line format WITH amount + # Format: # # <Öre> > ## + # Account number may have spaces: "78 2 1 713" -> "7821713" + # Kronor may have OCR-induced spaces: "12 0 0" -> "1200" + # The > symbol may be missing in low-DPI OCR, so make it optional + # Check digits may have spaces: "#41 #" -> "#41#" + payment_line_full_pattern = r'#\s*(\d[\d\s]*)\s*#\s*([\d\s]+?)\s+(\d{2})\s+(\d)\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#' - # Create a simple token-like structure for the parser - # (The parser expects tokens, but for inference we have raw text) - parser = MachineCodeParser() + match = re.search(payment_line_full_pattern, text) + if match: + ocr_part = match.group(1).replace(' ', '') + kronor = match.group(2).replace(' ', '') # Remove OCR-induced spaces + ore = match.group(3) + record_type = match.group(4) + account = match.group(5).replace(' ', '') # Remove spaces from account number + check_digits = match.group(6) - # Try to parse the standard payment line format - result = parser._parse_standard_payment_line(text) + # Reconstruct the clean machine-readable format + # Format: # OCR # KRONOR ORE TYPE > ACCOUNT#CHECK# + result = f"# {ocr_part} # {kronor} {ore} {record_type} > {account}#{check_digits}#" + return result, True, None - if result: - # Format as structured output - parts = [] - if result.get('ocr'): - parts.append(f"OCR:{result['ocr']}") - if result.get('amount'): - parts.append(f"Amount:{result['amount']}") - if result.get('bankgiro'): - parts.append(f"BG:{result['bankgiro']}") + # Try pattern WITHOUT amount (some payment lines don't have amount) + # Format: # # > ## + # > may be missing in low-DPI OCR + # Check digits may have spaces + payment_line_no_amount_pattern = r'#\s*(\d[\d\s]*)\s*#\s*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#' + match = re.search(payment_line_no_amount_pattern, text) + if match: + ocr_part = match.group(1).replace(' ', '') + account = match.group(2).replace(' ', '') + check_digits = match.group(3) - if parts: - return ' '.join(parts), True, None + result = f"# {ocr_part} # > {account}#{check_digits}#" + return result, True, None - # Fallback: return raw text if no structured parsing possible - return text, True, None + # Try alternative pattern: just look for the # > account# pattern (> optional) + # Check digits may have spaces + alt_pattern = r'(\d[\d\s]{10,})\s*#[^>]*>?\s*([\d\s]+)\s*#\s*(\d+)\s*#' + match = re.search(alt_pattern, text) + if match: + ocr_part = match.group(1).replace(' ', '') + account = match.group(2).replace(' ', '') + check_digits = match.group(3) + + result = f"# {ocr_part} # > {account}#{check_digits}#" + return result, True, None + + # Try to find just the account part with # markers + # Check digits may have spaces + account_pattern = r'>\s*([\d\s]+)\s*#\s*(\d+)\s*#' + match = re.search(account_pattern, text) + if match: + account = match.group(1).replace(' ', '') + check_digits = match.group(2) + return f"> {account}#{check_digits}#", True, "Partial payment line (account only)" + + # Fallback: return None if no payment line format found + return None, False, "No valid payment line format found" def _normalize_supplier_org_number(self, text: str) -> tuple[str | None, bool, str | None]: """ @@ -640,24 +747,52 @@ class FieldExtractor: Normalize customer number extracted from OCR. Customer numbers can have various formats: - - With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N' + - With separators: 'JTY 576-3', 'EMM 256-6', 'FFL 019N', 'UMJ 436-R' - Compact (no separators): 'JTY5763', 'EMM2566', 'FFL019N' - Mixed with names: 'VIKSTRÖM, ELIAS CH FFL 01' -> extract 'FFL 01' + - Address format: 'Umj 436-R Billo' -> extract 'UMJ 436-R' Note: Spaces and dashes may be removed from invoice display, so we need to match both 'JTY 576-3' and 'JTY5763' formats. """ - from ..normalize.normalizer import FieldNormalizer - - # Clean the text using the same logic as matcher - text = FieldNormalizer.clean_text(text) - - if not text: + if not text or not text.strip(): return None, False, "Empty text" - # Customer number patterns - ordered by specificity - # Match both spaced/dashed versions and compact versions + # Keep original text for pattern matching (don't uppercase yet) + original_text = text.strip() + + # Customer number patterns - ordered by specificity (most specific first) + # All patterns use IGNORECASE so they work regardless of case customer_code_patterns = [ + # Pattern: 2-4 letters + space + digits + dash + single letter/digit (UMJ 436-R, EMM 256-6) + # This is the most common Swedish customer number format + r'\b([A-Za-z]{2,4})\s+(\d{1,4})-([A-Za-z0-9])\b', + # Pattern: 2-4 letters + space + digits + letter WITHOUT dash (Dwq 211X, ABC 123X) + # Note: This is also common for customer numbers + r'\b([A-Za-z]{2,4})\s+(\d{2,4})([A-Za-z])\b', + # Pattern: Word (capitalized) + space + digits + dash + letter (Umj 436-R, Billo 123-A) + r'\b([A-Za-z][a-z]{1,10})\s+(\d{1,4})-([A-Za-z0-9])\b', + # Pattern: Letters + digits + dash + digit/letter without space (JTY576-3) + r'\b([A-Za-z]{2,4})(\d{1,4})-([A-Za-z0-9])\b', + ] + + # Try specific patterns first + for pattern in customer_code_patterns: + match = re.search(pattern, original_text) + if match: + # Skip if it looks like a Swedish postal code (SE + digits) + full_match = match.group(0) + if full_match.upper().startswith('SE ') and re.match(r'^SE\s+\d{3}\s*\d{2}', full_match, re.IGNORECASE): + continue + # Reconstruct the customer number in standard format + groups = match.groups() + if len(groups) == 3: + # Format: XXX NNN-X (add dash if not present, e.g., "Dwq 211X" -> "DWQ 211-X") + result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}" + return result, True, None + + # Generic patterns for other formats + generic_patterns = [ # Pattern: Letters + space/dash + digits + dash + digit (EMM 256-6, JTY 576-3) r'\b([A-Z]{2,4}[\s\-]?\d{1,4}[\s\-]\d{1,2}[A-Z]?)\b', # Pattern: Letters + space/dash + digits + optional letter (FFL 019N, ABC 123X) @@ -666,27 +801,47 @@ class FieldExtractor: r'\b([A-Z]{2,4}\d{3,6}[A-Z]?)\b', # Pattern: Single letter + digits (A12345) r'\b([A-Z]\d{4,6}[A-Z]?)\b', - # Pattern: Digits + dash/space + digits (123-456) - r'\b(\d{3,6}[\s\-]\d{1,4})\b', ] all_matches = [] - for pattern in customer_code_patterns: - matches = re.findall(pattern, text, re.IGNORECASE) - all_matches.extend(matches) + for pattern in generic_patterns: + for match in re.finditer(pattern, original_text, re.IGNORECASE): + matched_text = match.group(1) + pos = match.start() + # Filter out matches that look like postal codes or ID numbers + # Postal codes are usually 3-5 digits without letters + if re.match(r'^\d+$', matched_text): + continue + # Filter out V4 2 type matches (single letter + digit + space + digit) + if re.match(r'^[A-Z]\d\s+\d$', matched_text, re.IGNORECASE): + continue + # Filter out Swedish postal codes (SE XXX XX format or SE + digits) + # SE followed by digits is typically postal code, not customer number + if re.match(r'^SE[\s\-]*\d', matched_text, re.IGNORECASE): + continue + all_matches.append((matched_text, pos)) if all_matches: - # Prefer longer matches and those appearing later in text (after names) - # Sort by position in text (later = better) and length (longer = better) + # Prefer matches that contain both letters and digits with dash scored_matches = [] - for match in all_matches: - pos = text.upper().rfind(match.upper()) - # Score: position * 0.1 + length (prefer later and longer) - score = pos * 0.1 + len(match) - scored_matches.append((score, match)) + for match_text, pos in all_matches: + score = 0 + # Bonus for containing dash (likely customer number format) + if '-' in match_text: + score += 50 + # Bonus for format like XXX NNN-X + if re.match(r'^[A-Z]{2,4}\s*\d{1,4}-[A-Z0-9]$', match_text, re.IGNORECASE): + score += 100 + # Bonus for length (prefer medium length) + if 6 <= len(match_text) <= 12: + score += 20 + # Position bonus (prefer later matches, after names) + score += pos * 0.1 + scored_matches.append((score, match_text)) - best_match = max(scored_matches, key=lambda x: x[0])[1] - return best_match.strip().upper(), True, None + if scored_matches: + best_match = max(scored_matches, key=lambda x: x[0])[1] + return best_match.strip().upper(), True, None # Pattern 2: Look for explicit labels labeled_patterns = [ @@ -694,7 +849,7 @@ class FieldExtractor: ] for pattern in labeled_patterns: - match = re.search(pattern, text, re.IGNORECASE) + match = re.search(pattern, original_text, re.IGNORECASE) if match: extracted = match.group(1).strip() extracted = re.sub(r'[\s\.\,\:]+$', '', extracted) @@ -702,32 +857,18 @@ class FieldExtractor: return extracted.upper(), True, None # Pattern 3: If text contains comma (likely "NAME, NAME CODE"), extract after last comma - if ',' in text: - after_comma = text.split(',')[-1].strip() + if ',' in original_text: + after_comma = original_text.split(',')[-1].strip() # Look for alphanumeric code in the part after comma - for pattern in customer_code_patterns[:3]: # Use first 3 patterns - code_match = re.search(pattern, after_comma, re.IGNORECASE) + for pattern in customer_code_patterns: + code_match = re.search(pattern, after_comma) if code_match: - return code_match.group(1).strip().upper(), True, None + groups = code_match.groups() + if len(groups) == 3: + result = f"{groups[0].upper()} {groups[1]}-{groups[2].upper()}" + return result, True, None - # Pattern 4: Short text - filter out name-like words - if len(text) <= 20: - words = text.split() - code_parts = [] - for word in words: - # Keep if: contains digits, or is short uppercase (likely abbreviation) - if re.search(r'\d', word) or (len(word) <= 4 and word.isupper()): - code_parts.append(word) - if code_parts: - result = ' '.join(code_parts).upper() - if len(result) >= 3: - return result, True, None - - # Fallback: return cleaned text if reasonable - if text and 3 <= len(text) <= 15: - return text.upper(), True, None - - return None, False, f"Cannot extract customer number from: {text[:50]}" + return None, False, f"Cannot extract customer number from: {original_text[:50]}" def extract_all_fields( self, @@ -779,3 +920,428 @@ class FieldExtractor: fields['OCR'] = invoice_number return fields + + # ========================================================================= + # Multi-Source Fusion with Confidence Weighting + # ========================================================================= + + def fuse_multiple_detections( + self, + extracted_fields: list[ExtractedField] + ) -> list[ExtractedField]: + """ + Fuse multiple detections of the same field using confidence-weighted voting. + + When YOLO detects the same field type multiple times (e.g., multiple Amount boxes), + this method selects the best value or combines them intelligently. + + Strategies: + 1. For numeric fields (Amount, OCR): prefer values that pass validation + 2. For date fields: prefer values in expected range + 3. For giro numbers: prefer values with valid Luhn checksum + 4. General: weighted vote by confidence scores + + Args: + extracted_fields: List of all extracted fields (may have duplicates) + + Returns: + List with duplicates resolved to single best value per field + """ + # Group fields by name + fields_by_name: dict[str, list[ExtractedField]] = defaultdict(list) + for field in extracted_fields: + fields_by_name[field.field_name].append(field) + + fused_fields = [] + + for field_name, candidates in fields_by_name.items(): + if len(candidates) == 1: + # No fusion needed + fused_fields.append(candidates[0]) + else: + # Multiple candidates - fuse them + fused = self._fuse_field_candidates(field_name, candidates) + fused_fields.append(fused) + + return fused_fields + + def _fuse_field_candidates( + self, + field_name: str, + candidates: list[ExtractedField] + ) -> ExtractedField: + """ + Fuse multiple candidates for a single field. + + Returns the best candidate with alternatives recorded. + """ + # Sort by confidence (descending) + sorted_candidates = sorted(candidates, key=lambda x: x.confidence, reverse=True) + + # Collect all unique values with their max confidence + value_scores: dict[str, tuple[float, ExtractedField]] = {} + for c in sorted_candidates: + if c.normalized_value: + if c.normalized_value not in value_scores: + value_scores[c.normalized_value] = (c.confidence, c) + else: + # Keep the higher confidence one + if c.confidence > value_scores[c.normalized_value][0]: + value_scores[c.normalized_value] = (c.confidence, c) + + if not value_scores: + # No valid values, return the highest confidence candidate + return sorted_candidates[0] + + # Field-specific fusion strategy + best_value, best_field = self._select_best_value(field_name, value_scores) + + # Record alternatives + alternatives = [ + (v, score) for v, (score, _) in value_scores.items() + if v != best_value + ] + + # Create fused result + result = ExtractedField( + field_name=field_name, + raw_text=best_field.raw_text, + normalized_value=best_value, + confidence=value_scores[best_value][0], + detection_confidence=best_field.detection_confidence, + ocr_confidence=best_field.ocr_confidence, + bbox=best_field.bbox, + page_no=best_field.page_no, + is_valid=best_field.is_valid, + validation_error=best_field.validation_error, + alternative_values=alternatives, + extraction_method='fused' if len(value_scores) > 1 else 'single' + ) + + return result + + def _select_best_value( + self, + field_name: str, + value_scores: dict[str, tuple[float, ExtractedField]] + ) -> tuple[str, ExtractedField]: + """ + Select the best value for a field using field-specific logic. + + Returns (best_value, best_field) + """ + items = list(value_scores.items()) + + # Field-specific selection + if field_name in ('Bankgiro', 'Plusgiro', 'OCR'): + # Prefer values with valid Luhn checksum + for value, (score, field) in items: + digits = re.sub(r'\D', '', value) + if FieldValidators.luhn_checksum(digits): + return value, field + + elif field_name == 'Amount': + # Prefer larger amounts (usually the total, not subtotals) + amounts = [] + for value, (score, field) in items: + try: + amt = float(value.replace(',', '.')) + amounts.append((amt, value, field)) + except ValueError: + continue + if amounts: + # Return the largest amount + amounts.sort(reverse=True) + return amounts[0][1], amounts[0][2] + + elif field_name in ('InvoiceDate', 'InvoiceDueDate'): + # Prefer dates in reasonable range + from datetime import datetime + for value, (score, field) in items: + try: + dt = datetime.strptime(value, '%Y-%m-%d') + # Prefer recent dates (within last 2 years and next 1 year) + now = datetime.now() + if now.year - 2 <= dt.year <= now.year + 1: + return value, field + except ValueError: + continue + + # Default: return highest confidence value + best = max(items, key=lambda x: x[1][0]) + return best[0], best[1][1] + + # ========================================================================= + # Enhanced Amount Parsing + # ========================================================================= + + def _normalize_amount_enhanced(self, text: str) -> tuple[str | None, bool, str | None]: + """ + Enhanced amount parsing with multiple strategies. + + Strategies: + 1. Pattern matching for Swedish formats + 2. Context-aware extraction (look for keywords like "Total", "Summa") + 3. OCR error correction for common digit errors + 4. Multi-amount handling (prefer last/largest as total) + + This method replaces the original _normalize_amount when enhanced mode is enabled. + """ + # Strategy 1: Apply OCR corrections first + corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected + + # Strategy 2: Look for labeled amounts (highest priority) + labeled_patterns = [ + # Swedish patterns + (r'(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 1.0), + (r'(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})', 0.8), # Lower priority for VAT + # Generic pattern + (r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?', 0.7), + ] + + candidates = [] + for pattern, priority in labeled_patterns: + for match in re.finditer(pattern, corrected_text, re.IGNORECASE): + amount_str = match.group(1).replace(' ', '').replace(',', '.') + try: + amount = float(amount_str) + if 0 < amount < 10_000_000: # Reasonable range + candidates.append((amount, priority, match.start())) + except ValueError: + continue + + if candidates: + # Sort by priority (desc), then by position (later is usually total) + candidates.sort(key=lambda x: (-x[1], -x[2])) + best_amount = candidates[0][0] + return f"{best_amount:.2f}", True, None + + # Strategy 3: Parse with shared validator + cleaned = TextCleaner.normalize_amount_text(corrected_text) + amount = FieldValidators.parse_amount(cleaned) + if amount is not None and 0 < amount < 10_000_000: + return f"{amount:.2f}", True, None + + # Strategy 4: Try to extract any decimal number as fallback + decimal_pattern = r'(\d{1,3}(?:[\s\.]?\d{3})*[,\.]\d{2})' + matches = re.findall(decimal_pattern, corrected_text) + if matches: + # Clean and parse each match + amounts = [] + for m in matches: + cleaned_m = m.replace(' ', '').replace('.', '').replace(',', '.') + # Handle Swedish format: "1 234,56" -> "1234.56" + if ',' in m and '.' not in m: + cleaned_m = m.replace(' ', '').replace(',', '.') + try: + amt = float(cleaned_m) + if 0 < amt < 10_000_000: + amounts.append(amt) + except ValueError: + continue + + if amounts: + # Return the last/largest amount (usually the total) + return f"{max(amounts):.2f}", True, None + + return None, False, f"Cannot parse amount: {text[:50]}" + + # ========================================================================= + # Enhanced Date Parsing + # ========================================================================= + + def _normalize_date_enhanced(self, text: str) -> tuple[str | None, bool, str | None]: + """ + Enhanced date parsing with comprehensive format support. + + Supports: + - ISO: 2024-12-29, 2024/12/29 + - European: 29.12.2024, 29/12/2024, 29-12-2024 + - Swedish text: "29 december 2024", "29 dec 2024" + - Compact: 20241229 + - With OCR corrections: 2O24-12-29 -> 2024-12-29 + """ + from datetime import datetime + + # Apply OCR corrections + corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected + + # Try shared validator first + iso_date = FieldValidators.format_date_iso(corrected_text) + if iso_date and FieldValidators.is_valid_date(iso_date): + return iso_date, True, None + + # Swedish month names + swedish_months = { + 'januari': 1, 'jan': 1, + 'februari': 2, 'feb': 2, + 'mars': 3, 'mar': 3, + 'april': 4, 'apr': 4, + 'maj': 5, + 'juni': 6, 'jun': 6, + 'juli': 7, 'jul': 7, + 'augusti': 8, 'aug': 8, + 'september': 9, 'sep': 9, 'sept': 9, + 'oktober': 10, 'okt': 10, + 'november': 11, 'nov': 11, + 'december': 12, 'dec': 12, + } + + # Pattern for Swedish text dates: "29 december 2024" or "29 dec 2024" + swedish_pattern = r'(\d{1,2})\s+([a-zåäö]+)\s+(\d{4})' + match = re.search(swedish_pattern, corrected_text.lower()) + if match: + day = int(match.group(1)) + month_name = match.group(2) + year = int(match.group(3)) + if month_name in swedish_months: + month = swedish_months[month_name] + try: + dt = datetime(year, month, day) + if 2000 <= dt.year <= 2100: + return dt.strftime('%Y-%m-%d'), True, None + except ValueError: + pass + + # Extended patterns + patterns = [ + # ISO format: 2025-08-29, 2025/08/29 + (r'(\d{4})[-/](\d{1,2})[-/](\d{1,2})', 'ymd'), + # Dot format: 2025.08.29 + (r'(\d{4})\.(\d{1,2})\.(\d{1,2})', 'ymd'), + # European slash: 29/08/2025 + (r'(\d{1,2})/(\d{1,2})/(\d{4})', 'dmy'), + # European dot: 29.08.2025 + (r'(\d{1,2})\.(\d{1,2})\.(\d{4})', 'dmy'), + # European dash: 29-08-2025 + (r'(\d{1,2})-(\d{1,2})-(\d{4})', 'dmy'), + # Compact: 20250829 + (r'(? tuple[str, list[str]]: + """ + Apply OCR corrections to raw text based on field type. + + Returns (corrected_text, list_of_corrections_applied) + """ + corrections_applied = [] + + if field_name in ('OCR', 'Bankgiro', 'Plusgiro', 'supplier_org_number'): + # Aggressive correction for numeric fields + result = OCRCorrections.correct_digits(raw_text, aggressive=True) + if result.corrections_applied: + corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied] + return result.corrected, corrections_applied + + elif field_name == 'Amount': + # Conservative correction for amounts (preserve decimal separators) + result = OCRCorrections.correct_digits(raw_text, aggressive=False) + if result.corrections_applied: + corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied] + return result.corrected, corrections_applied + + elif field_name in ('InvoiceDate', 'InvoiceDueDate'): + # Conservative correction for dates + result = OCRCorrections.correct_digits(raw_text, aggressive=False) + if result.corrections_applied: + corrections_applied = [f"{c[1]}->{c[2]}" for c in result.corrections_applied] + return result.corrected, corrections_applied + + # No correction for other fields + return raw_text, [] + + # ========================================================================= + # Extraction with All Enhancements + # ========================================================================= + + def extract_with_enhancements( + self, + detection: Detection, + pdf_tokens: list, + image_width: int, + image_height: int, + use_enhanced_parsing: bool = True + ) -> ExtractedField: + """ + Extract field value with all enhancements enabled. + + Combines: + 1. OCR error correction + 2. Enhanced amount/date parsing + 3. Multi-strategy extraction + + Args: + detection: Detection object + pdf_tokens: PDF text tokens + image_width: Image width in pixels + image_height: Image height in pixels + use_enhanced_parsing: Whether to use enhanced parsing methods + + Returns: + ExtractedField with enhancements applied + """ + # First, extract using standard method + base_result = self.extract_from_detection_with_pdf( + detection, pdf_tokens, image_width, image_height + ) + + if not use_enhanced_parsing: + return base_result + + # Apply OCR corrections + corrected_text, corrections = self.apply_ocr_corrections( + base_result.field_name, base_result.raw_text + ) + + # Re-normalize with enhanced methods if corrections were applied + if corrections or base_result.normalized_value is None: + if base_result.field_name == 'Amount': + normalized, is_valid, error = self._normalize_amount_enhanced(corrected_text) + elif base_result.field_name in ('InvoiceDate', 'InvoiceDueDate'): + normalized, is_valid, error = self._normalize_date_enhanced(corrected_text) + else: + # Re-run standard normalization with corrected text + normalized, is_valid, error = self._normalize_and_validate( + base_result.field_name, corrected_text + ) + + # Update result if we got a better value + if normalized and (not base_result.normalized_value or is_valid): + base_result.normalized_value = normalized + base_result.is_valid = is_valid + base_result.validation_error = error + base_result.ocr_corrections_applied = corrections + if corrections: + base_result.extraction_method = 'corrected' + + return base_result diff --git a/src/inference/pipeline.py b/src/inference/pipeline.py index ca987d8..08ea33c 100644 --- a/src/inference/pipeline.py +++ b/src/inference/pipeline.py @@ -214,6 +214,43 @@ class InferencePipeline: # Perform cross-validation if payment_line is detected self._cross_validate_payment_line(result) + def _parse_machine_readable_payment_line(self, payment_line: str) -> tuple[str | None, str | None, str | None]: + """ + Parse machine-readable Swedish payment line format. + + Format: # # <Öre> > ## + Example: "# 11000770600242 # 1200 00 5 > 3082963#41#" + + Returns: (ocr, amount, account) tuple + """ + # Pattern with amount + pattern_full = r'#\s*(\d+)\s*#\s*(\d+)\s+(\d{2})\s+\d\s*>\s*(\d+)#\d+#' + match = re.search(pattern_full, payment_line) + if match: + ocr = match.group(1) + kronor = match.group(2) + ore = match.group(3) + account = match.group(4) + amount = f"{kronor}.{ore}" + return ocr, amount, account + + # Pattern without amount + pattern_no_amount = r'#\s*(\d+)\s*#\s*>\s*(\d+)#\d+#' + match = re.search(pattern_no_amount, payment_line) + if match: + ocr = match.group(1) + account = match.group(2) + return ocr, None, account + + # Fallback: partial pattern + pattern_partial = r'>\s*(\d+)#\d+#' + match = re.search(pattern_partial, payment_line) + if match: + account = match.group(1) + return None, None, account + + return None, None, None + def _cross_validate_payment_line(self, result: InferenceResult) -> None: """ Cross-validate payment_line data against other detected fields. @@ -236,23 +273,26 @@ class InferencePipeline: cv = CrossValidationResult() cv.details = [] - # Parse payment_line format: "OCR:12345 Amount:100,00 BG:123-4567" - pl_parts = {} - for part in str(payment_line).split(): - if ':' in part: - key, value = part.split(':', 1) - pl_parts[key.upper()] = value + # Parse machine-readable payment line format + ocr, amount, account = self._parse_machine_readable_payment_line(str(payment_line)) - cv.payment_line_ocr = pl_parts.get('OCR') - cv.payment_line_amount = pl_parts.get('AMOUNT') + cv.payment_line_ocr = ocr + cv.payment_line_amount = amount - # Determine account type from payment_line - if pl_parts.get('BG'): - cv.payment_line_account = pl_parts['BG'] - cv.payment_line_account_type = 'bankgiro' - elif pl_parts.get('PG'): - cv.payment_line_account = pl_parts['PG'] - cv.payment_line_account_type = 'plusgiro' + # Determine account type based on digit count + if account: + # Bankgiro: 7-8 digits, Plusgiro: typically fewer + if len(account) >= 7: + cv.payment_line_account_type = 'bankgiro' + # Format: XXX-XXXX or XXXX-XXXX + if len(account) == 7: + cv.payment_line_account = f"{account[:3]}-{account[3:]}" + else: + cv.payment_line_account = f"{account[:4]}-{account[4:]}" + else: + cv.payment_line_account_type = 'plusgiro' + # Format: XXXXXXX-X + cv.payment_line_account = f"{account[:-1]}-{account[-1]}" # Cross-validate and OVERRIDE with payment_line values diff --git a/src/inference/test_field_extractor.py b/src/inference/test_field_extractor.py index 47d1c1b..dc0fb60 100644 --- a/src/inference/test_field_extractor.py +++ b/src/inference/test_field_extractor.py @@ -266,6 +266,38 @@ class TestNormalizePaymentLine: assert is_valid is True # Bankgiro should be normalized despite spaces + def test_payment_line_with_spaces_in_check_digits(self, extractor): + """Test payment line with spaces around check digits: #41 # instead of #41#.""" + text = "# 6026726908 # 736 00 9 > 5692041 #41 #" + result, is_valid, error = extractor._normalize_payment_line(text) + + assert result is not None + assert is_valid is True + assert "6026726908" in result + assert "736 00" in result + assert "5692041#41#" in result + + def test_payment_line_with_ocr_spaces_in_amount(self, extractor): + """Test payment line with OCR-induced spaces in amount: '12 0 0 00' -> '1200 00'.""" + text = "# 11000770600242 # 12 0 0 00 5 3082963#41#" + result, is_valid, error = extractor._normalize_payment_line(text) + + assert result is not None + assert is_valid is True + assert "11000770600242" in result + assert "1200 00" in result + assert "3082963#41#" in result + + def test_payment_line_without_greater_symbol(self, extractor): + """Test payment line with missing > symbol (low-DPI OCR issue).""" + text = "# 11000770600242 # 1200 00 5 3082963#41#" + result, is_valid, error = extractor._normalize_payment_line(text) + + assert result is not None + assert is_valid is True + assert "11000770600242" in result + assert "1200 00" in result + class TestNormalizeCustomerNumber: """Tests for customer number normalization.""" @@ -284,6 +316,33 @@ class TestNormalizeCustomerNumber: result, is_valid, error = extractor._normalize_customer_number("JTY5763") assert result is not None + def test_format_without_dash(self, extractor): + """Test customer number format without dash: Dwq 211X -> DWQ 211-X.""" + text = "Dwq 211X Billo SE 106 43 Stockholm" + result, is_valid, error = extractor._normalize_customer_number(text) + + assert result is not None + assert is_valid is True + assert result == "DWQ 211-X" + + def test_swedish_postal_code_exclusion(self, extractor): + """Test that Swedish postal codes are excluded: SE 106 43 should not be extracted.""" + text = "SE 106 43 Stockholm" + result, is_valid, error = extractor._normalize_customer_number(text) + + # Should not extract postal code + assert result is None or "SE 106" not in result + + def test_customer_number_with_postal_code_in_text(self, extractor): + """Test extracting customer number when postal code is also present.""" + text = "Customer: ABC 123X, Address: SE 106 43 Stockholm" + result, is_valid, error = extractor._normalize_customer_number(text) + + assert result is not None + assert "ABC" in result + # Should not extract postal code + assert "SE 106" not in result if result else True + class TestNormalizeSupplierOrgNumber: """Tests for supplier organization number normalization.""" diff --git a/src/normalize/normalizer.py b/src/normalize/normalizer.py index dbedaa2..1e8d33c 100644 --- a/src/normalize/normalizer.py +++ b/src/normalize/normalizer.py @@ -2,6 +2,9 @@ Field Normalization Module Normalizes field values to generate multiple candidate forms for matching. + +This module generates variants of CSV values for matching against OCR text. +It uses shared utilities from src.utils for text cleaning and OCR error variants. """ import re @@ -9,6 +12,10 @@ from dataclasses import dataclass from datetime import datetime from typing import Callable +# Import shared utilities +from src.utils.text_cleaner import TextCleaner +from src.utils.format_variants import FormatVariants + @dataclass class NormalizedValue: @@ -39,15 +46,11 @@ class FieldNormalizer: @staticmethod def clean_text(text: str) -> str: - """Remove invisible characters and normalize whitespace and dashes.""" - # Remove zero-width characters - text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text) - # Normalize different dash types to standard hyphen-minus (ASCII 45) - # en-dash (–, U+2013), em-dash (—, U+2014), minus sign (−, U+2212), middle dot (·, U+00B7) - text = re.sub(r'[\u2013\u2014\u2212\u00b7]', '-', text) - # Normalize whitespace - text = ' '.join(text.split()) - return text.strip() + """Remove invisible characters and normalize whitespace and dashes. + + Delegates to shared TextCleaner for consistency. + """ + return TextCleaner.clean_text(text) @staticmethod def normalize_invoice_number(value: str) -> list[str]: @@ -81,57 +84,44 @@ class FieldNormalizer: """ Normalize Bankgiro number. + Uses shared FormatVariants plus OCR error variants. + Examples: '5393-9484' -> ['5393-9484', '53939484'] '53939484' -> ['53939484', '5393-9484'] """ - value = FieldNormalizer.clean_text(value) - digits_only = re.sub(r'\D', '', value) + # Use shared module for base variants + variants = set(FormatVariants.bankgiro_variants(value)) - variants = [value] + # Add OCR error variants + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + if digits: + for ocr_var in TextCleaner.generate_ocr_variants(digits): + variants.add(ocr_var) - if digits_only: - # Add without dash - variants.append(digits_only) - - # Add with dash (format: XXXX-XXXX for 8 digits) - if len(digits_only) == 8: - with_dash = f"{digits_only[:4]}-{digits_only[4:]}" - variants.append(with_dash) - elif len(digits_only) == 7: - # Some bankgiro numbers are 7 digits: XXX-XXXX - with_dash = f"{digits_only[:3]}-{digits_only[3:]}" - variants.append(with_dash) - - return list(set(v for v in variants if v)) + return list(v for v in variants if v) @staticmethod def normalize_plusgiro(value: str) -> list[str]: """ Normalize Plusgiro number. + Uses shared FormatVariants plus OCR error variants. + Examples: '1234567-8' -> ['1234567-8', '12345678'] '12345678' -> ['12345678', '1234567-8'] """ - value = FieldNormalizer.clean_text(value) - digits_only = re.sub(r'\D', '', value) + # Use shared module for base variants + variants = set(FormatVariants.plusgiro_variants(value)) - variants = [value] + # Add OCR error variants + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + if digits: + for ocr_var in TextCleaner.generate_ocr_variants(digits): + variants.add(ocr_var) - if digits_only: - variants.append(digits_only) - - # Plusgiro format: XXXXXXX-X (7 digits + check digit) - if len(digits_only) == 8: - with_dash = f"{digits_only[:-1]}-{digits_only[-1]}" - variants.append(with_dash) - # Also handle 6+1 format - elif len(digits_only) == 7: - with_dash = f"{digits_only[:-1]}-{digits_only[-1]}" - variants.append(with_dash) - - return list(set(v for v in variants if v)) + return list(v for v in variants if v) @staticmethod def normalize_organisation_number(value: str) -> list[str]: @@ -141,60 +131,27 @@ class FieldNormalizer: Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits) Swedish VAT format: SE + org_number (10 digits) + 01 + Uses shared FormatVariants for comprehensive variant generation, + plus OCR error variants. + Examples: '556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...] '5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...] 'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...] """ - value = FieldNormalizer.clean_text(value) + # Use shared module for base variants + variants = set(FormatVariants.organisation_number_variants(value)) - # Check if input is a VAT number (starts with SE, ends with 01) - org_digits = None - if value.upper().startswith('SE') and len(value) >= 12: - # Extract org number from VAT: SE + 10 digits + 01 - potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits - if len(potential_org) == 12 and potential_org.endswith('01'): - org_digits = potential_org[:-2] # Remove trailing 01 - elif len(potential_org) == 10: - org_digits = potential_org + # Add OCR error variants for digit sequences + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + if digits and len(digits) >= 10: + # Generate variants where OCR might have misread characters + for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]): + variants.add(ocr_var) + if len(ocr_var) == 10: + variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}") - if org_digits is None: - org_digits = re.sub(r'\D', '', value) - - variants = [value] - - if org_digits: - variants.append(org_digits) - - # Standard format: NNNNNN-NNNN (10 digits total) - if len(org_digits) == 10: - with_dash = f"{org_digits[:6]}-{org_digits[6:]}" - variants.append(with_dash) - - # Swedish VAT format: SE + org_number + 01 - vat_number = f"SE{org_digits}01" - variants.append(vat_number) - variants.append(vat_number.lower()) # se556123456701 - # With spaces: SE 5561234567 01 - variants.append(f"SE {org_digits} 01") - variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01") - # Without 01 suffix (some invoices show just SE + org) - variants.append(f"SE{org_digits}") - variants.append(f"SE {org_digits}") - - # Some may have 12 digits (century prefix): NNNNNNNN-NNNN - elif len(org_digits) == 12: - with_dash = f"{org_digits[:8]}-{org_digits[8:]}" - variants.append(with_dash) - # Also try without century prefix - short_version = org_digits[2:] - variants.append(short_version) - variants.append(f"{short_version[:6]}-{short_version[6:]}") - # VAT with short version - vat_number = f"SE{short_version}01" - variants.append(vat_number) - - return list(set(v for v in variants if v)) + return list(v for v in variants if v) @staticmethod def normalize_supplier_accounts(value: str) -> list[str]: diff --git a/src/ocr/machine_code_parser.py b/src/ocr/machine_code_parser.py index 7ba7f97..951773b 100644 --- a/src/ocr/machine_code_parser.py +++ b/src/ocr/machine_code_parser.py @@ -42,6 +42,7 @@ from dataclasses import dataclass, field from typing import Optional from src.pdf.extractor import Token as TextToken +from src.utils.validators import FieldValidators @dataclass @@ -484,21 +485,42 @@ class MachineCodeParser: def format_account(account_digits: str) -> tuple[str, str]: """Format account and determine type (bankgiro or plusgiro). + Uses context keywords first, then falls back to Luhn validation + to determine the most likely account type. + Returns: (formatted_account, account_type) """ if is_plusgiro_context: - # Plusgiro format: XXXXXXX-X + # Context explicitly indicates Plusgiro formatted = f"{account_digits[:-1]}-{account_digits[-1]}" return formatted, 'plusgiro' + + # No explicit context - use Luhn validation to determine type + # Try both formats and see which passes Luhn check + + # Format as Plusgiro: XXXXXXX-X (all digits, check digit at end) + pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}" + pg_valid = FieldValidators.is_valid_plusgiro(account_digits) + + # Format as Bankgiro: XXX-XXXX or XXXX-XXXX + if len(account_digits) == 7: + bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}" + elif len(account_digits) == 8: + bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}" else: - # Bankgiro format: XXX-XXXX or XXXX-XXXX - if len(account_digits) == 7: - formatted = f"{account_digits[:3]}-{account_digits[3:]}" - elif len(account_digits) == 8: - formatted = f"{account_digits[:4]}-{account_digits[4:]}" - else: - formatted = account_digits - return formatted, 'bankgiro' + bg_formatted = account_digits + bg_valid = FieldValidators.is_valid_bankgiro(account_digits) + + # Decision logic: + # 1. If only one format passes Luhn, use that + # 2. If both pass or both fail, default to Bankgiro (more common in payment lines) + if pg_valid and not bg_valid: + return pg_formatted, 'plusgiro' + elif bg_valid and not pg_valid: + return bg_formatted, 'bankgiro' + else: + # Both valid or both invalid - default to bankgiro + return bg_formatted, 'bankgiro' # Try primary pattern match = self.PAYMENT_LINE_PATTERN.search(raw_line) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..8dff22f --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,34 @@ +""" +Shared utilities for invoice field extraction and matching. + +This module provides common functionality used by both: +- Inference stage (field_extractor.py) - extracting values from OCR text +- Matching stage (normalizer.py) - generating variants for CSV matching + +Modules: +- TextCleaner: Unicode normalization and OCR error correction +- FormatVariants: Generate format variants for matching +- FieldValidators: Validate field values (Luhn, dates, amounts) +- FuzzyMatcher: Fuzzy string matching with OCR awareness +- OCRCorrections: Comprehensive OCR error correction +- ContextExtractor: Context-aware field extraction +""" + +from .text_cleaner import TextCleaner +from .format_variants import FormatVariants +from .validators import FieldValidators +from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult +from .ocr_corrections import OCRCorrections, CorrectionResult +from .context_extractor import ContextExtractor, ExtractionCandidate + +__all__ = [ + 'TextCleaner', + 'FormatVariants', + 'FieldValidators', + 'FuzzyMatcher', + 'FuzzyMatchResult', + 'OCRCorrections', + 'CorrectionResult', + 'ContextExtractor', + 'ExtractionCandidate', +] diff --git a/src/utils/context_extractor.py b/src/utils/context_extractor.py new file mode 100644 index 0000000..14eda78 --- /dev/null +++ b/src/utils/context_extractor.py @@ -0,0 +1,433 @@ +""" +Context-Aware Extraction Module + +Extracts field values using contextual cues and label detection. +Improves extraction accuracy by understanding the semantic context. +""" + +import re +from typing import Optional, NamedTuple +from dataclasses import dataclass + +from .text_cleaner import TextCleaner +from .validators import FieldValidators + + +@dataclass +class ExtractionCandidate: + """A candidate extracted value with metadata.""" + value: str + raw_text: str + context_label: str + confidence: float + position: int # Character position in source text + extraction_method: str # 'label', 'pattern', 'proximity' + + +class ContextExtractor: + """ + Context-aware field extraction. + + Uses multiple strategies: + 1. Label detection - finds values after field labels + 2. Pattern matching - uses field-specific regex patterns + 3. Proximity analysis - finds values near related terms + 4. Validation filtering - removes invalid candidates + """ + + # ========================================================================= + # Swedish Label Patterns (what appears before the value) + # ========================================================================= + + LABEL_PATTERNS = { + 'InvoiceNumber': [ + # Swedish + r'(?:faktura|fakt)\.?\s*(?:nr|nummer|#)?[:\s]*', + r'(?:fakturanummer|fakturanr)[:\s]*', + r'(?:vår\s+referens)[:\s]*', + # English + r'(?:invoice)\s*(?:no|number|#)?[:\s]*', + r'inv[.:\s]*#?', + ], + 'Amount': [ + # Swedish + r'(?:att\s+)?betala[:\s]*', + r'(?:total|totalt|summa)[:\s]*', + r'(?:belopp)[:\s]*', + r'(?:slutsumma)[:\s]*', + r'(?:att\s+erlägga)[:\s]*', + # English + r'(?:total|amount|sum)[:\s]*', + r'(?:amount\s+due)[:\s]*', + ], + 'InvoiceDate': [ + # Swedish + r'(?:faktura)?datum[:\s]*', + r'(?:fakt\.?\s*datum)[:\s]*', + # English + r'(?:invoice\s+)?date[:\s]*', + ], + 'InvoiceDueDate': [ + # Swedish + r'(?:förfall(?:o)?datum)[:\s]*', + r'(?:betalas\s+senast)[:\s]*', + r'(?:sista\s+betalningsdag)[:\s]*', + r'(?:förfaller)[:\s]*', + # English + r'(?:due\s+date)[:\s]*', + r'(?:payment\s+due)[:\s]*', + ], + 'OCR': [ + r'(?:ocr)[:\s]*', + r'(?:ocr\s*-?\s*nummer)[:\s]*', + r'(?:referens(?:nummer)?)[:\s]*', + r'(?:betalningsreferens)[:\s]*', + ], + 'Bankgiro': [ + r'(?:bankgiro|bg)[:\s]*', + r'(?:bank\s*giro)[:\s]*', + ], + 'Plusgiro': [ + r'(?:plusgiro|pg)[:\s]*', + r'(?:plus\s*giro)[:\s]*', + r'(?:postgiro)[:\s]*', + ], + 'supplier_organisation_number': [ + r'(?:org\.?\s*(?:nr|nummer)?)[:\s]*', + r'(?:organisationsnummer)[:\s]*', + r'(?:org\.?\s*id)[:\s]*', + r'(?:vat\s*(?:no|number|nr)?)[:\s]*', + r'(?:moms(?:reg)?\.?\s*(?:nr|nummer)?)[:\s]*', + r'(?:se)[:\s]*', # VAT prefix + ], + 'customer_number': [ + r'(?:kund(?:nr|nummer)?)[:\s]*', + r'(?:kundnummer)[:\s]*', + r'(?:customer\s*(?:no|number|id)?)[:\s]*', + r'(?:er\s+referens)[:\s]*', + ], + } + + # ========================================================================= + # Value Patterns (what the value looks like) + # ========================================================================= + + VALUE_PATTERNS = { + 'InvoiceNumber': [ + r'[A-Z]{0,3}\d{3,15}', # Alphanumeric: INV12345 + r'\d{3,15}', # Pure digits + r'20\d{2}[-/]\d{3,8}', # Year prefix: 2024-001 + ], + 'Amount': [ + r'\d{1,3}(?:[\s.]\d{3})*[,]\d{2}', # Swedish: 1 234,56 + r'\d{1,3}(?:[,]\d{3})*[.]\d{2}', # US: 1,234.56 + r'\d+[,.]\d{2}', # Simple: 123,45 + r'\d+', # Integer + ], + 'InvoiceDate': [ + r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}', # ISO-like + r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}', # European + r'\d{8}', # Compact YYYYMMDD + ], + 'InvoiceDueDate': [ + r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}', + r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}', + r'\d{8}', + ], + 'OCR': [ + r'\d{10,25}', # Long digit sequence + ], + 'Bankgiro': [ + r'\d{3,4}[-\s]?\d{4}', # XXX-XXXX or XXXX-XXXX + r'\d{7,8}', # Without separator + ], + 'Plusgiro': [ + r'\d{1,7}[-\s]?\d', # XXXXXXX-X + r'\d{2,8}', # Without separator + ], + 'supplier_organisation_number': [ + r'\d{6}[-\s]?\d{4}', # NNNNNN-NNNN + r'\d{10}', # Without separator + r'SE\s?\d{10,12}(?:\s?01)?', # VAT format + ], + 'customer_number': [ + r'[A-Z]{0,5}\s?[-]?\s?\d{1,10}', # EMM 256-6 + r'\d{3,15}', # Pure digits + ], + } + + # ========================================================================= + # Extraction Methods + # ========================================================================= + + @classmethod + def extract_with_label( + cls, + text: str, + field_name: str, + validate: bool = True + ) -> list[ExtractionCandidate]: + """ + Extract field values by finding labels and taking following values. + + Example: "Fakturanummer: 12345" -> extracts "12345" + """ + candidates = [] + label_patterns = cls.LABEL_PATTERNS.get(field_name, []) + value_patterns = cls.VALUE_PATTERNS.get(field_name, []) + + for label_pattern in label_patterns: + for value_pattern in value_patterns: + # Combine label + value patterns + full_pattern = f'({label_pattern})({value_pattern})' + matches = re.finditer(full_pattern, text, re.IGNORECASE) + + for match in matches: + label = match.group(1).strip() + value = match.group(2).strip() + + # Validate if requested + if validate and not cls._validate_value(field_name, value): + continue + + # Calculate confidence based on label specificity + confidence = cls._calculate_label_confidence(label, field_name) + + candidates.append(ExtractionCandidate( + value=value, + raw_text=match.group(0), + context_label=label, + confidence=confidence, + position=match.start(), + extraction_method='label' + )) + + return candidates + + @classmethod + def extract_with_pattern( + cls, + text: str, + field_name: str, + validate: bool = True + ) -> list[ExtractionCandidate]: + """ + Extract field values using only value patterns (no label required). + + This is a fallback when no labels are found. + """ + candidates = [] + value_patterns = cls.VALUE_PATTERNS.get(field_name, []) + + for pattern in value_patterns: + matches = re.finditer(pattern, text, re.IGNORECASE) + + for match in matches: + value = match.group(0).strip() + + # Validate if requested + if validate and not cls._validate_value(field_name, value): + continue + + # Lower confidence for pattern-only extraction + confidence = 0.6 + + candidates.append(ExtractionCandidate( + value=value, + raw_text=value, + context_label='', + confidence=confidence, + position=match.start(), + extraction_method='pattern' + )) + + return candidates + + @classmethod + def extract_field( + cls, + text: str, + field_name: str, + validate: bool = True + ) -> list[ExtractionCandidate]: + """ + Extract all candidate values for a field using multiple strategies. + + Returns candidates sorted by confidence (highest first). + """ + candidates = [] + + # Strategy 1: Label-based extraction (highest confidence) + label_candidates = cls.extract_with_label(text, field_name, validate) + candidates.extend(label_candidates) + + # Strategy 2: Pattern-based extraction (fallback) + if not label_candidates: + pattern_candidates = cls.extract_with_pattern(text, field_name, validate) + candidates.extend(pattern_candidates) + + # Remove duplicates (same value, keep highest confidence) + seen_values = {} + for candidate in candidates: + normalized = TextCleaner.normalize_for_comparison(candidate.value) + if normalized not in seen_values or candidate.confidence > seen_values[normalized].confidence: + seen_values[normalized] = candidate + + # Sort by confidence + result = sorted(seen_values.values(), key=lambda x: x.confidence, reverse=True) + return result + + @classmethod + def extract_best( + cls, + text: str, + field_name: str, + validate: bool = True + ) -> Optional[ExtractionCandidate]: + """ + Extract the best (highest confidence) candidate for a field. + """ + candidates = cls.extract_field(text, field_name, validate) + return candidates[0] if candidates else None + + @classmethod + def extract_all_fields(cls, text: str) -> dict[str, list[ExtractionCandidate]]: + """ + Extract all known fields from text. + + Returns a dictionary mapping field names to their candidates. + """ + results = {} + for field_name in cls.LABEL_PATTERNS.keys(): + candidates = cls.extract_field(text, field_name) + if candidates: + results[field_name] = candidates + return results + + # ========================================================================= + # Helper Methods + # ========================================================================= + + @classmethod + def _validate_value(cls, field_name: str, value: str) -> bool: + """Validate a value based on field type.""" + field_lower = field_name.lower() + + if 'date' in field_lower: + return FieldValidators.is_valid_date(value) + elif 'amount' in field_lower: + return FieldValidators.is_valid_amount(value) + elif 'bankgiro' in field_lower: + # Basic format check, not Luhn + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + return 7 <= len(digits) <= 8 + elif 'plusgiro' in field_lower: + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + return 2 <= len(digits) <= 8 + elif 'ocr' in field_lower: + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + return 10 <= len(digits) <= 25 + elif 'org' in field_lower: + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + return len(digits) >= 10 + else: + # For other fields, just check it's not empty + return bool(value.strip()) + + @classmethod + def _calculate_label_confidence(cls, label: str, field_name: str) -> float: + """ + Calculate confidence based on how specific the label is. + + More specific labels = higher confidence. + """ + label_lower = label.lower() + + # Very specific labels + very_specific = { + 'InvoiceNumber': ['fakturanummer', 'invoice number', 'fakturanr'], + 'Amount': ['att betala', 'slutsumma', 'amount due'], + 'InvoiceDate': ['fakturadatum', 'invoice date'], + 'InvoiceDueDate': ['förfallodatum', 'förfallodag', 'due date'], + 'OCR': ['ocr', 'betalningsreferens'], + 'Bankgiro': ['bankgiro'], + 'Plusgiro': ['plusgiro', 'postgiro'], + 'supplier_organisation_number': ['organisationsnummer', 'org nummer'], + 'customer_number': ['kundnummer', 'customer number'], + } + + # Check for very specific match + if field_name in very_specific: + for specific in very_specific[field_name]: + if specific in label_lower: + return 0.95 + + # Moderately specific + moderate = { + 'InvoiceNumber': ['faktura', 'invoice', 'nr'], + 'Amount': ['total', 'summa', 'belopp'], + 'InvoiceDate': ['datum', 'date'], + 'InvoiceDueDate': ['förfall', 'due'], + } + + if field_name in moderate: + for mod in moderate[field_name]: + if mod in label_lower: + return 0.85 + + # Generic match + return 0.75 + + @classmethod + def find_field_context(cls, text: str, position: int, window: int = 50) -> str: + """ + Get the surrounding context for a position in text. + + Useful for understanding what field a value belongs to. + """ + start = max(0, position - window) + end = min(len(text), position + window) + return text[start:end] + + @classmethod + def identify_field_type(cls, text: str, value: str) -> Optional[str]: + """ + Try to identify what field type a value belongs to based on context. + + Looks at text surrounding the value to find labels. + """ + # Find the value in text + pos = text.find(value) + if pos == -1: + return None + + # Get context before the value + context_before = text[max(0, pos - 50):pos].lower() + + # Check each field's labels + for field_name, patterns in cls.LABEL_PATTERNS.items(): + for pattern in patterns: + if re.search(pattern, context_before, re.IGNORECASE): + return field_name + + return None + + +# ========================================================================= +# Convenience functions +# ========================================================================= + +def extract_field_with_context(text: str, field_name: str) -> Optional[str]: + """Convenience function to extract a field value.""" + candidate = ContextExtractor.extract_best(text, field_name) + return candidate.value if candidate else None + + +def extract_all_with_context(text: str) -> dict[str, str]: + """Convenience function to extract all fields.""" + all_candidates = ContextExtractor.extract_all_fields(text) + return { + field: candidates[0].value + for field, candidates in all_candidates.items() + if candidates + } diff --git a/src/utils/format_variants.py b/src/utils/format_variants.py new file mode 100644 index 0000000..583e526 --- /dev/null +++ b/src/utils/format_variants.py @@ -0,0 +1,610 @@ +""" +Format Variants Generator + +Generates multiple format variants for invoice field values. +Used by both inference (to try different extractions) and matching (to match CSV values). +""" + +import re +from datetime import datetime +from typing import Optional + +from .text_cleaner import TextCleaner + + +class FormatVariants: + """ + Generates format variants for different field types. + + The same logic is used for: + - Inference: trying different formats to extract a value + - Matching: generating variants of CSV values to match against OCR text + """ + + # Swedish month names for date parsing + SWEDISH_MONTHS = { + 'januari': '01', 'jan': '01', + 'februari': '02', 'feb': '02', + 'mars': '03', 'mar': '03', + 'april': '04', 'apr': '04', + 'maj': '05', + 'juni': '06', 'jun': '06', + 'juli': '07', 'jul': '07', + 'augusti': '08', 'aug': '08', + 'september': '09', 'sep': '09', 'sept': '09', + 'oktober': '10', 'okt': '10', + 'november': '11', 'nov': '11', + 'december': '12', 'dec': '12', + } + + # ========================================================================= + # Organization Number Variants + # ========================================================================= + + @classmethod + def organisation_number_variants(cls, value: str) -> list[str]: + """ + Generate all format variants for Swedish organization number. + + Input formats handled: + - "556123-4567" (standard with hyphen) + - "5561234567" (no hyphen) + - "SE556123456701" (VAT format) + - "SE 556123-4567 01" (VAT with spaces) + + Returns all possible variants for matching. + """ + value = TextCleaner.clean_text(value) + value_upper = value.upper() + variants = set() + + # 提取纯数字 + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + # 如果是 VAT 格式,提取中间的 org number + # SE + 10 digits + 01 = "SE556123456701" + if value_upper.startswith('SE') and len(digits) == 12 and digits.endswith('01'): + # VAT format: SE + org_number + 01 + digits = digits[:10] + elif digits.startswith('46') and len(digits) == 14: + # SE prefix in numeric (46 is SE in phone code): 46 + 10 digits + 01 + digits = digits[2:12] + + if len(digits) == 12: + # 12 位数字可能是带世纪前缀的: NNNNNNNN-NNNN (19556123-4567) + variants.add(value) + variants.add(digits) # 195561234567 + # 带横线格式 + variants.add(f"{digits[:8]}-{digits[8:]}") # 19556123-4567 + # 提取后 10 位作为标准 org number + short_digits = digits[2:] # 5561234567 + variants.add(short_digits) + variants.add(f"{short_digits[:6]}-{short_digits[6:]}") # 556123-4567 + # VAT 格式 + variants.add(f"SE{short_digits}01") # SE556123456701 + return list(v for v in variants if v) + + if len(digits) != 10: + # 如果不是标准 10 位,返回原始值和清洗后的变体 + variants.add(value) + if digits: + variants.add(digits) + return list(variants) + + # 生成所有变体 + # 1. 纯数字 + variants.add(digits) # 5561234567 + + # 2. 标准格式 (NNNNNN-NNNN) + with_hyphen = f"{digits[:6]}-{digits[6:]}" + variants.add(with_hyphen) # 556123-4567 + + # 3. VAT 格式 + vat_compact = f"SE{digits}01" + variants.add(vat_compact) # SE556123456701 + variants.add(vat_compact.lower()) # se556123456701 + + vat_spaced = f"SE {digits[:6]}-{digits[6:]} 01" + variants.add(vat_spaced) # SE 556123-4567 01 + + vat_spaced_no_hyphen = f"SE {digits} 01" + variants.add(vat_spaced_no_hyphen) # SE 5561234567 01 + + # 4. 有时带国家代码但无 01 后缀 + variants.add(f"SE{digits}") # SE5561234567 + variants.add(f"SE {digits}") # SE 5561234567 + variants.add(f"SE{digits[:6]}-{digits[6:]}") # SE556123-4567 + + # 5. OCR 可能的错误变体 + ocr_variants = TextCleaner.generate_ocr_variants(digits) + for ocr_var in ocr_variants: + if len(ocr_var) == 10: + variants.add(ocr_var) + variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}") + + return list(v for v in variants if v) + + # ========================================================================= + # Bankgiro Variants + # ========================================================================= + + @classmethod + def bankgiro_variants(cls, value: str) -> list[str]: + """ + Generate variants for Bankgiro number. + + Formats: + - 7 digits: XXX-XXXX (e.g., 123-4567) + - 8 digits: XXXX-XXXX (e.g., 1234-5678) + """ + value = TextCleaner.clean_text(value) + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + variants = set() + + variants.add(value) + + if not digits or len(digits) < 7 or len(digits) > 8: + return list(v for v in variants if v) + + # 纯数字 + variants.add(digits) + + # 带横线格式 + if len(digits) == 7: + variants.add(f"{digits[:3]}-{digits[3:]}") # XXX-XXXX + elif len(digits) == 8: + variants.add(f"{digits[:4]}-{digits[4:]}") # XXXX-XXXX + # 有些 8 位也用 XXX-XXXXX 格式 + variants.add(f"{digits[:3]}-{digits[3:]}") + + # 带空格格式 (有时 OCR 会这样识别) + if len(digits) == 7: + variants.add(f"{digits[:3]} {digits[3:]}") + elif len(digits) == 8: + variants.add(f"{digits[:4]} {digits[4:]}") + + # OCR 错误变体 + for ocr_var in TextCleaner.generate_ocr_variants(digits): + variants.add(ocr_var) + + return list(v for v in variants if v) + + # ========================================================================= + # Plusgiro Variants + # ========================================================================= + + @classmethod + def plusgiro_variants(cls, value: str) -> list[str]: + """ + Generate variants for Plusgiro number. + + Format: XXXXXXX-X (7 digits + check digit) or shorter + Examples: 1234567-8, 12345-6, 1-8 + """ + value = TextCleaner.clean_text(value) + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + variants = set() + + variants.add(value) + + if not digits or len(digits) < 2 or len(digits) > 8: + return list(v for v in variants if v) + + # 纯数字 + variants.add(digits) + + # Plusgiro 格式: 最后一位是校验位,用横线分隔 + main_part = digits[:-1] + check_digit = digits[-1] + variants.add(f"{main_part}-{check_digit}") + + # 有时带空格 + variants.add(f"{main_part} {check_digit}") + + # 分组格式 (常见于长号码): XX XX XX-X + if len(digits) >= 6: + # 尝试 XX XX XX-X 格式 + spaced = ' '.join([digits[i:i + 2] for i in range(0, len(digits) - 1, 2)]) + if len(digits) % 2 == 0: + spaced = spaced[:-1] + '-' + digits[-1] + else: + spaced = spaced + '-' + digits[-1] + variants.add(spaced.replace('- ', '-')) + + # OCR 错误变体 + for ocr_var in TextCleaner.generate_ocr_variants(digits): + variants.add(ocr_var) + + return list(v for v in variants if v) + + # ========================================================================= + # Amount Variants + # ========================================================================= + + @classmethod + def amount_variants(cls, value: str) -> list[str]: + """ + Generate variants for monetary amounts. + + Handles: + - Swedish: 1 234,56 (space thousand, comma decimal) + - German: 1.234,56 (dot thousand, comma decimal) + - US/UK: 1,234.56 (comma thousand, dot decimal) + - Integer: 1234 -> 1234.00 + + Returns variants with different separators and with/without decimals. + """ + value = TextCleaner.clean_text(value) + variants = set() + variants.add(value) + + # 尝试解析为数值 + amount = cls._parse_amount(value) + if amount is None: + return list(v for v in variants if v) + + # 生成不同格式的变体 + int_part = int(amount) + dec_part = round((amount - int_part) * 100) + + # 1. 基础格式 + if dec_part == 0: + variants.add(str(int_part)) # 1234 + variants.add(f"{int_part}.00") # 1234.00 + variants.add(f"{int_part},00") # 1234,00 + else: + variants.add(f"{int_part}.{dec_part:02d}") # 1234.56 + variants.add(f"{int_part},{dec_part:02d}") # 1234,56 + + # 2. 带千位分隔符 + int_str = str(int_part) + if len(int_str) > 3: + # 从右往左每3位加分隔符 + parts = [] + while int_str: + parts.append(int_str[-3:]) + int_str = int_str[:-3] + parts.reverse() + + # 空格分隔 (Swedish) + space_sep = ' '.join(parts) + if dec_part == 0: + variants.add(space_sep) + else: + variants.add(f"{space_sep},{dec_part:02d}") + variants.add(f"{space_sep}.{dec_part:02d}") + + # 点分隔 (German) + dot_sep = '.'.join(parts) + if dec_part == 0: + variants.add(dot_sep) + else: + variants.add(f"{dot_sep},{dec_part:02d}") + + # 逗号分隔 (US) + comma_sep = ','.join(parts) + if dec_part == 0: + variants.add(comma_sep) + else: + variants.add(f"{comma_sep}.{dec_part:02d}") + + # 3. 带货币符号 + base_amounts = [f"{int_part}.{dec_part:02d}", f"{int_part},{dec_part:02d}"] + if dec_part == 0: + base_amounts.append(str(int_part)) + + for base in base_amounts: + variants.add(f"{base} kr") + variants.add(f"{base} SEK") + variants.add(f"{base}kr") + variants.add(f"SEK {base}") + + return list(v for v in variants if v) + + @classmethod + def _parse_amount(cls, text: str) -> Optional[float]: + """Parse amount from various formats.""" + text = TextCleaner.normalize_amount_text(text) + + # 移除所有非数字和分隔符 + clean = re.sub(r'[^\d,.\s]', '', text) + if not clean: + return None + + # 检测格式 + # 瑞典格式: 1 234,56 或 1234,56 + if re.match(r'^[\d\s]+,\d{2}$', clean): + clean = clean.replace(' ', '').replace(',', '.') + try: + return float(clean) + except ValueError: + pass + + # 德国格式: 1.234,56 + if re.match(r'^[\d.]+,\d{2}$', clean): + clean = clean.replace('.', '').replace(',', '.') + try: + return float(clean) + except ValueError: + pass + + # 美国格式: 1,234.56 + if re.match(r'^[\d,]+\.\d{2}$', clean): + clean = clean.replace(',', '') + try: + return float(clean) + except ValueError: + pass + + # 简单格式 + clean = clean.replace(' ', '').replace(',', '.') + # 如果有多个点,只保留最后一个 + if clean.count('.') > 1: + parts = clean.rsplit('.', 1) + clean = parts[0].replace('.', '') + '.' + parts[1] + + try: + return float(clean) + except ValueError: + return None + + # ========================================================================= + # Date Variants + # ========================================================================= + + @classmethod + def date_variants(cls, value: str) -> list[str]: + """ + Generate variants for dates. + + Input can be: + - ISO: 2024-12-29 + - European: 29/12/2024, 29.12.2024 + - Swedish text: "29 december 2024" + - Compact: 20241229 + + Returns all format variants. + """ + value = TextCleaner.clean_text(value) + variants = set() + variants.add(value) + + # 尝试解析日期 + parsed = cls._parse_date(value) + if parsed is None: + return list(v for v in variants if v) + + year, month, day = parsed + + # 生成所有格式变体 + # ISO + variants.add(f"{year}-{month:02d}-{day:02d}") + variants.add(f"{year}-{month}-{day}") # 不补零 + + # 点分隔 (Swedish common) + variants.add(f"{year}.{month:02d}.{day:02d}") + variants.add(f"{day:02d}.{month:02d}.{year}") + + # 斜杠分隔 + variants.add(f"{day:02d}/{month:02d}/{year}") + variants.add(f"{month:02d}/{day:02d}/{year}") # US format + variants.add(f"{year}/{month:02d}/{day:02d}") + + # 紧凑格式 + variants.add(f"{year}{month:02d}{day:02d}") + + # 带月份名 (Swedish) + for month_name, month_num in cls.SWEDISH_MONTHS.items(): + if month_num == f"{month:02d}": + variants.add(f"{day} {month_name} {year}") + variants.add(f"{day:02d} {month_name} {year}") + # 首字母大写 + variants.add(f"{day} {month_name.capitalize()} {year}") + + # 短年份 + short_year = str(year)[2:] + variants.add(f"{day:02d}.{month:02d}.{short_year}") + variants.add(f"{day:02d}/{month:02d}/{short_year}") + variants.add(f"{short_year}-{month:02d}-{day:02d}") + + return list(v for v in variants if v) + + @classmethod + def _parse_date(cls, text: str) -> Optional[tuple[int, int, int]]: + """ + Parse date from text, returns (year, month, day) or None. + """ + text = TextCleaner.clean_text(text).lower() + + # ISO: 2024-12-29 + match = re.search(r'(\d{4})-(\d{1,2})-(\d{1,2})', text) + if match: + return int(match.group(1)), int(match.group(2)), int(match.group(3)) + + # Dot format: 2024.12.29 + match = re.search(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', text) + if match: + return int(match.group(1)), int(match.group(2)), int(match.group(3)) + + # European: 29/12/2024 or 29.12.2024 + match = re.search(r'(\d{1,2})[/.](\d{1,2})[/.](\d{4})', text) + if match: + day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3)) + # 验证日期合理性 + if 1 <= day <= 31 and 1 <= month <= 12: + return year, month, day + + # Compact: 20241229 + match = re.search(r'(? list[str]: + """ + Generate variants for invoice numbers. + + Invoice numbers are highly variable: + - Pure digits: 12345678 + - Alphanumeric: A3861, INV-2024-001 + - With separators: 2024/001 + """ + value = TextCleaner.clean_text(value) + variants = set() + variants.add(value) + + # 提取数字部分 + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + if digits: + variants.add(digits) + + # 大小写变体 + variants.add(value.upper()) + variants.add(value.lower()) + + # 移除分隔符 + no_sep = re.sub(r'[-/\s]', '', value) + variants.add(no_sep) + variants.add(no_sep.upper()) + + # OCR 错误变体 + for ocr_var in TextCleaner.generate_ocr_variants(value): + variants.add(ocr_var) + + return list(v for v in variants if v) + + # ========================================================================= + # OCR Number Variants + # ========================================================================= + + @classmethod + def ocr_number_variants(cls, value: str) -> list[str]: + """ + Generate variants for OCR reference numbers. + + OCR numbers are typically 10-25 digits. + """ + value = TextCleaner.clean_text(value) + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + variants = set() + + variants.add(value) + + if digits: + variants.add(digits) + + # 有些 OCR 号码带空格分组 + if len(digits) > 4: + # 每 4 位分组 + spaced = ' '.join([digits[i:i + 4] for i in range(0, len(digits), 4)]) + variants.add(spaced) + + # OCR 错误变体 + for ocr_var in TextCleaner.generate_ocr_variants(digits): + variants.add(ocr_var) + + return list(v for v in variants if v) + + # ========================================================================= + # Customer Number Variants + # ========================================================================= + + @classmethod + def customer_number_variants(cls, value: str) -> list[str]: + """ + Generate variants for customer numbers. + + Customer numbers can be very diverse: + - Pure digits: 12345 + - Alphanumeric: ABC123, EMM 256-6 + - With separators: 123-456 + """ + value = TextCleaner.clean_text(value) + variants = set() + variants.add(value) + + # 大小写 + variants.add(value.upper()) + variants.add(value.lower()) + + # 移除所有分隔符和空格 + compact = re.sub(r'[-/\s]', '', value) + variants.add(compact) + variants.add(compact.upper()) + variants.add(compact.lower()) + + # 纯数字 + digits = TextCleaner.extract_digits(value, apply_ocr_correction=False) + if digits: + variants.add(digits) + + # 纯字母 + 数字 (分离) + letters = re.sub(r'[^a-zA-Z]', '', value) + if letters and digits: + variants.add(f"{letters}{digits}") + variants.add(f"{letters.upper()}{digits}") + variants.add(f"{letters} {digits}") + variants.add(f"{letters.upper()} {digits}") + variants.add(f"{letters}-{digits}") + variants.add(f"{letters.upper()}-{digits}") + + # OCR 错误变体 + for ocr_var in TextCleaner.generate_ocr_variants(value): + variants.add(ocr_var) + + return list(v for v in variants if v) + + # ========================================================================= + # Generic Field Variants + # ========================================================================= + + @classmethod + def get_variants(cls, field_name: str, value: str) -> list[str]: + """ + Get variants for a field by name. + + This is the main entry point - dispatches to specific variant generators. + """ + if not value: + return [] + + field_lower = field_name.lower() + + # 映射字段名到变体生成器 + if 'organisation' in field_lower or 'org' in field_lower: + return cls.organisation_number_variants(value) + elif 'bankgiro' in field_lower or field_lower == 'bg': + return cls.bankgiro_variants(value) + elif 'plusgiro' in field_lower or field_lower == 'pg': + return cls.plusgiro_variants(value) + elif 'amount' in field_lower or 'belopp' in field_lower: + return cls.amount_variants(value) + elif 'date' in field_lower or 'datum' in field_lower: + return cls.date_variants(value) + elif 'invoice' in field_lower and 'number' in field_lower: + return cls.invoice_number_variants(value) + elif field_lower == 'invoicenumber': + return cls.invoice_number_variants(value) + elif 'ocr' in field_lower: + return cls.ocr_number_variants(value) + elif 'customer' in field_lower: + return cls.customer_number_variants(value) + else: + # 默认: 返回原值和基本清洗 + return [value, TextCleaner.clean_text(value)] diff --git a/src/utils/fuzzy_matcher.py b/src/utils/fuzzy_matcher.py new file mode 100644 index 0000000..aeb257c --- /dev/null +++ b/src/utils/fuzzy_matcher.py @@ -0,0 +1,417 @@ +""" +Fuzzy Matching Module + +Provides fuzzy string matching with OCR-aware similarity scoring. +Handles common OCR errors and format variations in invoice fields. +""" + +import re +from typing import Optional +from dataclasses import dataclass + +from .text_cleaner import TextCleaner + + +@dataclass +class FuzzyMatchResult: + """Result of a fuzzy match operation.""" + matched: bool + score: float # 0.0 to 1.0 + ocr_value: str + expected_value: str + normalized_ocr: str + normalized_expected: str + match_type: str # 'exact', 'normalized', 'fuzzy', 'ocr_corrected' + + +class FuzzyMatcher: + """ + Fuzzy string matcher optimized for OCR text matching. + + Provides multiple matching strategies: + 1. Exact match + 2. Normalized match (case-insensitive, whitespace-normalized) + 3. OCR-corrected match (applying common OCR error corrections) + 4. Edit distance based fuzzy match + 5. Digit-sequence match (for numeric fields) + """ + + # Minimum similarity threshold for fuzzy matches + DEFAULT_THRESHOLD = 0.85 + + # Field-specific thresholds (some fields need stricter matching) + FIELD_THRESHOLDS = { + 'InvoiceNumber': 0.90, + 'OCR': 0.95, # OCR numbers need high precision + 'Amount': 0.95, + 'Bankgiro': 0.90, + 'Plusgiro': 0.90, + 'InvoiceDate': 0.90, + 'InvoiceDueDate': 0.90, + 'supplier_organisation_number': 0.85, + 'customer_number': 0.80, # More lenient for customer numbers + } + + @classmethod + def get_threshold(cls, field_name: str) -> float: + """Get the matching threshold for a specific field.""" + return cls.FIELD_THRESHOLDS.get(field_name, cls.DEFAULT_THRESHOLD) + + @classmethod + def levenshtein_distance(cls, s1: str, s2: str) -> int: + """ + Calculate Levenshtein (edit) distance between two strings. + + This is the minimum number of single-character edits + (insertions, deletions, substitutions) needed to change s1 into s2. + """ + if len(s1) < len(s2): + return cls.levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = range(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + # Cost is 0 if characters match, 1 otherwise + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + @classmethod + def similarity_ratio(cls, s1: str, s2: str) -> float: + """ + Calculate similarity ratio between two strings. + + Returns a value between 0.0 (completely different) and 1.0 (identical). + Based on Levenshtein distance normalized by the length of the longer string. + """ + if not s1 and not s2: + return 1.0 + if not s1 or not s2: + return 0.0 + + max_len = max(len(s1), len(s2)) + distance = cls.levenshtein_distance(s1, s2) + return 1.0 - (distance / max_len) + + @classmethod + def ocr_aware_similarity(cls, ocr_text: str, expected: str) -> float: + """ + Calculate similarity with OCR error awareness. + + This method considers common OCR errors when calculating similarity, + giving higher scores when differences are likely OCR mistakes. + """ + if not ocr_text or not expected: + return 0.0 if ocr_text != expected else 1.0 + + # First try exact match + if ocr_text == expected: + return 1.0 + + # Try with OCR corrections applied to ocr_text + corrected = TextCleaner.apply_ocr_digit_corrections(ocr_text) + if corrected == expected: + return 0.98 # Slightly less than exact match + + # Try normalized comparison + norm_ocr = TextCleaner.normalize_for_comparison(ocr_text) + norm_expected = TextCleaner.normalize_for_comparison(expected) + if norm_ocr == norm_expected: + return 0.95 + + # Calculate base similarity + base_sim = cls.similarity_ratio(norm_ocr, norm_expected) + + # Boost score if differences are common OCR errors + boost = cls._calculate_ocr_error_boost(ocr_text, expected) + + return min(1.0, base_sim + boost) + + @classmethod + def _calculate_ocr_error_boost(cls, ocr_text: str, expected: str) -> float: + """ + Calculate a score boost based on whether differences are likely OCR errors. + + Returns a value between 0.0 and 0.1. + """ + if len(ocr_text) != len(expected): + return 0.0 + + ocr_errors = 0 + total_diffs = 0 + + for oc, ec in zip(ocr_text, expected): + if oc != ec: + total_diffs += 1 + # Check if this is a known OCR confusion pair + if cls._is_ocr_confusion_pair(oc, ec): + ocr_errors += 1 + + if total_diffs == 0: + return 0.0 + + # Boost proportional to how many differences are OCR errors + ocr_error_ratio = ocr_errors / total_diffs + return ocr_error_ratio * 0.1 + + @classmethod + def _is_ocr_confusion_pair(cls, char1: str, char2: str) -> bool: + """Check if two characters are commonly confused in OCR.""" + confusion_pairs = { + ('0', 'O'), ('0', 'o'), ('0', 'D'), ('0', 'Q'), + ('1', 'l'), ('1', 'I'), ('1', 'i'), ('1', '|'), + ('2', 'Z'), ('2', 'z'), + ('5', 'S'), ('5', 's'), + ('6', 'G'), ('6', 'b'), + ('8', 'B'), + ('9', 'g'), ('9', 'q'), + } + + pair = (char1, char2) + return pair in confusion_pairs or (char2, char1) in confusion_pairs + + @classmethod + def match_digits(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult: + """ + Match digit sequences with OCR error tolerance. + + Optimized for numeric fields like OCR numbers, amounts, etc. + """ + # Extract digits + ocr_digits = TextCleaner.extract_digits(ocr_text, apply_ocr_correction=True) + expected_digits = TextCleaner.extract_digits(expected, apply_ocr_correction=False) + + # Exact match after extraction + if ocr_digits == expected_digits: + return FuzzyMatchResult( + matched=True, + score=1.0, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_digits, + normalized_expected=expected_digits, + match_type='exact' + ) + + # Calculate similarity + score = cls.ocr_aware_similarity(ocr_digits, expected_digits) + + return FuzzyMatchResult( + matched=score >= threshold, + score=score, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_digits, + normalized_expected=expected_digits, + match_type='fuzzy' if score >= threshold else 'no_match' + ) + + @classmethod + def match_amount(cls, ocr_text: str, expected: str, threshold: float = 0.95) -> FuzzyMatchResult: + """ + Match monetary amounts with format tolerance. + + Handles different decimal separators (. vs ,) and thousand separators. + """ + from .validators import FieldValidators + + # Parse both amounts + ocr_amount = FieldValidators.parse_amount(ocr_text) + expected_amount = FieldValidators.parse_amount(expected) + + if ocr_amount is None or expected_amount is None: + # Can't parse, fall back to string matching + return cls.match_string(ocr_text, expected, threshold) + + # Compare numeric values + if abs(ocr_amount - expected_amount) < 0.01: # Within 1 cent + return FuzzyMatchResult( + matched=True, + score=1.0, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=f"{ocr_amount:.2f}", + normalized_expected=f"{expected_amount:.2f}", + match_type='exact' + ) + + # Calculate relative difference + max_val = max(abs(ocr_amount), abs(expected_amount)) + if max_val > 0: + diff_ratio = abs(ocr_amount - expected_amount) / max_val + score = max(0.0, 1.0 - diff_ratio) + else: + score = 1.0 if ocr_amount == expected_amount else 0.0 + + return FuzzyMatchResult( + matched=score >= threshold, + score=score, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=f"{ocr_amount:.2f}" if ocr_amount else ocr_text, + normalized_expected=f"{expected_amount:.2f}" if expected_amount else expected, + match_type='fuzzy' if score >= threshold else 'no_match' + ) + + @classmethod + def match_date(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult: + """ + Match dates with format tolerance. + + Handles different date formats (ISO, European, compact, etc.) + """ + from .validators import FieldValidators + + # Parse both dates to ISO format + ocr_iso = FieldValidators.format_date_iso(ocr_text) + expected_iso = FieldValidators.format_date_iso(expected) + + if ocr_iso and expected_iso: + if ocr_iso == expected_iso: + return FuzzyMatchResult( + matched=True, + score=1.0, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_iso, + normalized_expected=expected_iso, + match_type='exact' + ) + + # Fall back to string matching on digits + return cls.match_digits(ocr_text, expected, threshold) + + @classmethod + def match_string(cls, ocr_text: str, expected: str, threshold: float = 0.85) -> FuzzyMatchResult: + """ + General string matching with multiple strategies. + + Tries exact, normalized, and fuzzy matching in order. + """ + # Clean both strings + ocr_clean = TextCleaner.clean_text(ocr_text) + expected_clean = TextCleaner.clean_text(expected) + + # Strategy 1: Exact match + if ocr_clean == expected_clean: + return FuzzyMatchResult( + matched=True, + score=1.0, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_clean, + normalized_expected=expected_clean, + match_type='exact' + ) + + # Strategy 2: Case-insensitive match + if ocr_clean.lower() == expected_clean.lower(): + return FuzzyMatchResult( + matched=True, + score=0.98, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_clean, + normalized_expected=expected_clean, + match_type='normalized' + ) + + # Strategy 3: OCR-corrected match + ocr_corrected = TextCleaner.apply_ocr_digit_corrections(ocr_clean) + if ocr_corrected == expected_clean: + return FuzzyMatchResult( + matched=True, + score=0.95, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_corrected, + normalized_expected=expected_clean, + match_type='ocr_corrected' + ) + + # Strategy 4: Fuzzy match + score = cls.ocr_aware_similarity(ocr_clean, expected_clean) + + return FuzzyMatchResult( + matched=score >= threshold, + score=score, + ocr_value=ocr_text, + expected_value=expected, + normalized_ocr=ocr_clean, + normalized_expected=expected_clean, + match_type='fuzzy' if score >= threshold else 'no_match' + ) + + @classmethod + def match_field( + cls, + field_name: str, + ocr_value: str, + expected_value: str, + threshold: Optional[float] = None + ) -> FuzzyMatchResult: + """ + Match a field value using field-appropriate strategy. + + Automatically selects the best matching strategy based on field type. + """ + if threshold is None: + threshold = cls.get_threshold(field_name) + + field_lower = field_name.lower() + + # Route to appropriate matcher + if 'amount' in field_lower or 'belopp' in field_lower: + return cls.match_amount(ocr_value, expected_value, threshold) + + if 'date' in field_lower or 'datum' in field_lower: + return cls.match_date(ocr_value, expected_value, threshold) + + if any(x in field_lower for x in ['ocr', 'bankgiro', 'plusgiro', 'org']): + # Numeric fields with OCR errors + return cls.match_digits(ocr_value, expected_value, threshold) + + if 'invoice' in field_lower and 'number' in field_lower: + # Invoice numbers can be alphanumeric + return cls.match_string(ocr_value, expected_value, threshold) + + # Default to string matching + return cls.match_string(ocr_value, expected_value, threshold) + + @classmethod + def find_best_match( + cls, + ocr_value: str, + candidates: list[str], + field_name: str = '', + threshold: Optional[float] = None + ) -> Optional[tuple[str, FuzzyMatchResult]]: + """ + Find the best matching candidate from a list. + + Returns (matched_value, match_result) or None if no match above threshold. + """ + if threshold is None: + threshold = cls.get_threshold(field_name) if field_name else cls.DEFAULT_THRESHOLD + + best_match = None + best_result = None + + for candidate in candidates: + result = cls.match_field(field_name, ocr_value, candidate, threshold=0.0) + if result.score >= threshold: + if best_result is None or result.score > best_result.score: + best_match = candidate + best_result = result + + if best_match: + return (best_match, best_result) + return None diff --git a/src/utils/ocr_corrections.py b/src/utils/ocr_corrections.py new file mode 100644 index 0000000..be59613 --- /dev/null +++ b/src/utils/ocr_corrections.py @@ -0,0 +1,384 @@ +""" +OCR Error Corrections Module + +Provides comprehensive OCR error correction tables and correction functions. +Based on common OCR recognition errors in Swedish invoice documents. +""" + +import re +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CorrectionResult: + """Result of an OCR correction operation.""" + original: str + corrected: str + corrections_applied: list[tuple[int, str, str]] # (position, from_char, to_char) + confidence: float # How confident we are in the correction + + +class OCRCorrections: + """ + Comprehensive OCR error correction utilities. + + Provides: + - Character-level corrections for digits + - Word-level corrections for common Swedish terms + - Context-aware corrections + - Multiple correction strategies + """ + + # ========================================================================= + # Character-level OCR errors (digit fields) + # ========================================================================= + + # Characters commonly misread as digits + CHAR_TO_DIGIT = { + # Letters that look like digits + 'O': '0', 'o': '0', # O -> 0 + 'Q': '0', # Q -> 0 (less common) + 'D': '0', # D -> 0 (in some fonts) + + 'l': '1', 'I': '1', # l/I -> 1 + 'i': '1', # i without dot -> 1 + '|': '1', # pipe -> 1 + '!': '1', # exclamation -> 1 + + 'Z': '2', 'z': '2', # Z -> 2 + + 'E': '3', # E -> 3 (rare) + + 'A': '4', 'h': '4', # A/h -> 4 (in some fonts) + + 'S': '5', 's': '5', # S -> 5 + + 'G': '6', 'b': '6', # G/b -> 6 + + 'T': '7', 't': '7', # T -> 7 (rare) + + 'B': '8', # B -> 8 + + 'g': '9', 'q': '9', # g/q -> 9 + } + + # Digits commonly misread as other characters + DIGIT_TO_CHAR = { + '0': ['O', 'o', 'D', 'Q'], + '1': ['l', 'I', 'i', '|', '!'], + '2': ['Z', 'z'], + '3': ['E'], + '4': ['A', 'h'], + '5': ['S', 's'], + '6': ['G', 'b'], + '7': ['T', 't'], + '8': ['B'], + '9': ['g', 'q'], + } + + # Bidirectional confusion pairs (either direction is possible) + CONFUSION_PAIRS = [ + ('0', 'O'), ('0', 'o'), ('0', 'D'), + ('1', 'l'), ('1', 'I'), ('1', '|'), + ('2', 'Z'), ('2', 'z'), + ('5', 'S'), ('5', 's'), + ('6', 'G'), ('6', 'b'), + ('8', 'B'), + ('9', 'g'), ('9', 'q'), + ] + + # ========================================================================= + # Word-level OCR errors (Swedish invoice terms) + # ========================================================================= + + # Common Swedish invoice terms and their OCR misreadings + SWEDISH_TERM_CORRECTIONS = { + # Faktura (Invoice) + 'faktura': ['Faktura', 'FAKTURA', 'faktúra', 'faKtura'], + 'fakturanummer': ['Fakturanummer', 'FAKTURANUMMER', 'fakturanr', 'fakt.nr'], + 'fakturadatum': ['Fakturadatum', 'FAKTURADATUM', 'fakt.datum'], + + # Belopp (Amount) + 'belopp': ['Belopp', 'BELOPP', 'be1opp', 'bel0pp'], + 'summa': ['Summa', 'SUMMA', '5umma'], + 'total': ['Total', 'TOTAL', 'tota1', 't0tal'], + 'moms': ['Moms', 'MOMS', 'm0ms'], + + # Dates + 'förfallodatum': ['Förfallodatum', 'FÖRFALLODATUM', 'förfa11odatum'], + 'datum': ['Datum', 'DATUM', 'dátum'], + + # Payment + 'bankgiro': ['Bankgiro', 'BANKGIRO', 'BG', 'bg', 'bank giro'], + 'plusgiro': ['Plusgiro', 'PLUSGIRO', 'PG', 'pg', 'plus giro'], + 'postgiro': ['Postgiro', 'POSTGIRO'], + 'ocr': ['OCR', 'ocr', '0CR', 'OcR'], + + # Organization + 'organisationsnummer': ['Organisationsnummer', 'ORGANISATIONSNUMMER', 'org.nr', 'orgnr'], + 'kundnummer': ['Kundnummer', 'KUNDNUMMER', 'kund nr', 'kundnr'], + + # Currency + 'kronor': ['Kronor', 'KRONOR', 'kr', 'KR', 'SEK', 'sek'], + 'öre': ['Öre', 'ÖRE', 'ore', 'ORE'], + } + + # ========================================================================= + # Context patterns + # ========================================================================= + + # Patterns that indicate the following/preceding text is a specific field + CONTEXT_INDICATORS = { + 'invoice_number': [ + r'faktura\s*(?:nr|nummer)?[:\s]*', + r'invoice\s*(?:no|number)?[:\s]*', + r'fakt\.?\s*nr[:\s]*', + r'inv[:\s]*#?', + ], + 'amount': [ + r'(?:att\s+)?betala[:\s]*', + r'total[t]?[:\s]*', + r'summa[:\s]*', + r'belopp[:\s]*', + r'amount[:\s]*', + ], + 'date': [ + r'datum[:\s]*', + r'date[:\s]*', + r'förfall(?:o)?datum[:\s]*', + r'fakturadatum[:\s]*', + ], + 'ocr': [ + r'ocr[:\s]*', + r'referens[:\s]*', + r'betalningsreferens[:\s]*', + ], + 'bankgiro': [ + r'bankgiro[:\s]*', + r'bg[:\s]*', + r'bank\s*giro[:\s]*', + ], + 'plusgiro': [ + r'plusgiro[:\s]*', + r'pg[:\s]*', + r'plus\s*giro[:\s]*', + r'postgiro[:\s]*', + ], + 'org_number': [ + r'org\.?\s*(?:nr|nummer)?[:\s]*', + r'organisationsnummer[:\s]*', + r'vat[:\s]*', + r'moms(?:reg)?\.?\s*(?:nr|nummer)?[:\s]*', + ], + } + + # ========================================================================= + # Correction Methods + # ========================================================================= + + @classmethod + def correct_digits(cls, text: str, aggressive: bool = False) -> CorrectionResult: + """ + Apply digit corrections to text. + + Args: + text: Input text + aggressive: If True, correct all potential digit-like characters. + If False, only correct characters adjacent to digits. + + Returns: + CorrectionResult with original, corrected text, and details. + """ + corrections = [] + result = [] + + for i, char in enumerate(text): + if char.isdigit(): + result.append(char) + elif char in cls.CHAR_TO_DIGIT: + if aggressive: + # Always correct + corrected_char = cls.CHAR_TO_DIGIT[char] + corrections.append((i, char, corrected_char)) + result.append(corrected_char) + else: + # Only correct if adjacent to digit + prev_is_digit = i > 0 and (text[i-1].isdigit() or text[i-1] in cls.CHAR_TO_DIGIT) + next_is_digit = i < len(text) - 1 and (text[i+1].isdigit() or text[i+1] in cls.CHAR_TO_DIGIT) + + if prev_is_digit or next_is_digit: + corrected_char = cls.CHAR_TO_DIGIT[char] + corrections.append((i, char, corrected_char)) + result.append(corrected_char) + else: + result.append(char) + else: + result.append(char) + + corrected = ''.join(result) + confidence = 1.0 - (len(corrections) * 0.05) # Decrease confidence per correction + + return CorrectionResult( + original=text, + corrected=corrected, + corrections_applied=corrections, + confidence=max(0.5, confidence) + ) + + @classmethod + def generate_digit_variants(cls, text: str) -> list[str]: + """ + Generate all possible digit interpretations of a text. + + Useful for matching when we don't know which direction the OCR error went. + """ + if not text: + return [text] + + variants = {text} + + # For each character that could be confused + for i, char in enumerate(text): + new_variants = set() + for existing in variants: + # If it's a digit, add letter variants + if char.isdigit() and char in cls.DIGIT_TO_CHAR: + for replacement in cls.DIGIT_TO_CHAR[char]: + new_variants.add(existing[:i] + replacement + existing[i+1:]) + + # If it's a letter that looks like a digit, add digit variant + if char in cls.CHAR_TO_DIGIT: + new_variants.add(existing[:i] + cls.CHAR_TO_DIGIT[char] + existing[i+1:]) + + variants.update(new_variants) + + # Limit explosion - only keep reasonable number + if len(variants) > 100: + break + + return list(variants) + + @classmethod + def correct_swedish_term(cls, text: str) -> str: + """ + Correct common Swedish invoice terms that may have OCR errors. + """ + text_lower = text.lower() + + for canonical, variants in cls.SWEDISH_TERM_CORRECTIONS.items(): + for variant in variants: + if variant.lower() in text_lower: + # Replace with canonical form (preserving case of first letter) + pattern = re.compile(re.escape(variant), re.IGNORECASE) + if text[0].isupper(): + replacement = canonical.capitalize() + else: + replacement = canonical + text = pattern.sub(replacement, text) + + return text + + @classmethod + def extract_with_context(cls, text: str, field_type: str) -> Optional[str]: + """ + Extract a field value using context indicators. + + Looks for patterns like "Fakturanr: 12345" and extracts "12345". + """ + patterns = cls.CONTEXT_INDICATORS.get(field_type, []) + + for pattern in patterns: + # Look for pattern followed by value + full_pattern = pattern + r'([^\s,;]+)' + match = re.search(full_pattern, text, re.IGNORECASE) + if match: + return match.group(1) + + return None + + @classmethod + def is_likely_ocr_error(cls, char1: str, char2: str) -> bool: + """ + Check if two characters are commonly confused in OCR. + """ + pair = (char1, char2) + reverse_pair = (char2, char1) + + for p in cls.CONFUSION_PAIRS: + if pair == p or reverse_pair == p: + return True + + return False + + @classmethod + def count_potential_ocr_errors(cls, s1: str, s2: str) -> tuple[int, int]: + """ + Count how many character differences between two strings + are likely OCR errors vs other differences. + + Returns: (ocr_errors, other_errors) + """ + if len(s1) != len(s2): + return (0, abs(len(s1) - len(s2))) + + ocr_errors = 0 + other_errors = 0 + + for c1, c2 in zip(s1, s2): + if c1 != c2: + if cls.is_likely_ocr_error(c1, c2): + ocr_errors += 1 + else: + other_errors += 1 + + return (ocr_errors, other_errors) + + @classmethod + def suggest_corrections(cls, text: str, expected_type: str = 'digit') -> list[tuple[str, float]]: + """ + Suggest possible corrections for a text with confidence scores. + + Returns list of (corrected_text, confidence) tuples, sorted by confidence. + """ + suggestions = [] + + if expected_type == 'digit': + # Apply digit corrections with different levels of aggressiveness + mild = cls.correct_digits(text, aggressive=False) + if mild.corrected != text: + suggestions.append((mild.corrected, mild.confidence)) + + aggressive = cls.correct_digits(text, aggressive=True) + if aggressive.corrected != text and aggressive.corrected != mild.corrected: + suggestions.append((aggressive.corrected, aggressive.confidence * 0.9)) + + # Generate variants + variants = cls.generate_digit_variants(text) + for variant in variants[:10]: # Limit to top 10 + if variant != text and variant not in [s[0] for s in suggestions]: + # Lower confidence for variants + suggestions.append((variant, 0.7)) + + # Sort by confidence + suggestions.sort(key=lambda x: x[1], reverse=True) + return suggestions + + +# ========================================================================= +# Convenience functions +# ========================================================================= + +def correct_ocr_digits(text: str, aggressive: bool = False) -> str: + """Convenience function to correct OCR digit errors.""" + return OCRCorrections.correct_digits(text, aggressive).corrected + + +def generate_ocr_variants(text: str) -> list[str]: + """Convenience function to generate OCR variants.""" + return OCRCorrections.generate_digit_variants(text) + + +def is_ocr_confusion(char1: str, char2: str) -> bool: + """Convenience function to check if characters are OCR confusable.""" + return OCRCorrections.is_likely_ocr_error(char1, char2) diff --git a/src/utils/test_advanced_utils.py b/src/utils/test_advanced_utils.py new file mode 100644 index 0000000..ac02513 --- /dev/null +++ b/src/utils/test_advanced_utils.py @@ -0,0 +1,399 @@ +""" +Tests for advanced utility modules: +- FuzzyMatcher +- OCRCorrections +- ContextExtractor +""" + +import pytest +from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult +from .ocr_corrections import OCRCorrections, correct_ocr_digits, generate_ocr_variants +from .context_extractor import ContextExtractor, extract_field_with_context + + +class TestFuzzyMatcher: + """Tests for FuzzyMatcher class.""" + + def test_levenshtein_distance_identical(self): + """Test distance for identical strings.""" + assert FuzzyMatcher.levenshtein_distance("hello", "hello") == 0 + + def test_levenshtein_distance_one_char(self): + """Test distance for one character difference.""" + assert FuzzyMatcher.levenshtein_distance("hello", "hallo") == 1 + assert FuzzyMatcher.levenshtein_distance("hello", "hell") == 1 + assert FuzzyMatcher.levenshtein_distance("hello", "helloo") == 1 + + def test_levenshtein_distance_multiple(self): + """Test distance for multiple differences.""" + assert FuzzyMatcher.levenshtein_distance("hello", "world") == 4 + assert FuzzyMatcher.levenshtein_distance("", "hello") == 5 + + def test_similarity_ratio_identical(self): + """Test similarity for identical strings.""" + assert FuzzyMatcher.similarity_ratio("hello", "hello") == 1.0 + + def test_similarity_ratio_similar(self): + """Test similarity for similar strings.""" + ratio = FuzzyMatcher.similarity_ratio("hello", "hallo") + assert 0.8 <= ratio <= 0.9 # One char different in 5-char string + + def test_similarity_ratio_different(self): + """Test similarity for different strings.""" + ratio = FuzzyMatcher.similarity_ratio("hello", "world") + assert ratio < 0.5 + + def test_ocr_aware_similarity_exact(self): + """Test OCR-aware similarity for exact match.""" + assert FuzzyMatcher.ocr_aware_similarity("12345", "12345") == 1.0 + + def test_ocr_aware_similarity_ocr_error(self): + """Test OCR-aware similarity with OCR error.""" + # O instead of 0 + score = FuzzyMatcher.ocr_aware_similarity("1234O", "12340") + assert score >= 0.9 # Should be high due to OCR correction + + def test_ocr_aware_similarity_multiple_errors(self): + """Test OCR-aware similarity with multiple OCR errors.""" + # l instead of 1, O instead of 0 + score = FuzzyMatcher.ocr_aware_similarity("l234O", "12340") + assert score >= 0.85 + + def test_match_digits_exact(self): + """Test digit matching for exact match.""" + result = FuzzyMatcher.match_digits("12345", "12345") + assert result.matched is True + assert result.score == 1.0 + assert result.match_type == 'exact' + + def test_match_digits_with_separators(self): + """Test digit matching ignoring separators.""" + result = FuzzyMatcher.match_digits("123-4567", "1234567") + assert result.matched is True + assert result.normalized_ocr == "1234567" + + def test_match_digits_ocr_error(self): + """Test digit matching with OCR error.""" + result = FuzzyMatcher.match_digits("556O234567", "5560234567") + assert result.matched is True + assert result.score >= 0.9 + + def test_match_amount_exact(self): + """Test amount matching for exact values.""" + result = FuzzyMatcher.match_amount("1234.56", "1234.56") + assert result.matched is True + assert result.score == 1.0 + + def test_match_amount_different_formats(self): + """Test amount matching with different formats.""" + # Swedish vs US format + result = FuzzyMatcher.match_amount("1234,56", "1234.56") + assert result.matched is True + assert result.score >= 0.99 + + def test_match_amount_with_spaces(self): + """Test amount matching with thousand separators.""" + result = FuzzyMatcher.match_amount("1 234,56", "1234.56") + assert result.matched is True + + def test_match_date_same_date_different_format(self): + """Test date matching with different formats.""" + result = FuzzyMatcher.match_date("2024-12-29", "29.12.2024") + assert result.matched is True + assert result.score >= 0.9 + + def test_match_date_different_dates(self): + """Test date matching with different dates.""" + result = FuzzyMatcher.match_date("2024-12-29", "2024-12-30") + assert result.matched is False + + def test_match_string_exact(self): + """Test string matching for exact match.""" + result = FuzzyMatcher.match_string("Hello World", "Hello World") + assert result.matched is True + assert result.match_type == 'exact' + + def test_match_string_case_insensitive(self): + """Test string matching case insensitivity.""" + result = FuzzyMatcher.match_string("HELLO", "hello") + assert result.matched is True + assert result.match_type == 'normalized' + + def test_match_string_ocr_corrected(self): + """Test string matching with OCR corrections.""" + result = FuzzyMatcher.match_string("5561234567", "556l234567") + assert result.matched is True + + def test_match_field_routes_correctly(self): + """Test that match_field routes to correct matcher.""" + # Amount field + result = FuzzyMatcher.match_field("Amount", "1234.56", "1234,56") + assert result.matched is True + + # Date field + result = FuzzyMatcher.match_field("InvoiceDate", "2024-12-29", "29.12.2024") + assert result.matched is True + + def test_find_best_match(self): + """Test finding best match from candidates.""" + candidates = ["12345", "12346", "99999"] + result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber") + + assert result is not None + assert result[0] == "12345" + assert result[1].score == 1.0 + + def test_find_best_match_no_match(self): + """Test finding best match when none above threshold.""" + candidates = ["99999", "88888", "77777"] + result = FuzzyMatcher.find_best_match("12345", candidates, "InvoiceNumber") + + assert result is None + + +class TestOCRCorrections: + """Tests for OCRCorrections class.""" + + def test_correct_digits_simple(self): + """Test simple digit correction.""" + result = OCRCorrections.correct_digits("556O23", aggressive=False) + assert result.corrected == "556023" + assert len(result.corrections_applied) == 1 + + def test_correct_digits_multiple(self): + """Test multiple digit corrections.""" + result = OCRCorrections.correct_digits("5S6l23", aggressive=False) + assert result.corrected == "556123" + assert len(result.corrections_applied) == 2 + + def test_correct_digits_aggressive(self): + """Test aggressive mode corrects all potential errors.""" + result = OCRCorrections.correct_digits("AB123", aggressive=True) + # A -> 4, B -> 8 + assert result.corrected == "48123" + + def test_correct_digits_non_aggressive(self): + """Test non-aggressive mode only corrects adjacent.""" + result = OCRCorrections.correct_digits("AB 123", aggressive=False) + # A and B are adjacent to each other and both in CHAR_TO_DIGIT, + # so they may be corrected. The key is digits are not affected. + assert "123" in result.corrected + + def test_generate_digit_variants(self): + """Test generating OCR variants.""" + variants = OCRCorrections.generate_digit_variants("10") + # Should include original and variants like "1O", "I0", "IO", "l0", etc. + assert "10" in variants + assert "1O" in variants or "l0" in variants + + def test_generate_digit_variants_limits(self): + """Test that variant generation is limited.""" + variants = OCRCorrections.generate_digit_variants("1234567890") + # Should be limited to prevent explosion (limit is ~100, but may slightly exceed) + assert len(variants) <= 150 + + def test_is_likely_ocr_error(self): + """Test OCR error detection.""" + assert OCRCorrections.is_likely_ocr_error('0', 'O') is True + assert OCRCorrections.is_likely_ocr_error('O', '0') is True + assert OCRCorrections.is_likely_ocr_error('1', 'l') is True + assert OCRCorrections.is_likely_ocr_error('5', 'S') is True + assert OCRCorrections.is_likely_ocr_error('A', 'Z') is False + + def test_count_potential_ocr_errors(self): + """Test counting OCR errors vs other errors.""" + ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1O3", "103") + assert ocr_errors == 1 # O vs 0 + assert other_errors == 0 + + ocr_errors, other_errors = OCRCorrections.count_potential_ocr_errors("1X3", "103") + assert ocr_errors == 0 + assert other_errors == 1 # X vs 0, not a known pair + + def test_suggest_corrections(self): + """Test correction suggestions.""" + suggestions = OCRCorrections.suggest_corrections("556O23", expected_type='digit') + assert len(suggestions) > 0 + # First suggestion should be the corrected version + assert suggestions[0][0] == "556023" + + def test_convenience_function_correct(self): + """Test convenience function.""" + assert correct_ocr_digits("556O23") == "556023" + + def test_convenience_function_variants(self): + """Test convenience function for variants.""" + variants = generate_ocr_variants("10") + assert "10" in variants + + +class TestContextExtractor: + """Tests for ContextExtractor class.""" + + def test_extract_invoice_number_with_label(self): + """Test extracting invoice number after label.""" + text = "Fakturanummer: 12345678" + candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber") + + assert len(candidates) > 0 + assert candidates[0].value == "12345678" + assert candidates[0].extraction_method == 'label' + + def test_extract_invoice_number_swedish(self): + """Test extracting with Swedish label.""" + text = "Faktura nr: A12345" + candidates = ContextExtractor.extract_with_label(text, "InvoiceNumber") + + assert len(candidates) > 0 + # Should extract A12345 or 12345 + + def test_extract_amount_with_label(self): + """Test extracting amount after label.""" + text = "Att betala: 1 234,56" + candidates = ContextExtractor.extract_with_label(text, "Amount") + + assert len(candidates) > 0 + + def test_extract_amount_total(self): + """Test extracting with total label.""" + text = "Total: 5678,90 kr" + candidates = ContextExtractor.extract_with_label(text, "Amount") + + assert len(candidates) > 0 + + def test_extract_date_with_label(self): + """Test extracting date after label.""" + text = "Fakturadatum: 2024-12-29" + candidates = ContextExtractor.extract_with_label(text, "InvoiceDate") + + assert len(candidates) > 0 + assert "2024-12-29" in candidates[0].value + + def test_extract_due_date(self): + """Test extracting due date.""" + text = "Förfallodatum: 2025-01-15" + candidates = ContextExtractor.extract_with_label(text, "InvoiceDueDate") + + assert len(candidates) > 0 + + def test_extract_bankgiro(self): + """Test extracting Bankgiro.""" + text = "Bankgiro: 1234-5678" + candidates = ContextExtractor.extract_with_label(text, "Bankgiro") + + assert len(candidates) > 0 + assert "1234-5678" in candidates[0].value or "12345678" in candidates[0].value + + def test_extract_plusgiro(self): + """Test extracting Plusgiro.""" + text = "Plusgiro: 1234567-8" + candidates = ContextExtractor.extract_with_label(text, "Plusgiro") + + assert len(candidates) > 0 + + def test_extract_ocr(self): + """Test extracting OCR number.""" + text = "OCR: 12345678901234" + candidates = ContextExtractor.extract_with_label(text, "OCR") + + assert len(candidates) > 0 + assert candidates[0].value == "12345678901234" + + def test_extract_org_number(self): + """Test extracting organization number.""" + text = "Org.nr: 556123-4567" + candidates = ContextExtractor.extract_with_label(text, "supplier_organisation_number") + + assert len(candidates) > 0 + + def test_extract_customer_number(self): + """Test extracting customer number.""" + text = "Kundnummer: EMM 256-6" + candidates = ContextExtractor.extract_with_label(text, "customer_number") + + assert len(candidates) > 0 + + def test_extract_field_returns_sorted(self): + """Test that extract_field returns sorted by confidence.""" + text = "Fakturanummer: 12345 Invoice number: 67890" + candidates = ContextExtractor.extract_field(text, "InvoiceNumber") + + if len(candidates) > 1: + # Should be sorted by confidence (descending) + assert candidates[0].confidence >= candidates[1].confidence + + def test_extract_best(self): + """Test extract_best returns single best candidate.""" + text = "Fakturanummer: 12345678" + best = ContextExtractor.extract_best(text, "InvoiceNumber") + + assert best is not None + assert best.value == "12345678" + + def test_extract_best_no_match(self): + """Test extract_best returns None when no match.""" + text = "No invoice information here" + best = ContextExtractor.extract_best(text, "InvoiceNumber", validate=True) + + # May or may not find something depending on validation + + def test_extract_all_fields(self): + """Test extracting all fields from text.""" + text = """ + Fakturanummer: 12345 + Datum: 2024-12-29 + Belopp: 1234,56 + Bankgiro: 1234-5678 + """ + results = ContextExtractor.extract_all_fields(text) + + # Should find at least some fields + assert len(results) > 0 + + def test_identify_field_type(self): + """Test identifying field type from context.""" + text = "Fakturanummer: 12345" + field_type = ContextExtractor.identify_field_type(text, "12345") + + assert field_type == "InvoiceNumber" + + def test_convenience_function_extract(self): + """Test convenience function.""" + text = "Fakturanummer: 12345678" + value = extract_field_with_context(text, "InvoiceNumber") + + assert value == "12345678" + + +class TestIntegration: + """Integration tests combining multiple modules.""" + + def test_fuzzy_match_with_ocr_correction(self): + """Test fuzzy matching with OCR correction.""" + # Simulate OCR error: 0 -> O + ocr_text = "556O234567" + expected = "5560234567" + + # First correct + corrected = correct_ocr_digits(ocr_text) + assert corrected == expected + + # Then match + result = FuzzyMatcher.match_digits(ocr_text, expected) + assert result.matched is True + + def test_context_extraction_with_fuzzy_match(self): + """Test extracting value and fuzzy matching.""" + text = "Fakturanummer: 1234S678" # S is OCR error for 5 + + # Extract + candidate = ContextExtractor.extract_best(text, "InvoiceNumber", validate=False) + assert candidate is not None + + # Fuzzy match against expected + result = FuzzyMatcher.match_string(candidate.value, "12345678") + # Might match depending on threshold + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/src/utils/test_utils.py b/src/utils/test_utils.py new file mode 100644 index 0000000..455ba34 --- /dev/null +++ b/src/utils/test_utils.py @@ -0,0 +1,235 @@ +""" +Tests for shared utility modules. +""" + +import pytest +from .text_cleaner import TextCleaner +from .format_variants import FormatVariants +from .validators import FieldValidators + + +class TestTextCleaner: + """Tests for TextCleaner class.""" + + def test_clean_unicode_dashes(self): + """Test normalization of various dash types.""" + # en-dash + assert TextCleaner.clean_unicode("556123–4567") == "556123-4567" + # em-dash + assert TextCleaner.clean_unicode("556123—4567") == "556123-4567" + # minus sign + assert TextCleaner.clean_unicode("556123−4567") == "556123-4567" + + def test_clean_unicode_spaces(self): + """Test normalization of various space types.""" + # non-breaking space + assert TextCleaner.clean_unicode("1\xa0234") == "1 234" + # zero-width space removed + assert TextCleaner.clean_unicode("123\u200b456") == "123456" + + def test_ocr_digit_corrections(self): + """Test OCR error corrections for digit fields.""" + # O -> 0 + assert TextCleaner.apply_ocr_digit_corrections("556O23") == "556023" + # l -> 1 + assert TextCleaner.apply_ocr_digit_corrections("556l23") == "556123" + # S -> 5 + assert TextCleaner.apply_ocr_digit_corrections("5S6123") == "556123" + # Mixed + assert TextCleaner.apply_ocr_digit_corrections("S56l23-4S67") == "556123-4567" + + def test_extract_digits(self): + """Test digit extraction with OCR correction.""" + assert TextCleaner.extract_digits("556123-4567") == "5561234567" + assert TextCleaner.extract_digits("556O23-4567", apply_ocr_correction=True) == "5560234567" + # Without OCR correction, only extracts actual digits + assert TextCleaner.extract_digits("ABC 123 DEF", apply_ocr_correction=False) == "123" + # With OCR correction, standalone letters are not converted + # (they need to be adjacent to digits to be corrected) + assert TextCleaner.extract_digits("A 123 B", apply_ocr_correction=True) == "123" + + def test_normalize_amount_text(self): + """Test amount text normalization.""" + assert TextCleaner.normalize_amount_text("1 234,56 kr") == "1234,56" + assert TextCleaner.normalize_amount_text("SEK 1234.56") == "1234.56" + assert TextCleaner.normalize_amount_text("1 234 567,89 kronor") == "1234567,89" + + +class TestFormatVariants: + """Tests for FormatVariants class.""" + + def test_organisation_number_variants(self): + """Test organisation number variant generation.""" + variants = FormatVariants.organisation_number_variants("5561234567") + + assert "5561234567" in variants # 纯数字 + assert "556123-4567" in variants # 带横线 + assert "SE556123456701" in variants # VAT格式 + + def test_organisation_number_from_vat(self): + """Test extracting org number from VAT format.""" + variants = FormatVariants.organisation_number_variants("SE556123456701") + + assert "5561234567" in variants + assert "556123-4567" in variants + + def test_bankgiro_variants(self): + """Test Bankgiro variant generation.""" + # 8 digits + variants = FormatVariants.bankgiro_variants("53939484") + assert "53939484" in variants + assert "5393-9484" in variants + + # 7 digits + variants = FormatVariants.bankgiro_variants("1234567") + assert "1234567" in variants + assert "123-4567" in variants + + def test_plusgiro_variants(self): + """Test Plusgiro variant generation.""" + variants = FormatVariants.plusgiro_variants("12345678") + assert "12345678" in variants + assert "1234567-8" in variants + + def test_amount_variants(self): + """Test amount variant generation.""" + variants = FormatVariants.amount_variants("1234.56") + + assert "1234.56" in variants + assert "1234,56" in variants + assert "1 234,56" in variants or "1234,56" in variants # Swedish format + + def test_date_variants(self): + """Test date variant generation.""" + variants = FormatVariants.date_variants("2024-12-29") + + assert "2024-12-29" in variants # ISO + assert "29.12.2024" in variants # European + assert "29/12/2024" in variants # European slash + assert "20241229" in variants # Compact + assert "29 december 2024" in variants # Swedish text + + def test_invoice_number_variants(self): + """Test invoice number variant generation.""" + variants = FormatVariants.invoice_number_variants("INV-2024-001") + + assert "INV-2024-001" in variants + assert "INV2024001" in variants # No separators + assert "inv-2024-001" in variants # Lowercase + + def test_get_variants_dispatch(self): + """Test get_variants dispatches to correct method.""" + # Organisation number + org_variants = FormatVariants.get_variants("supplier_organisation_number", "5561234567") + assert "556123-4567" in org_variants + + # Bankgiro + bg_variants = FormatVariants.get_variants("Bankgiro", "53939484") + assert "5393-9484" in bg_variants + + # Amount + amount_variants = FormatVariants.get_variants("Amount", "1234.56") + assert "1234,56" in amount_variants + + +class TestFieldValidators: + """Tests for FieldValidators class.""" + + def test_luhn_checksum_valid(self): + """Test Luhn validation with valid numbers.""" + # Valid Bankgiro numbers (with correct check digit) + assert FieldValidators.luhn_checksum("53939484") is True + # Valid OCR numbers + assert FieldValidators.luhn_checksum("1234567897") is True # check digit 7 + + def test_luhn_checksum_invalid(self): + """Test Luhn validation with invalid numbers.""" + assert FieldValidators.luhn_checksum("53939485") is False # wrong check digit + assert FieldValidators.luhn_checksum("1234567890") is False + + def test_calculate_luhn_check_digit(self): + """Test Luhn check digit calculation.""" + # For "123456789", the check digit should make it valid + check = FieldValidators.calculate_luhn_check_digit("123456789") + full_number = "123456789" + str(check) + assert FieldValidators.luhn_checksum(full_number) is True + + def test_is_valid_organisation_number(self): + """Test organisation number validation.""" + # Valid (with correct Luhn checksum) + # Note: Need actual valid org numbers for this test + # Using a well-known one: 5565006245 (placeholder) + pass # Skip without real test data + + def test_is_valid_bankgiro(self): + """Test Bankgiro validation.""" + # Valid 8-digit Bankgiro with Luhn + assert FieldValidators.is_valid_bankgiro("53939484") is True + # Invalid (wrong length) + assert FieldValidators.is_valid_bankgiro("123") is False + assert FieldValidators.is_valid_bankgiro("123456789") is False # 9 digits + + def test_format_bankgiro(self): + """Test Bankgiro formatting.""" + assert FieldValidators.format_bankgiro("53939484") == "5393-9484" + assert FieldValidators.format_bankgiro("1234567") == "123-4567" + assert FieldValidators.format_bankgiro("123") is None + + def test_is_valid_plusgiro(self): + """Test Plusgiro validation.""" + # Valid Plusgiro (2-8 digits with Luhn) + assert FieldValidators.is_valid_plusgiro("18") is True # minimal + # Invalid (wrong length) + assert FieldValidators.is_valid_plusgiro("1") is False + + def test_format_plusgiro(self): + """Test Plusgiro formatting.""" + assert FieldValidators.format_plusgiro("12345678") == "1234567-8" + assert FieldValidators.format_plusgiro("123456") == "12345-6" + + def test_is_valid_amount(self): + """Test amount validation.""" + assert FieldValidators.is_valid_amount("1234.56") is True + assert FieldValidators.is_valid_amount("1 234,56") is True + assert FieldValidators.is_valid_amount("abc") is False + assert FieldValidators.is_valid_amount("-100") is False # below min + assert FieldValidators.is_valid_amount("100000000") is False # above max + + def test_parse_amount(self): + """Test amount parsing.""" + assert FieldValidators.parse_amount("1234.56") == 1234.56 + assert FieldValidators.parse_amount("1 234,56") == 1234.56 + assert FieldValidators.parse_amount("1.234,56") == 1234.56 # German + assert FieldValidators.parse_amount("1,234.56") == 1234.56 # US + + def test_is_valid_date(self): + """Test date validation.""" + assert FieldValidators.is_valid_date("2024-12-29") is True + assert FieldValidators.is_valid_date("29.12.2024") is True + assert FieldValidators.is_valid_date("29/12/2024") is True + assert FieldValidators.is_valid_date("not a date") is False + assert FieldValidators.is_valid_date("1900-01-01") is False # out of range + + def test_format_date_iso(self): + """Test date ISO formatting.""" + assert FieldValidators.format_date_iso("29.12.2024") == "2024-12-29" + assert FieldValidators.format_date_iso("29/12/2024") == "2024-12-29" + assert FieldValidators.format_date_iso("2024-12-29") == "2024-12-29" + + def test_validate_field_dispatch(self): + """Test validate_field dispatches correctly.""" + # Organisation number + is_valid, error = FieldValidators.validate_field("supplier_organisation_number", "") + assert is_valid is False + + # Amount + is_valid, error = FieldValidators.validate_field("Amount", "1234.56") + assert is_valid is True + + # Date + is_valid, error = FieldValidators.validate_field("InvoiceDate", "2024-12-29") + assert is_valid is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/src/utils/text_cleaner.py b/src/utils/text_cleaner.py new file mode 100644 index 0000000..dbe29ef --- /dev/null +++ b/src/utils/text_cleaner.py @@ -0,0 +1,244 @@ +""" +Text Cleaning Module + +Provides text normalization and OCR error correction utilities. +Used by both inference (field_extractor) and matching (normalizer) stages. +""" + +import re +from typing import Optional + + +class TextCleaner: + """ + Unified text cleaning utilities for invoice processing. + + Handles: + - Unicode normalization (zero-width chars, dash variants) + - OCR error correction (O/0, l/1, etc.) + - Whitespace normalization + - Swedish-specific character handling + """ + + # OCR常见错误修正映射 (用于数字字段) + # 当我们期望数字时,这些字符常被误识别 + OCR_DIGIT_CORRECTIONS = { + 'O': '0', 'o': '0', # 字母O -> 数字0 + 'Q': '0', # Q 有时像 0 + 'l': '1', 'I': '1', # 小写L/大写I -> 数字1 + '|': '1', # 竖线 -> 1 + 'i': '1', # 小写i -> 1 + 'S': '5', 's': '5', # S -> 5 + 'B': '8', # B -> 8 + 'Z': '2', 'z': '2', # Z -> 2 + 'G': '6', 'g': '6', # G -> 6 (在某些字体中) + 'A': '4', # A -> 4 (在某些字体中) + 'T': '7', # T -> 7 (在某些字体中) + 'q': '9', # q -> 9 + 'D': '0', # D -> 0 + } + + # 反向映射:数字被误识别为字母的情况 (用于字母数字混合字段) + OCR_LETTER_CORRECTIONS = { + '0': 'O', + '1': 'I', + '5': 'S', + '8': 'B', + '2': 'Z', + } + + # Unicode 特殊字符归一化 + UNICODE_NORMALIZATIONS = { + # 各种横线/破折号 -> 标准连字符 + '\u2013': '-', # en-dash – + '\u2014': '-', # em-dash — + '\u2212': '-', # minus sign − + '\u00b7': '-', # middle dot · + '\u2010': '-', # hyphen ‐ + '\u2011': '-', # non-breaking hyphen ‑ + '\u2012': '-', # figure dash ‒ + '\u2015': '-', # horizontal bar ― + + # 各种空格 -> 标准空格 + '\u00a0': ' ', # non-breaking space + '\u2002': ' ', # en space + '\u2003': ' ', # em space + '\u2009': ' ', # thin space + '\u200a': ' ', # hair space + + # 零宽字符 -> 删除 + '\u200b': '', # zero-width space + '\u200c': '', # zero-width non-joiner + '\u200d': '', # zero-width joiner + '\ufeff': '', # BOM / zero-width no-break space + + # 各种引号 -> 标准引号 + '\u2018': "'", # left single quote ' + '\u2019': "'", # right single quote ' + '\u201c': '"', # left double quote " + '\u201d': '"', # right double quote " + } + + @classmethod + def clean_unicode(cls, text: str) -> str: + """ + Normalize Unicode characters to ASCII equivalents. + + Handles: + - Various dash types -> standard hyphen (-) + - Various spaces -> standard space + - Zero-width characters -> removed + - Various quotes -> standard quotes + """ + for unicode_char, replacement in cls.UNICODE_NORMALIZATIONS.items(): + text = text.replace(unicode_char, replacement) + return text + + @classmethod + def normalize_whitespace(cls, text: str) -> str: + """Collapse multiple whitespace to single space and strip.""" + return ' '.join(text.split()) + + @classmethod + def clean_text(cls, text: str) -> str: + """ + Full text cleaning pipeline. + + 1. Normalize Unicode + 2. Normalize whitespace + 3. Strip + + This is safe for all field types. + """ + text = cls.clean_unicode(text) + text = cls.normalize_whitespace(text) + return text.strip() + + @classmethod + def apply_ocr_digit_corrections(cls, text: str) -> str: + """ + Apply OCR error corrections for digit-only fields. + + Use this when the field is expected to contain only digits + (e.g., OCR number, organization number digits, etc.) + + Example: + "556l23-4S67" -> "556123-4567" + """ + result = [] + for char in text: + if char in cls.OCR_DIGIT_CORRECTIONS: + result.append(cls.OCR_DIGIT_CORRECTIONS[char]) + else: + result.append(char) + return ''.join(result) + + @classmethod + def extract_digits(cls, text: str, apply_ocr_correction: bool = True) -> str: + """ + Extract only digits from text. + + Args: + text: Input text + apply_ocr_correction: If True, apply OCR corrections ONLY to characters + that are adjacent to digits (not standalone letters) + + Returns: + String containing only digits + """ + if apply_ocr_correction: + # 只对看起来像数字序列中的字符应用 OCR 修正 + # 例如 "556O23" 中的 O 应该修正,但 "ABC 123" 中的 ABC 不应该 + result = [] + for i, char in enumerate(text): + if char.isdigit(): + result.append(char) + elif char in cls.OCR_DIGIT_CORRECTIONS: + # 检查前后是否有数字 + prev_is_digit = i > 0 and (text[i - 1].isdigit() or text[i - 1] in cls.OCR_DIGIT_CORRECTIONS) + next_is_digit = i < len(text) - 1 and (text[i + 1].isdigit() or text[i + 1] in cls.OCR_DIGIT_CORRECTIONS) + if prev_is_digit or next_is_digit: + result.append(cls.OCR_DIGIT_CORRECTIONS[char]) + # 其他字符跳过 + return ''.join(result) + else: + return re.sub(r'\D', '', text) + + @classmethod + def clean_for_digits(cls, text: str) -> str: + """ + Clean text that should primarily contain digits. + + Pipeline: + 1. Clean Unicode + 2. Apply OCR digit corrections + 3. Normalize whitespace + + Preserves separators (-, /) for formatted numbers like "556123-4567" + """ + text = cls.clean_unicode(text) + text = cls.apply_ocr_digit_corrections(text) + text = cls.normalize_whitespace(text) + return text.strip() + + @classmethod + def generate_ocr_variants(cls, text: str) -> list[str]: + """ + Generate possible OCR error variants of the input text. + + This is useful for matching: if we have a CSV value, + we generate variants that might appear in OCR output. + + Example: + "5561234567" -> ["5561234567", "556I234567", "5561234S67", ...] + """ + variants = {text} + + # 只对数字生成字母变体 + for digit, letter in cls.OCR_LETTER_CORRECTIONS.items(): + if digit in text: + variants.add(text.replace(digit, letter)) + + # 对字母生成数字变体 + for letter, digit in cls.OCR_DIGIT_CORRECTIONS.items(): + if letter in text: + variants.add(text.replace(letter, digit)) + + return list(variants) + + @classmethod + def normalize_amount_text(cls, text: str) -> str: + """ + Normalize amount text for parsing. + + - Removes currency symbols and labels + - Normalizes separators + - Handles Swedish format (space as thousand separator) + """ + text = cls.clean_text(text) + + # 移除货币符号和标签 (使用单词边界确保完整匹配) + text = re.sub(r'(?i)\b(kr|sek|kronor|öre)\b', '', text) + + # 移除千位分隔空格 (Swedish: "1 234,56" -> "1234,56") + # 但保留小数点前的数字 + text = re.sub(r'(\d)\s+(\d)', r'\1\2', text) + + return text.strip() + + @classmethod + def normalize_for_comparison(cls, text: str) -> str: + """ + Normalize text for loose comparison. + + - Lowercase + - Remove all non-alphanumeric + - Apply OCR corrections + + This is the most aggressive normalization, used for fuzzy matching. + """ + text = cls.clean_text(text) + text = text.lower() + text = cls.apply_ocr_digit_corrections(text) + text = re.sub(r'[^a-z0-9]', '', text) + return text diff --git a/src/utils/validators.py b/src/utils/validators.py new file mode 100644 index 0000000..2f8b643 --- /dev/null +++ b/src/utils/validators.py @@ -0,0 +1,393 @@ +""" +Field Validators Module + +Provides validation functions for Swedish invoice fields. +Used by both inference (to validate extracted values) and matching (to filter candidates). +""" + +import re +from datetime import datetime +from typing import Optional + +from .text_cleaner import TextCleaner + + +class FieldValidators: + """ + Validators for Swedish invoice field values. + + Includes: + - Luhn (Mod10) checksum validation + - Format validation for specific field types + - Range validation for dates and amounts + """ + + # ========================================================================= + # Luhn (Mod10) Checksum + # ========================================================================= + + @classmethod + def luhn_checksum(cls, digits: str) -> bool: + """ + Validate using Luhn (Mod10) algorithm. + + Used for: + - Bankgiro numbers + - Plusgiro numbers + - OCR reference numbers + - Swedish organization numbers + + The checksum is valid if the total modulo 10 equals 0. + """ + # 只保留数字 + digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False) + + if not digits or not digits.isdigit(): + return False + + total = 0 + for i, char in enumerate(reversed(digits)): + digit = int(char) + if i % 2 == 1: # 从右往左,每隔一位加倍 + digit *= 2 + if digit > 9: + digit -= 9 + total += digit + + return total % 10 == 0 + + @classmethod + def calculate_luhn_check_digit(cls, digits: str) -> int: + """ + Calculate the Luhn check digit for a number. + + Given a number without check digit, returns the digit that would make it valid. + """ + digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False) + + # 计算现有数字的 Luhn 和 + total = 0 + for i, char in enumerate(reversed(digits)): + digit = int(char) + if i % 2 == 0: # 注意:因为还要加一位,所以偶数位置加倍 + digit *= 2 + if digit > 9: + digit -= 9 + total += digit + + # 计算需要的校验位 + check_digit = (10 - (total % 10)) % 10 + return check_digit + + # ========================================================================= + # Organisation Number Validation + # ========================================================================= + + @classmethod + def is_valid_organisation_number(cls, value: str) -> bool: + """ + Validate Swedish organisation number. + + Format: NNNNNN-NNNN (10 digits) + - First digit: 1-9 + - Third digit: >= 2 (distinguishes from personal numbers) + - Last digit: Luhn check digit + """ + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + # 处理 VAT 格式 + if len(digits) == 12 and digits.endswith('01'): + digits = digits[:10] + elif len(digits) == 14 and digits.startswith('46') and digits.endswith('01'): + digits = digits[2:12] + + if len(digits) != 10: + return False + + # 第一位 1-9 + if digits[0] == '0': + return False + + # 第三位 >= 2 (区分组织号和个人号) + # 注意:有些特殊组织可能不符合此规则,所以这里放宽 + # if int(digits[2]) < 2: + # return False + + # Luhn 校验 + return cls.luhn_checksum(digits) + + # ========================================================================= + # Bankgiro Validation + # ========================================================================= + + @classmethod + def is_valid_bankgiro(cls, value: str) -> bool: + """ + Validate Swedish Bankgiro number. + + Format: 7 or 8 digits with Luhn checksum + """ + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + if len(digits) < 7 or len(digits) > 8: + return False + + return cls.luhn_checksum(digits) + + @classmethod + def format_bankgiro(cls, value: str) -> Optional[str]: + """ + Format Bankgiro number to standard format. + + Returns: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits), or None if invalid + """ + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + if len(digits) == 7: + return f"{digits[:3]}-{digits[3:]}" + elif len(digits) == 8: + return f"{digits[:4]}-{digits[4:]}" + else: + return None + + # ========================================================================= + # Plusgiro Validation + # ========================================================================= + + @classmethod + def is_valid_plusgiro(cls, value: str) -> bool: + """ + Validate Swedish Plusgiro number. + + Format: 2-8 digits with Luhn checksum + """ + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + if len(digits) < 2 or len(digits) > 8: + return False + + return cls.luhn_checksum(digits) + + @classmethod + def format_plusgiro(cls, value: str) -> Optional[str]: + """ + Format Plusgiro number to standard format. + + Returns: XXXXXXX-X format, or None if invalid + """ + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + if len(digits) < 2 or len(digits) > 8: + return None + + return f"{digits[:-1]}-{digits[-1]}" + + # ========================================================================= + # OCR Number Validation + # ========================================================================= + + @classmethod + def is_valid_ocr_number(cls, value: str, validate_checksum: bool = True) -> bool: + """ + Validate Swedish OCR reference number. + + - Typically 10-25 digits + - Usually has Luhn checksum (but not always enforced) + """ + digits = TextCleaner.extract_digits(value, apply_ocr_correction=True) + + if len(digits) < 5 or len(digits) > 25: + return False + + if validate_checksum: + return cls.luhn_checksum(digits) + + return True + + # ========================================================================= + # Amount Validation + # ========================================================================= + + @classmethod + def is_valid_amount(cls, value: str, min_amount: float = 0.0, max_amount: float = 10_000_000.0) -> bool: + """ + Validate monetary amount. + + - Must be positive (or at least >= min_amount) + - Should be within reasonable range + """ + try: + # 尝试解析 + text = TextCleaner.normalize_amount_text(value) + # 统一为点作为小数分隔符 + text = text.replace(' ', '').replace(',', '.') + # 如果有多个点,保留最后一个 + if text.count('.') > 1: + parts = text.rsplit('.', 1) + text = parts[0].replace('.', '') + '.' + parts[1] + + amount = float(text) + return min_amount <= amount <= max_amount + except (ValueError, TypeError): + return False + + @classmethod + def parse_amount(cls, value: str) -> Optional[float]: + """ + Parse amount from string, handling various formats. + + Returns float or None if parsing fails. + """ + try: + text = TextCleaner.normalize_amount_text(value) + text = text.replace(' ', '') + + # 检测格式并解析 + # 瑞典/德国格式: 逗号是小数点 + if re.match(r'^[\d.]+,\d{1,2}$', text): + text = text.replace('.', '').replace(',', '.') + # 美国格式: 点是小数点 + elif re.match(r'^[\d,]+\.\d{1,2}$', text): + text = text.replace(',', '') + else: + # 简单格式 + text = text.replace(',', '.') + if text.count('.') > 1: + parts = text.rsplit('.', 1) + text = parts[0].replace('.', '') + '.' + parts[1] + + return float(text) + except (ValueError, TypeError): + return None + + # ========================================================================= + # Date Validation + # ========================================================================= + + @classmethod + def is_valid_date(cls, value: str, min_year: int = 2000, max_year: int = 2100) -> bool: + """ + Validate date string. + + - Year should be within reasonable range + - Month 1-12 + - Day 1-31 (basic check) + """ + parsed = cls.parse_date(value) + if parsed is None: + return False + + year, month, day = parsed + if not (min_year <= year <= max_year): + return False + if not (1 <= month <= 12): + return False + if not (1 <= day <= 31): + return False + + # 更精确的日期验证 + try: + datetime(year, month, day) + return True + except ValueError: + return False + + @classmethod + def parse_date(cls, value: str) -> Optional[tuple[int, int, int]]: + """ + Parse date from string. + + Returns (year, month, day) tuple or None. + """ + from .format_variants import FormatVariants + return FormatVariants._parse_date(value) + + @classmethod + def format_date_iso(cls, value: str) -> Optional[str]: + """ + Format date to ISO format (YYYY-MM-DD). + + Returns formatted string or None if parsing fails. + """ + parsed = cls.parse_date(value) + if parsed is None: + return None + + year, month, day = parsed + return f"{year}-{month:02d}-{day:02d}" + + # ========================================================================= + # Invoice Number Validation + # ========================================================================= + + @classmethod + def is_valid_invoice_number(cls, value: str, min_length: int = 1, max_length: int = 30) -> bool: + """ + Validate invoice number. + + Basic validation - just length check since invoice numbers are highly variable. + """ + clean = TextCleaner.clean_text(value) + if not clean: + return False + + # 提取有意义的字符(字母和数字) + meaningful = re.sub(r'[^a-zA-Z0-9]', '', clean) + return min_length <= len(meaningful) <= max_length + + # ========================================================================= + # Generic Validation + # ========================================================================= + + @classmethod + def validate_field(cls, field_name: str, value: str) -> tuple[bool, Optional[str]]: + """ + Validate a field by name. + + Returns (is_valid, error_message). + """ + if not value: + return False, "Empty value" + + field_lower = field_name.lower() + + if 'organisation' in field_lower or 'org' in field_lower: + if cls.is_valid_organisation_number(value): + return True, None + return False, "Invalid organisation number format or checksum" + + elif 'bankgiro' in field_lower: + if cls.is_valid_bankgiro(value): + return True, None + return False, "Invalid Bankgiro format or checksum" + + elif 'plusgiro' in field_lower: + if cls.is_valid_plusgiro(value): + return True, None + return False, "Invalid Plusgiro format or checksum" + + elif 'ocr' in field_lower: + if cls.is_valid_ocr_number(value, validate_checksum=False): + return True, None + return False, "Invalid OCR number length" + + elif 'amount' in field_lower: + if cls.is_valid_amount(value): + return True, None + return False, "Invalid amount format" + + elif 'date' in field_lower: + if cls.is_valid_date(value): + return True, None + return False, "Invalid date format" + + elif 'invoice' in field_lower and 'number' in field_lower: + if cls.is_valid_invoice_number(value): + return True, None + return False, "Invalid invoice number" + + else: + # 默认:只检查非空 + if TextCleaner.clean_text(value): + return True, None + return False, "Empty value after cleaning" diff --git a/src/web/routes.py b/src/web/routes.py index 31e7c1c..c193bcb 100644 --- a/src/web/routes.py +++ b/src/web/routes.py @@ -122,6 +122,7 @@ def create_api_router( inference_result = InferenceResult( document_id=service_result.document_id, success=service_result.success, + document_type=service_result.document_type, fields=service_result.fields, confidence=service_result.confidence, detections=[ diff --git a/src/web/schemas.py b/src/web/schemas.py index 352eff7..f7ed47d 100644 --- a/src/web/schemas.py +++ b/src/web/schemas.py @@ -30,6 +30,9 @@ class InferenceResult(BaseModel): document_id: str = Field(..., description="Document identifier") success: bool = Field(..., description="Whether inference succeeded") + document_type: str = Field( + default="invoice", description="Document type: 'invoice' or 'letter'" + ) fields: dict[str, str | None] = Field( default_factory=dict, description="Extracted field values" ) diff --git a/src/web/services.py b/src/web/services.py index ca3890a..bd19c0e 100644 --- a/src/web/services.py +++ b/src/web/services.py @@ -28,6 +28,7 @@ class ServiceResult: document_id: str success: bool = False + document_type: str = "invoice" # "invoice" or "letter" fields: dict[str, str | None] = field(default_factory=dict) confidence: dict[str, float] = field(default_factory=dict) detections: list[dict] = field(default_factory=list) @@ -145,6 +146,13 @@ class InferenceService: result.success = pipeline_result.success result.errors = pipeline_result.errors + # Determine document type based on payment_line presence + # If no payment_line found, it's likely a letter, not an invoice + if not result.fields.get('payment_line'): + result.document_type = "letter" + else: + result.document_type = "invoice" + # Get raw detections for visualization result.detections = [ { @@ -202,6 +210,13 @@ class InferenceService: result.success = pipeline_result.success result.errors = pipeline_result.errors + # Determine document type based on payment_line presence + # If no payment_line found, it's likely a letter, not an invoice + if not result.fields.get('payment_line'): + result.document_type = "letter" + else: + result.document_type = "invoice" + # Get raw detections result.detections = [ {