restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -0,0 +1,9 @@
PyMuPDF>=1.23.0
paddleocr>=2.7.0
Pillow>=10.0.0
numpy>=1.24.0
opencv-python>=4.8.0
psycopg2-binary>=2.9.0
python-dotenv>=1.0.0
pyyaml>=6.0
thefuzz>=0.20.0

19
packages/shared/setup.py Normal file
View File

@@ -0,0 +1,19 @@
from setuptools import setup, find_packages
setup(
name="invoice-shared",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"PyMuPDF>=1.23.0",
"paddleocr>=2.7.0",
"Pillow>=10.0.0",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"psycopg2-binary>=2.9.0",
"python-dotenv>=1.0.0",
"pyyaml>=6.0",
"thefuzz>=0.20.0",
],
)

View File

@@ -0,0 +1,2 @@
# Invoice Master POC v2
# Automatic invoice information extraction system using YOLO + OCR

View File

@@ -0,0 +1,88 @@
"""
Configuration settings for the invoice extraction system.
"""
import os
import platform
from pathlib import Path
from dotenv import load_dotenv
# Load environment variables from .env file at project root
# Walk up from packages/shared/shared/config.py to find project root
_config_dir = Path(__file__).parent
for _candidate in [_config_dir.parent.parent.parent, _config_dir.parent.parent, _config_dir.parent]:
_env_path = _candidate / '.env'
if _env_path.exists():
load_dotenv(dotenv_path=_env_path)
break
else:
load_dotenv() # fallback: search cwd and parents
# Global DPI setting - must match training DPI for optimal model performance
DEFAULT_DPI = 150
def _is_wsl() -> bool:
"""Check if running inside WSL (Windows Subsystem for Linux)."""
if platform.system() != 'Linux':
return False
# Check for WSL-specific indicators
if os.environ.get('WSL_DISTRO_NAME'):
return True
try:
with open('/proc/version', 'r') as f:
return 'microsoft' in f.read().lower()
except (FileNotFoundError, PermissionError):
return False
# PostgreSQL Database Configuration
# Now loaded from environment variables for security
DATABASE = {
'host': os.getenv('DB_HOST', '192.168.68.31'),
'port': int(os.getenv('DB_PORT', '5432')),
'database': os.getenv('DB_NAME', 'docmaster'),
'user': os.getenv('DB_USER', 'docmaster'),
'password': os.getenv('DB_PASSWORD'), # No default for security
}
# Validate required configuration
if not DATABASE['password']:
raise ValueError(
"DB_PASSWORD environment variable is not set. "
"Please create a .env file based on .env.example and set DB_PASSWORD."
)
# Connection string for psycopg2
def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
# Paths Configuration - auto-detect WSL vs Windows
if _is_wsl():
# WSL: use native Linux filesystem for better I/O performance
PATHS = {
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
'reports_dir': 'reports', # Keep reports in project directory
}
else:
# Windows or native Linux: use relative paths
PATHS = {
'csv_dir': 'data/structured_data',
'pdf_dir': 'data/raw_pdfs',
'output_dir': 'data/dataset',
'reports_dir': 'reports',
}
# Auto-labeling Configuration
AUTOLABEL = {
'workers': 2,
'dpi': DEFAULT_DPI,
'min_confidence': 0.5,
'train_ratio': 0.8,
'val_ratio': 0.1,
'test_ratio': 0.1,
'max_records_per_report': 10000,
}

View File

@@ -0,0 +1,3 @@
from .csv_loader import CSVLoader, InvoiceRow
__all__ = ['CSVLoader', 'InvoiceRow']

View File

@@ -0,0 +1,372 @@
"""
CSV Data Loader
Loads and parses structured invoice data from CSV files.
Follows the CSV specification for invoice data.
"""
import csv
from dataclasses import dataclass, field
from datetime import datetime, date
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import Any, Iterator
@dataclass
class InvoiceRow:
"""Parsed invoice data row."""
DocumentId: str
InvoiceDate: date | None = None
InvoiceNumber: str | None = None
InvoiceDueDate: date | None = None
OCR: str | None = None
Message: str | None = None
Bankgiro: str | None = None
Plusgiro: str | None = None
Amount: Decimal | None = None
# New fields
split: str | None = None # train/test split indicator
customer_number: str | None = None # Customer number (needs matching)
supplier_name: str | None = None # Supplier name (no matching)
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
supplier_accounts: str | None = None # Supplier accounts (needs matching)
# Raw values for reference
raw_data: dict = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for matching."""
return {
'DocumentId': self.DocumentId,
'InvoiceDate': self.InvoiceDate.isoformat() if self.InvoiceDate else None,
'InvoiceNumber': self.InvoiceNumber,
'InvoiceDueDate': self.InvoiceDueDate.isoformat() if self.InvoiceDueDate else None,
'OCR': self.OCR,
'Bankgiro': self.Bankgiro,
'Plusgiro': self.Plusgiro,
'Amount': str(self.Amount) if self.Amount else None,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
}
def get_field_value(self, field_name: str) -> str | None:
"""Get field value as string for matching."""
value = getattr(self, field_name, None)
if value is None:
return None
if isinstance(value, date):
return value.isoformat()
if isinstance(value, Decimal):
return str(value)
return str(value) if value else None
class CSVLoader:
"""Loads invoice data from CSV files."""
# Expected field mappings (CSV header -> InvoiceRow attribute)
FIELD_MAPPINGS = {
'DocumentId': 'DocumentId',
'InvoiceDate': 'InvoiceDate',
'InvoiceNumber': 'InvoiceNumber',
'InvoiceDueDate': 'InvoiceDueDate',
'OCR': 'OCR',
'Message': 'Message',
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'Amount': 'Amount',
# New fields
'split': 'split',
'customer_number': 'customer_number',
'supplier_name': 'supplier_name',
'supplier_organisation_number': 'supplier_organisation_number',
'supplier_accounts': 'supplier_accounts',
}
def __init__(
self,
csv_path: str | Path | list[str | Path],
pdf_dir: str | Path | None = None,
doc_map_path: str | Path | None = None,
encoding: str = 'utf-8'
):
"""
Initialize CSV loader.
Args:
csv_path: Path to CSV file(s). Can be:
- Single path: 'data/file.csv'
- List of paths: ['data/file1.csv', 'data/file2.csv']
- Glob pattern: 'data/*.csv' or 'data/export_*.csv'
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
doc_map_path: Optional path to document mapping CSV
encoding: CSV file encoding (default: utf-8)
"""
# Handle multiple CSV files
if isinstance(csv_path, list):
self.csv_paths = [Path(p) for p in csv_path]
else:
csv_path = Path(csv_path)
# Check if it's a glob pattern (contains * or ?)
if '*' in str(csv_path) or '?' in str(csv_path):
parent = csv_path.parent
pattern = csv_path.name
self.csv_paths = sorted(parent.glob(pattern))
else:
self.csv_paths = [csv_path]
# For backward compatibility
self.csv_path = self.csv_paths[0] if self.csv_paths else None
self.pdf_dir = Path(pdf_dir) if pdf_dir else (self.csv_path.parent.parent / 'raw_pdfs' if self.csv_path else Path('data/raw_pdfs'))
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
self.encoding = encoding
# Load document mapping if provided
self.doc_map = self._load_doc_map() if self.doc_map_path else {}
def _load_doc_map(self) -> dict[str, str]:
"""Load document ID to filename mapping."""
mapping = {}
if self.doc_map_path and self.doc_map_path.exists():
with open(self.doc_map_path, 'r', encoding=self.encoding) as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row.get('DocumentId', '').strip()
filename = row.get('FileName', '').strip()
if doc_id and filename:
mapping[doc_id] = filename
return mapping
def _parse_date(self, value: str | None) -> date | None:
"""Parse date from various formats."""
if not value or not value.strip():
return None
value = value.strip()
# Try different date formats
formats = [
'%Y-%m-%d',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%d/%m/%Y',
'%d.%m.%Y',
'%d-%m-%Y',
'%Y%m%d',
]
for fmt in formats:
try:
return datetime.strptime(value, fmt).date()
except ValueError:
continue
return None
def _parse_amount(self, value: str | None) -> Decimal | None:
"""Parse monetary amount from various formats."""
if not value or not value.strip():
return None
value = value.strip()
# Remove currency symbols and common suffixes
value = value.replace('SEK', '').replace('kr', '').replace(':-', '')
value = value.strip()
# Remove spaces (thousand separators)
value = value.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator (European format)
if ',' in value and '.' not in value:
value = value.replace(',', '.')
elif ',' in value and '.' in value:
# Assume comma is thousands separator, dot is decimal
value = value.replace(',', '')
try:
return Decimal(value)
except InvalidOperation:
return None
def _parse_string(self, value: str | None) -> str | None:
"""Parse string field with cleanup."""
if value is None:
return None
value = value.strip()
return value if value else None
def _get_field(self, row: dict, *keys: str) -> str | None:
"""Get field value trying multiple possible column names."""
for key in keys:
value = row.get(key)
if value is not None:
return value
return None
def _parse_row(self, row: dict) -> InvoiceRow | None:
"""Parse a single CSV row into InvoiceRow."""
doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id'))
if not doc_id:
return None
return InvoiceRow(
DocumentId=doc_id,
InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')),
InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')),
InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')),
OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')),
Message=self._parse_string(self._get_field(row, 'Message', 'message')),
Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')),
Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')),
Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')),
# New fields
split=self._parse_string(row.get('split')),
customer_number=self._parse_string(row.get('customer_number')),
supplier_name=self._parse_string(row.get('supplier_name')),
supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')),
supplier_accounts=self._parse_string(row.get('supplier_accounts')),
raw_data=dict(row)
)
def _iter_single_csv(self, csv_path: Path) -> Iterator[InvoiceRow]:
"""Iterate over rows from a single CSV file."""
# Handle BOM - try utf-8-sig first to handle BOM correctly
encodings = ['utf-8-sig', self.encoding, 'latin-1']
for enc in encodings:
try:
with open(csv_path, 'r', encoding=enc) as f:
reader = csv.DictReader(f)
for row in reader:
parsed = self._parse_row(row)
if parsed:
yield parsed
return
except UnicodeDecodeError:
continue
raise ValueError(f"Could not read CSV file {csv_path} with any supported encoding")
def load_all(self) -> list[InvoiceRow]:
"""Load all rows from CSV(s)."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows from all CSV files."""
seen_doc_ids = set()
for csv_path in self.csv_paths:
if not csv_path.exists():
continue
for row in self._iter_single_csv(csv_path):
# Deduplicate by DocumentId
if row.DocumentId not in seen_doc_ids:
seen_doc_ids.add(row.DocumentId)
yield row
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
"""
Get PDF path for an invoice row.
Uses document mapping if available, otherwise assumes
DocumentId.pdf naming convention.
"""
doc_id = invoice_row.DocumentId
# Check document mapping first
if doc_id in self.doc_map:
filename = self.doc_map[doc_id]
pdf_path = self.pdf_dir / filename
if pdf_path.exists():
return pdf_path
# Try default naming patterns
patterns = [
f"{doc_id}.pdf",
f"{doc_id}.PDF",
f"{doc_id.lower()}.pdf",
f"{doc_id.lower()}.PDF",
f"{doc_id.upper()}.pdf",
f"{doc_id.upper()}.PDF",
]
for pattern in patterns:
pdf_path = self.pdf_dir / pattern
if pdf_path.exists():
return pdf_path
# Try glob patterns for partial matches (both cases)
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.pdf"):
return pdf_file
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.PDF"):
return pdf_file
return None
def get_row_by_id(self, doc_id: str) -> InvoiceRow | None:
"""Get a specific row by DocumentId."""
for row in self.iter_rows():
if row.DocumentId == doc_id:
return row
return None
def validate(self) -> list[dict]:
"""
Validate CSV data and return issues.
Returns:
List of validation issues
"""
issues = []
for i, row in enumerate(self.iter_rows(), start=2): # Start at 2 (header is row 1)
# Check required DocumentId
if not row.DocumentId:
issues.append({
'row': i,
'field': 'DocumentId',
'issue': 'Missing required DocumentId'
})
continue
# Check if PDF exists
pdf_path = self.get_pdf_path(row)
if not pdf_path:
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'PDF',
'issue': 'PDF file not found'
})
# Check for at least one matchable field
matchable_fields = [
row.InvoiceNumber,
row.OCR,
row.Bankgiro,
row.Plusgiro,
row.Amount,
row.supplier_organisation_number,
row.supplier_accounts,
]
if not any(matchable_fields):
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'All',
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)'
})
return issues
def load_invoice_csv(csv_path: str | Path | list[str | Path], pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
"""Convenience function to load invoice CSV(s)."""
loader = CSVLoader(csv_path, pdf_dir)
return loader.load_all()

View File

@@ -0,0 +1,530 @@
"""
Database utilities for autolabel workflow.
"""
import json
import psycopg2
from psycopg2.extras import execute_values
from typing import Set, Dict, Any, Optional
import sys
from pathlib import Path
from shared.config import get_db_connection_string
class DocumentDB:
"""Database interface for document processing status."""
def __init__(self, connection_string: str = None):
self.connection_string = connection_string or get_db_connection_string()
self.conn = None
def connect(self):
"""Connect to database."""
if self.conn is None:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def create_tables(self):
"""Create database tables if they don't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
document_id TEXT PRIMARY KEY,
pdf_path TEXT,
pdf_type TEXT,
success BOOLEAN,
total_pages INTEGER,
fields_matched INTEGER,
fields_total INTEGER,
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]',
-- Extended CSV format fields
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
);
CREATE TABLE IF NOT EXISTS field_results (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
field_name TEXT,
csv_value TEXT,
matched BOOLEAN,
score REAL,
matched_text TEXT,
candidate_used TEXT,
bbox JSONB,
page_no INTEGER,
context_keywords JSONB DEFAULT '[]',
error TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""")
conn.commit()
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def get_successful_doc_ids(self) -> Set[str]:
"""Get all document IDs that have been successfully processed."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents WHERE success = true")
return {row[0] for row in cursor.fetchall()}
def get_failed_doc_ids(self) -> Set[str]:
"""Get all document IDs that failed processing."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents WHERE success = false")
return {row[0] for row in cursor.fetchall()}
def check_document_status(self, doc_id: str) -> Optional[bool]:
"""
Check if a document exists and its success status.
Returns:
True if exists and success=true
False if exists and success=false
None if not exists
"""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute(
"SELECT success FROM documents WHERE document_id = %s",
(doc_id,)
)
row = cursor.fetchone()
if row is None:
return None
return row[0]
def check_documents_status_batch(self, doc_ids: list[str]) -> Dict[str, Optional[bool]]:
"""
Batch check document status for multiple IDs.
Returns:
Dict mapping doc_id to status:
True if exists and success=true
False if exists and success=false
(missing from dict if not exists)
"""
if not doc_ids:
return {}
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute(
"SELECT document_id, success FROM documents WHERE document_id = ANY(%s)",
(doc_ids,)
)
return {row[0]: row[1] for row in cursor.fetchall()}
def delete_document(self, doc_id: str):
"""Delete a document and its field results (for re-processing)."""
conn = self.connect()
with conn.cursor() as cursor:
# field_results will be cascade deleted
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
conn.commit()
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""Get a single document with its field results."""
conn = self.connect()
with conn.cursor() as cursor:
# Get document
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
if not row:
return None
doc = {
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6],
'annotations_generated': row[7],
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': []
}
# Get field results
cursor.execute("""
SELECT field_name, csv_value, matched, score, matched_text,
candidate_used, bbox, page_no, context_keywords, error
FROM field_results WHERE document_id = %s
""", (doc_id,))
for fr in cursor.fetchall():
doc['field_results'].append({
'field_name': fr[0],
'csv_value': fr[1],
'matched': fr[2],
'score': fr[3],
'matched_text': fr[4],
'candidate_used': fr[5],
'bbox': fr[6] if isinstance(fr[6], list) else json.loads(fr[6]) if fr[6] else None,
'page_no': fr[7],
'context_keywords': fr[8] if isinstance(fr[8], list) else json.loads(fr[8] or '[]'),
'error': fr[9]
})
return doc
def get_all_documents_summary(self, success_only: bool = False, limit: int = None) -> list[Dict[str, Any]]:
"""Get summary of all documents (without field_results for efficiency)."""
conn = self.connect()
with conn.cursor() as cursor:
query = """
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total
FROM documents
"""
params = []
if success_only:
query += " WHERE success = true"
query += " ORDER BY timestamp DESC"
if limit:
# Use parameterized query instead of f-string
query += " LIMIT %s"
params.append(limit)
cursor.execute(query, params if params else None)
return [
{
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6]
}
for row in cursor.fetchall()
]
def get_field_stats(self) -> Dict[str, Dict[str, int]]:
"""Get match statistics per field."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name,
COUNT(*) as total,
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched
FROM field_results
GROUP BY field_name
ORDER BY field_name
""")
return {
row[0]: {'total': row[1], 'matched': row[2]}
for row in cursor.fetchall()
}
def get_failed_matches(self, field_name: str = None, limit: int = 100) -> list[Dict[str, Any]]:
"""Get field results that failed to match."""
conn = self.connect()
with conn.cursor() as cursor:
query = """
SELECT fr.document_id, fr.field_name, fr.csv_value, fr.error,
d.pdf_type
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
WHERE fr.matched = false
"""
params = []
if field_name:
query += " AND fr.field_name = %s"
params.append(field_name)
# Use parameterized query instead of f-string
query += " LIMIT %s"
params.append(limit)
cursor.execute(query, params)
return [
{
'document_id': row[0],
'field_name': row[1],
'csv_value': row[2],
'error': row[3],
'pdf_type': row[4]
}
for row in cursor.fetchall()
]
def get_documents_batch(self, doc_ids: list[str]) -> Dict[str, Dict[str, Any]]:
"""
Get multiple documents with their field results in a single batch query.
This is much more efficient than calling get_document() in a loop.
Args:
doc_ids: List of document IDs to fetch
Returns:
Dict mapping doc_id to document data (with field_results)
"""
if not doc_ids:
return {}
conn = self.connect()
result: Dict[str, Dict[str, Any]] = {}
with conn.cursor() as cursor:
# Batch fetch all documents
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = ANY(%s)
""", (doc_ids,))
for row in cursor.fetchall():
result[row[0]] = {
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6],
'annotations_generated': row[7],
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': []
}
if not result:
return {}
# Batch fetch all field results for these documents
cursor.execute("""
SELECT document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error
FROM field_results WHERE document_id = ANY(%s)
""", (list(result.keys()),))
for fr in cursor.fetchall():
doc_id = fr[0]
if doc_id in result:
result[doc_id]['field_results'].append({
'field_name': fr[1],
'csv_value': fr[2],
'matched': fr[3],
'score': fr[4],
'matched_text': fr[5],
'candidate_used': fr[6],
'bbox': fr[7] if isinstance(fr[7], list) else json.loads(fr[7]) if fr[7] else None,
'page_no': fr[8],
'context_keywords': fr[9] if isinstance(fr[9], list) else json.loads(fr[9] or '[]'),
'error': fr[10]
})
return result
def save_document(self, report: Dict[str, Any]):
"""Save or update a document and its field results using batch operations."""
conn = self.connect()
doc_id = report.get('document_id')
with conn.cursor() as cursor:
# Delete existing record if any (for update)
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
# Insert document
cursor.execute("""
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
doc_id,
report.get('pdf_path'),
report.get('pdf_type'),
report.get('success'),
report.get('total_pages'),
report.get('fields_matched'),
report.get('fields_total'),
report.get('annotations_generated'),
report.get('processing_time_ms'),
report.get('timestamp'),
json.dumps(report.get('errors', [])),
# New fields
report.get('split'),
report.get('customer_number'),
report.get('supplier_name'),
report.get('supplier_organisation_number'),
report.get('supplier_accounts'),
))
# Batch insert field results using execute_values
field_results = report.get('field_results', [])
if field_results:
field_values = [
(
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
)
for field in field_results
]
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_values)
conn.commit()
def save_documents_batch(self, reports: list[Dict[str, Any]]):
"""Save multiple documents in a batch."""
if not reports:
return
conn = self.connect()
doc_ids = [r['document_id'] for r in reports]
with conn.cursor() as cursor:
# Delete existing records
cursor.execute(
"DELETE FROM documents WHERE document_id = ANY(%s)",
(doc_ids,)
)
# Batch insert documents
doc_values = [
(
r.get('document_id'),
r.get('pdf_path'),
r.get('pdf_type'),
r.get('success'),
r.get('total_pages'),
r.get('fields_matched'),
r.get('fields_total'),
r.get('annotations_generated'),
r.get('processing_time_ms'),
r.get('timestamp'),
json.dumps(r.get('errors', [])),
# New fields
r.get('split'),
r.get('customer_number'),
r.get('supplier_name'),
r.get('supplier_organisation_number'),
r.get('supplier_accounts'),
)
for r in reports
]
execute_values(cursor, """
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s
""", doc_values)
# Batch insert field results
field_values = []
for r in reports:
doc_id = r.get('document_id')
for field in r.get('field_results', []):
field_values.append((
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
))
if field_values:
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_values)
conn.commit()

View File

@@ -0,0 +1,102 @@
"""
Application-specific exceptions for invoice extraction system.
This module defines a hierarchy of custom exceptions to provide better
error handling and debugging capabilities throughout the application.
"""
class InvoiceExtractionError(Exception):
"""Base exception for all invoice extraction errors."""
def __init__(self, message: str, details: dict = None):
"""
Initialize exception with message and optional details.
Args:
message: Human-readable error message
details: Optional dict with additional error context
"""
super().__init__(message)
self.message = message
self.details = details or {}
def __str__(self):
if self.details:
details_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
return f"{self.message} ({details_str})"
return self.message
class PDFProcessingError(InvoiceExtractionError):
"""Error during PDF processing (rendering, conversion)."""
pass
class OCRError(InvoiceExtractionError):
"""Error during OCR processing."""
pass
class ModelInferenceError(InvoiceExtractionError):
"""Error during YOLO model inference."""
pass
class FieldValidationError(InvoiceExtractionError):
"""Error during field validation or normalization."""
def __init__(self, field_name: str, value: str, reason: str, details: dict = None):
"""
Initialize field validation error.
Args:
field_name: Name of the field that failed validation
value: The invalid value
reason: Why validation failed
details: Additional context
"""
message = f"Field '{field_name}' validation failed: {reason}"
super().__init__(message, details)
self.field_name = field_name
self.value = value
self.reason = reason
class DatabaseError(InvoiceExtractionError):
"""Error during database operations."""
pass
class ConfigurationError(InvoiceExtractionError):
"""Error in application configuration."""
pass
class PaymentLineParseError(InvoiceExtractionError):
"""Error parsing Swedish payment line format."""
pass
class CustomerNumberParseError(InvoiceExtractionError):
"""Error parsing Swedish customer number."""
pass
class DataLoadError(InvoiceExtractionError):
"""Error loading data from CSV or other sources."""
pass
class AnnotationError(InvoiceExtractionError):
"""Error generating or processing YOLO annotations."""
pass

View File

@@ -0,0 +1,4 @@
from .field_matcher import FieldMatcher, find_field_matches
from .models import Match, TokenLike
__all__ = ['FieldMatcher', 'Match', 'TokenLike', 'find_field_matches']

View File

@@ -0,0 +1,92 @@
"""
Context keywords for field matching.
"""
from .models import TokenLike
from .token_index import TokenIndex
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
def find_context_keywords(
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str,
context_radius: float,
token_index: TokenIndex | None = None
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
Args:
tokens: List of all tokens
target_token: The token to find context for
field_name: Name of the field
context_radius: Search radius in pixels
token_index: Optional spatial index for efficient lookup
Returns:
Tuple of (found_keywords, boost_score)
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if token_index:
nearby_tokens = token_index.find_nearby(target_token, context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost

View File

@@ -0,0 +1,219 @@
"""
Field Matching Module - Refactored
Matches normalized field values to tokens extracted from documents.
"""
from .models import TokenLike, Match
from .token_index import TokenIndex
from .utils import bbox_overlap
from .strategies import (
ExactMatcher,
ConcatenatedMatcher,
SubstringMatcher,
FuzzyMatcher,
FlexibleDateMatcher,
)
class FieldMatcher:
"""Matches field values to document tokens."""
def __init__(
self,
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
# Initialize matching strategies
self.exact_matcher = ExactMatcher(context_radius)
self.concatenated_matcher = ConcatenatedMatcher(context_radius)
self.substring_matcher = SubstringMatcher(context_radius)
self.fuzzy_matcher = FuzzyMatcher(context_radius)
self.flexible_date_matcher = FlexibleDateMatcher(context_radius)
def find_matches(
self,
tokens: list[TokenLike],
field_name: str,
normalized_values: list[str],
page_no: int = 0
) -> list[Match]:
"""
Find all matches for a field in the token list.
Args:
tokens: List of tokens from the document
field_name: Name of the field to match
normalized_values: List of normalized value variants to search for
page_no: Page number to filter tokens
Returns:
List of Match objects sorted by score (descending)
"""
matches = []
# Filter tokens by page and exclude hidden metadata tokens
# Hidden tokens often have bbox with y < 0 or y > page_height
# These are typically PDF metadata stored as invisible text
page_tokens = [
t for t in tokens
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self.exact_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation
concat_matches = self.concatenated_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self.fuzzy_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
# Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in (
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
'Bankgiro', 'Plusgiro', 'supplier_organisation_number',
'supplier_accounts', 'customer_number'
):
substring_matches = self.substring_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
for value in normalized_values:
flexible_matches = self.flexible_date_matcher.find_matches(
page_tokens, value, field_name, self._token_index
)
matches.extend(flexible_matches)
# Deduplicate and sort by score
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold]
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches:
return []
# Sort by: 1) score descending, 2) prefer matches with context keywords,
# 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = []
for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False
cells_to_check = set()
for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if is_duplicate:
break
if not is_duplicate:
unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique
def find_field_matches(
tokens: list[TokenLike],
field_values: dict[str, str],
page_no: int = 0
) -> dict[str, list[Match]]:
"""
Convenience function to find matches for multiple fields.
Args:
tokens: List of tokens from the document
field_values: Dict of field_name -> value to search for
page_no: Page number
Returns:
Dict of field_name -> list of matches
"""
from ..normalize import normalize_field
matcher = FieldMatcher()
results = {}
for field_name, value in field_values.items():
if value is None or str(value).strip() == '':
continue
normalized_values = normalize_field(field_name, str(value))
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
results[field_name] = matches
return results

View File

@@ -0,0 +1,875 @@
"""
Field Matching Module
Matches normalized field values to tokens extracted from documents.
"""
from dataclasses import dataclass, field
from typing import Protocol
import re
from functools import cached_property
# Pre-compiled regex patterns (module-level for efficiency)
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D')
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def _normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return _DASH_PATTERN.sub('-', text)
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
}
class FieldMatcher:
"""Matches field values to document tokens."""
def __init__(
self,
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
def find_matches(
self,
tokens: list[TokenLike],
field_name: str,
normalized_values: list[str],
page_no: int = 0
) -> list[Match]:
"""
Find all matches for a field in the token list.
Args:
tokens: List of tokens from the document
field_name: Name of the field to match
normalized_values: List of normalized value variants to search for
page_no: Page number to filter tokens
Returns:
List of Match objects sorted by score (descending)
"""
matches = []
# Filter tokens by page and exclude hidden metadata tokens
# Hidden tokens often have bbox with y < 0 or y > page_height
# These are typically PDF metadata stored as invisible text
page_tokens = [
t for t in tokens
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
# Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
flexible_matches = self._find_flexible_date_matches(
page_tokens, normalized_values, field_name
)
matches.extend(flexible_matches)
# Deduplicate and sort by score
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find tokens that exactly match the value."""
matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts') else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches
def _find_concatenated_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find value by concatenating adjacent tokens."""
matches = []
value_clean = _WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not self._tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches
def _find_substring_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
matches = []
# Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = _normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches
def _find_fuzzy_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find approximate matches for amounts and dates."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = self._parse_amount(token_text)
value_num = self._parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches
def _find_flexible_date_matches(
self,
tokens: list[TokenLike],
normalized_values: list[str],
field_name: str
) -> list[Match]:
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
from datetime import datetime, timedelta
matches = []
# Parse the target date from normalized values
target_date = None
for value in normalized_values:
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
break
except ValueError:
continue
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in _DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= self.min_score_threshold:
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []
def _find_context_keywords(
self,
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str
) -> tuple[list[str], float]:
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
# Use spatial index for efficient nearby token lookup
if self._token_index:
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = self._token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def _parse_amount(self, text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches:
return []
# Sort by: 1) score descending, 2) prefer matches with context keywords,
# 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = []
for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False
cells_to_check = set()
for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if is_duplicate:
break
if not is_duplicate:
unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique
def _bbox_overlap(
self,
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def find_field_matches(
tokens: list[TokenLike],
field_values: dict[str, str],
page_no: int = 0
) -> dict[str, list[Match]]:
"""
Convenience function to find matches for multiple fields.
Args:
tokens: List of tokens from the document
field_values: Dict of field_name -> value to search for
page_no: Page number
Returns:
Dict of field_name -> list of matches
"""
from ..normalize import normalize_field
matcher = FieldMatcher()
results = {}
for field_name, value in field_values.items():
if value is None or str(value).strip() == '':
continue
normalized_values = normalize_field(field_name, str(value))
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
results[field_name] = matches
return results

View File

@@ -0,0 +1,36 @@
"""
Data models for field matching.
"""
from dataclasses import dataclass
from typing import Protocol
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"

View File

@@ -0,0 +1,17 @@
"""
Matching strategies for field matching.
"""
from .exact_matcher import ExactMatcher
from .concatenated_matcher import ConcatenatedMatcher
from .substring_matcher import SubstringMatcher
from .fuzzy_matcher import FuzzyMatcher
from .flexible_date_matcher import FlexibleDateMatcher
__all__ = [
'ExactMatcher',
'ConcatenatedMatcher',
'SubstringMatcher',
'FuzzyMatcher',
'FlexibleDateMatcher',
]

View File

@@ -0,0 +1,42 @@
"""
Base class for matching strategies.
"""
from abc import ABC, abstractmethod
from ..models import TokenLike, Match
from ..token_index import TokenIndex
class BaseMatchStrategy(ABC):
"""Base class for all matching strategies."""
def __init__(self, context_radius: float = 200.0):
"""
Initialize the strategy.
Args:
context_radius: Distance to search for context keywords
"""
self.context_radius = context_radius
@abstractmethod
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""
Find matches for the given value.
Args:
tokens: List of tokens to search
value: Value to find
field_name: Name of the field
token_index: Optional spatial index for efficient lookup
Returns:
List of Match objects
"""
pass

View File

@@ -0,0 +1,73 @@
"""
Concatenated match strategy - finds value by concatenating adjacent tokens.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import WHITESPACE_PATTERN, tokens_on_same_line
class ConcatenatedMatcher(BaseMatchStrategy):
"""Find value by concatenating adjacent tokens."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find concatenated matches."""
matches = []
value_clean = WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = find_context_keywords(
tokens, start_token, field_name, self.context_radius, token_index
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches

View File

@@ -0,0 +1,65 @@
"""
Exact match strategy.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import NON_DIGIT_PATTERN
class ExactMatcher(BaseMatchStrategy):
"""Find tokens that exactly match the value."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find exact matches."""
matches = []
value_lower = value.lower()
value_digits = NON_DIGIT_PATTERN.sub('', value) if field_name in (
'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts'
) else None
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match (use cached lowercase from index)
elif token_index and token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif value_digits is not None:
token_digits = NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches

View File

@@ -0,0 +1,149 @@
"""
Flexible date match strategy - finds dates with year-month or nearby date matching.
"""
import re
from datetime import datetime
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords, CONTEXT_KEYWORDS
from ..utils import DATE_PATTERN
class FlexibleDateMatcher(BaseMatchStrategy):
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find flexible date matches."""
matches = []
# Parse the target date from normalized values
target_date = None
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
except ValueError:
pass
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token (use pre-compiled pattern)
for match in DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= 0.5: # Min threshold for flexible matching
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []

View File

@@ -0,0 +1,52 @@
"""
Fuzzy match strategy for amounts and dates.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords
from ..utils import parse_amount
class FuzzyMatcher(BaseMatchStrategy):
"""Find approximate matches for amounts and dates."""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find fuzzy matches."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = parse_amount(token_text)
value_num = parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches

View File

@@ -0,0 +1,143 @@
"""
Substring match strategy - finds value as substring within longer tokens.
"""
from .base import BaseMatchStrategy
from ..models import TokenLike, Match
from ..token_index import TokenIndex
from ..context import find_context_keywords, CONTEXT_KEYWORDS
from ..utils import normalize_dashes
class SubstringMatcher(BaseMatchStrategy):
"""
Find value as a substring within longer tokens.
Handles cases like:
- 'Fakturadatum: 2026-01-09' where the date is embedded
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
- 'OCR: 1234567890' where reference number is embedded
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a larger number).
"""
def find_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str,
token_index: TokenIndex | None = None
) -> list[Match]:
"""Find substring matches."""
matches = []
# Supported fields for substring matching
supported_fields = (
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts', 'customer_number'
)
if field_name not in supported_fields:
return matches
# Fields where spaces/dashes should be ignored during matching
# (e.g., org number "55 65 74-6624" should match "5565746624")
ignore_spaces_fields = (
'supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts'
)
for token in tokens:
token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = normalize_dashes(token_text)
# For certain fields, also try matching with spaces/dashes removed
if field_name in ignore_spaces_fields:
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
value_compact = value.replace(' ', '').replace('-', '')
else:
token_text_compact = None
value_compact = None
# Skip if token is the same length as value (would be exact match)
if len(token_text_normalized) <= len(value):
continue
# Check if value appears as substring (using normalized text)
# Try case-sensitive first, then case-insensitive
idx = None
case_sensitive_match = True
used_compact = False
if value in token_text_normalized:
idx = token_text_normalized.find(value)
elif value.lower() in token_text_normalized.lower():
idx = token_text_normalized.lower().find(value.lower())
case_sensitive_match = False
elif token_text_compact and value_compact in token_text_compact:
# Try compact matching (spaces/dashes removed)
idx = token_text_compact.find(value_compact)
used_compact = True
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
idx = token_text_compact.lower().find(value_compact.lower())
case_sensitive_match = False
used_compact = True
if idx is None:
continue
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
if used_compact:
# Verify proper boundary in compact text
if idx > 0 and token_text_compact[idx - 1].isdigit():
continue
end_idx = idx + len(value_compact)
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
continue
else:
# Verify it's a proper boundary match (not part of a larger number)
# Check character before (if exists)
if idx > 0:
char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text_normalized):
char_after = token_text_normalized[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = find_context_keywords(
tokens, token, field_name, self.context_radius, token_index
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
# Lower score for case-insensitive match
base_score = 0.75 if case_sensitive_match else 0.70
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, base_score + context_boost + inline_boost),
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches

View File

@@ -0,0 +1,92 @@
"""
Spatial index for fast token lookup.
"""
from .models import TokenLike
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby

View File

@@ -0,0 +1,91 @@
"""
Utility functions for field matching.
"""
import re
# Pre-compiled regex patterns (module-level for efficiency)
DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
WHITESPACE_PATTERN = re.compile(r'\s+')
NON_DIGIT_PATTERN = re.compile(r'\D')
DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
def normalize_dashes(text: str) -> str:
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
return DASH_PATTERN.sub('-', text)
def parse_amount(text: str | int | float) -> float | None:
"""Try to parse text as a monetary amount."""
# Convert to string first
text = str(text)
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
# Pattern: digits + space + exactly 2 digits at end
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
if ore_match:
kronor = ore_match.group(1)
ore = ore_match.group(2)
try:
return float(f"{kronor}.{ore}")
except ValueError:
pass
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
text = re.sub(r'\s*\(.*\)', '', text)
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
text = re.sub(r'[:-]', '', text)
# Remove spaces (thousand separators) but be careful with öre format
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
# Swedish format: "500,00" means 500.00
# Need to handle cases like "500,00." (after removing "kr.")
if ',' in text:
# Remove any trailing dots first (from "kr." removal)
text = text.rstrip('.')
# Now replace comma with dot
if '.' not in text:
text = text.replace(',', '.')
# Remove any remaining non-numeric characters except dot
text = re.sub(r'[^\d.]', '', text)
try:
return float(text)
except ValueError:
return None
def tokens_on_same_line(token1, token2) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def bbox_overlap(
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0

View File

@@ -0,0 +1,3 @@
from .normalizer import normalize_field, FieldNormalizer
__all__ = ['normalize_field', 'FieldNormalizer']

View File

@@ -0,0 +1,186 @@
"""
Field Normalization Module
Normalizes field values to generate multiple candidate forms for matching.
This module now delegates to individual normalizer modules for each field type.
Each normalizer is a separate, reusable module that can be used independently.
"""
from dataclasses import dataclass
from typing import Callable
from shared.utils.text_cleaner import TextCleaner
# Import individual normalizers
from .normalizers import (
InvoiceNumberNormalizer,
OCRNormalizer,
BankgiroNormalizer,
PlusgiroNormalizer,
AmountNormalizer,
DateNormalizer,
OrganisationNumberNormalizer,
SupplierAccountsNormalizer,
CustomerNumberNormalizer,
)
@dataclass
class NormalizedValue:
"""Represents a normalized value with its variants."""
original: str
variants: list[str]
field_type: str
class FieldNormalizer:
"""
Handles normalization of different invoice field types.
This class now acts as a facade that delegates to individual
normalizer modules. Each field type has its own specialized
normalizer for better modularity and reusability.
"""
# Instantiate individual normalizers
_invoice_number = InvoiceNumberNormalizer()
_ocr_number = OCRNormalizer()
_bankgiro = BankgiroNormalizer()
_plusgiro = PlusgiroNormalizer()
_amount = AmountNormalizer()
_date = DateNormalizer()
_organisation_number = OrganisationNumberNormalizer()
_supplier_accounts = SupplierAccountsNormalizer()
_customer_number = CustomerNumberNormalizer()
# Common Swedish month names for backward compatibility
SWEDISH_MONTHS = DateNormalizer.SWEDISH_MONTHS
@staticmethod
def clean_text(text: str) -> str:
"""
Remove invisible characters and normalize whitespace and dashes.
Delegates to shared TextCleaner for consistency.
"""
return TextCleaner.clean_text(text)
@staticmethod
def normalize_invoice_number(value: str) -> list[str]:
"""
Normalize invoice number.
Delegates to InvoiceNumberNormalizer.
"""
return FieldNormalizer._invoice_number.normalize(value)
@staticmethod
def normalize_ocr_number(value: str) -> list[str]:
"""
Normalize OCR number (Swedish payment reference).
Delegates to OCRNormalizer.
"""
return FieldNormalizer._ocr_number.normalize(value)
@staticmethod
def normalize_bankgiro(value: str) -> list[str]:
"""
Normalize Bankgiro number.
Delegates to BankgiroNormalizer.
"""
return FieldNormalizer._bankgiro.normalize(value)
@staticmethod
def normalize_plusgiro(value: str) -> list[str]:
"""
Normalize Plusgiro number.
Delegates to PlusgiroNormalizer.
"""
return FieldNormalizer._plusgiro.normalize(value)
@staticmethod
def normalize_organisation_number(value: str) -> list[str]:
"""
Normalize Swedish organisation number and generate VAT number variants.
Delegates to OrganisationNumberNormalizer.
"""
return FieldNormalizer._organisation_number.normalize(value)
@staticmethod
def normalize_supplier_accounts(value: str) -> list[str]:
"""
Normalize supplier accounts field.
Delegates to SupplierAccountsNormalizer.
"""
return FieldNormalizer._supplier_accounts.normalize(value)
@staticmethod
def normalize_customer_number(value: str) -> list[str]:
"""
Normalize customer number.
Delegates to CustomerNumberNormalizer.
"""
return FieldNormalizer._customer_number.normalize(value)
@staticmethod
def normalize_amount(value: str) -> list[str]:
"""
Normalize monetary amount.
Delegates to AmountNormalizer.
"""
return FieldNormalizer._amount.normalize(value)
@staticmethod
def normalize_date(value: str) -> list[str]:
"""
Normalize date to YYYY-MM-DD and generate variants.
Delegates to DateNormalizer.
"""
return FieldNormalizer._date.normalize(value)
# Field type to normalizer mapping
NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
'InvoiceNumber': FieldNormalizer.normalize_invoice_number,
'OCR': FieldNormalizer.normalize_ocr_number,
'Bankgiro': FieldNormalizer.normalize_bankgiro,
'Plusgiro': FieldNormalizer.normalize_plusgiro,
'Amount': FieldNormalizer.normalize_amount,
'InvoiceDate': FieldNormalizer.normalize_date,
'InvoiceDueDate': FieldNormalizer.normalize_date,
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
'customer_number': FieldNormalizer.normalize_customer_number,
}
def normalize_field(field_name: str, value: str) -> list[str]:
"""
Normalize a field value based on its type.
Args:
field_name: Name of the field (e.g., 'InvoiceNumber', 'Amount')
value: Raw value to normalize
Returns:
List of normalized variants
"""
if value is None or (isinstance(value, str) and not value.strip()):
return []
value = str(value)
normalizer = NORMALIZERS.get(field_name)
if normalizer:
return normalizer(value)
# Default: just clean the text
return [FieldNormalizer.clean_text(value)]

View File

@@ -0,0 +1,28 @@
"""
Normalizer modules for different field types.
Each normalizer is responsible for generating variants of a field value
for matching against OCR text or other data sources.
"""
from .invoice_number_normalizer import InvoiceNumberNormalizer
from .ocr_normalizer import OCRNormalizer
from .bankgiro_normalizer import BankgiroNormalizer
from .plusgiro_normalizer import PlusgiroNormalizer
from .amount_normalizer import AmountNormalizer
from .date_normalizer import DateNormalizer
from .organisation_number_normalizer import OrganisationNumberNormalizer
from .supplier_accounts_normalizer import SupplierAccountsNormalizer
from .customer_number_normalizer import CustomerNumberNormalizer
__all__ = [
'InvoiceNumberNormalizer',
'OCRNormalizer',
'BankgiroNormalizer',
'PlusgiroNormalizer',
'AmountNormalizer',
'DateNormalizer',
'OrganisationNumberNormalizer',
'SupplierAccountsNormalizer',
'CustomerNumberNormalizer',
]

View File

@@ -0,0 +1,130 @@
"""
Amount Normalizer
Normalizes monetary amounts with various formats and separators.
"""
import re
from .base import BaseNormalizer
class AmountNormalizer(BaseNormalizer):
"""
Normalizes monetary amounts.
Handles Swedish and international formats with different
thousand/decimal separators.
Examples:
'114' -> ['114', '114,00', '114.00']
'114,00' -> ['114,00', '114.00', '114']
'1 234,56' -> ['1234,56', '1234.56', '1 234,56']
'3045 52' -> ['3045.52', '3045,52', '304552']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of amount."""
value = self.clean_text(value)
# Remove currency symbols and common suffixes
value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip()
variants = [value]
# Check for space as decimal separator: "3045 52"
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', value)
if space_decimal_match:
integer_part = space_decimal_match.group(1)
decimal_part = space_decimal_match.group(2)
variants.append(f"{integer_part}.{decimal_part}")
variants.append(f"{integer_part},{decimal_part}")
variants.append(f"{integer_part}{decimal_part}")
# Check for space as thousand separator: "10 571,00"
space_thousand_match = re.match(r'^(\d{1,3})[\s\xa0]+(\d{3})([,\.])(\d{2})$', value)
if space_thousand_match:
part1 = space_thousand_match.group(1)
part2 = space_thousand_match.group(2)
sep = space_thousand_match.group(3)
decimal = space_thousand_match.group(4)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(f"{combined}{decimal}")
other_sep = ',' if sep == '.' else '.'
variants.append(f"{part1} {part2}{other_sep}{decimal}")
# Handle US format: "1,390.00"
us_format_match = re.match(r'^(\d{1,3}),(\d{3})\.(\d{2})$', value)
if us_format_match:
part1 = us_format_match.group(1)
part2 = us_format_match.group(2)
decimal = us_format_match.group(3)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(combined)
variants.append(f"{part1}.{part2},{decimal}")
# Handle European format: "1.390,00"
eu_format_match = re.match(r'^(\d{1,3})\.(\d{3}),(\d{2})$', value)
if eu_format_match:
part1 = eu_format_match.group(1)
part2 = eu_format_match.group(2)
decimal = eu_format_match.group(3)
combined = f"{part1}{part2}"
variants.append(f"{combined}.{decimal}")
variants.append(f"{combined},{decimal}")
variants.append(combined)
variants.append(f"{part1},{part2}.{decimal}")
# Remove spaces (thousand separators)
no_space = value.replace(' ', '').replace('\xa0', '')
# Normalize decimal separator
if ',' in no_space:
dot_version = no_space.replace(',', '.')
variants.append(no_space)
variants.append(dot_version)
elif '.' in no_space:
comma_version = no_space.replace('.', ',')
variants.append(no_space)
variants.append(comma_version)
else:
# Integer amount - add decimal versions
variants.append(no_space)
variants.append(f"{no_space},00")
variants.append(f"{no_space}.00")
# Try to parse and get clean numeric value
try:
clean = no_space.replace(',', '.')
num = float(clean)
# Integer if no decimals
if num == int(num):
int_val = int(num)
variants.append(str(int_val))
variants.append(f"{int_val},00")
variants.append(f"{int_val}.00")
# European format with dot as thousand separator
if int_val >= 1000:
formatted = f"{int_val:,}".replace(',', '.')
variants.append(formatted)
variants.append(f"{formatted},00")
else:
variants.append(f"{num:.2f}")
variants.append(f"{num:.2f}".replace('.', ','))
# European format with dot as thousand separator
if num >= 1000:
formatted_str = f"{num:.2f}"
int_str, dec_str = formatted_str.split(".")
int_part = int(int_str)
formatted_int = f"{int_part:,}".replace(',', '.')
variants.append(f"{formatted_int},{dec_str}")
except ValueError:
pass
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,34 @@
"""
Bankgiro Number Normalizer
Normalizes Swedish Bankgiro account numbers.
"""
from .base import BaseNormalizer
from shared.utils.format_variants import FormatVariants
from shared.utils.text_cleaner import TextCleaner
class BankgiroNormalizer(BaseNormalizer):
"""
Normalizes Bankgiro numbers.
Generates format variants and OCR error variants.
Examples:
'5393-9484' -> ['5393-9484', '53939484', ...]
'53939484' -> ['53939484', '5393-9484', ...]
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of Bankgiro number."""
# Use shared module for base variants
variants = set(FormatVariants.bankgiro_variants(value))
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)

View File

@@ -0,0 +1,34 @@
"""
Base class for field normalizers.
"""
from abc import ABC, abstractmethod
from shared.utils.text_cleaner import TextCleaner
class BaseNormalizer(ABC):
"""Base class for all field normalizers."""
@staticmethod
def clean_text(text: str) -> str:
"""Clean text using shared TextCleaner."""
return TextCleaner.clean_text(text)
@abstractmethod
def normalize(self, value: str) -> list[str]:
"""
Normalize a field value and return all variants.
Args:
value: Raw field value
Returns:
List of normalized variants for matching
"""
pass
def __call__(self, value: str) -> list[str]:
"""Allow normalizer to be called as a function."""
if value is None or (isinstance(value, str) and not value.strip()):
return []
return self.normalize(str(value))

View File

@@ -0,0 +1,49 @@
"""
Customer Number Normalizer
Normalizes customer numbers (alphanumeric codes).
"""
from .base import BaseNormalizer
class CustomerNumberNormalizer(BaseNormalizer):
"""
Normalizes customer numbers.
Customer numbers can have various formats:
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
- Pure numbers: '12345', '123-456'
Examples:
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
'ABC 123' -> ['ABC 123', 'ABC123']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of customer number."""
value = self.clean_text(value)
variants = [value]
# Version without spaces
no_space = value.replace(' ', '')
if no_space != value:
variants.append(no_space)
# Version without dashes
no_dash = value.replace('-', '')
if no_dash != value:
variants.append(no_dash)
# Version without spaces and dashes
clean = value.replace(' ', '').replace('-', '')
if clean != value and clean not in variants:
variants.append(clean)
# Uppercase and lowercase versions
if value.upper() != value:
variants.append(value.upper())
if value.lower() != value:
variants.append(value.lower())
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,190 @@
"""
Date Normalizer
Normalizes dates in various formats to ISO and generates variants.
"""
import re
from datetime import datetime
from .base import BaseNormalizer
class DateNormalizer(BaseNormalizer):
"""
Normalizes dates to YYYY-MM-DD and generates variants.
Handles Swedish and international date formats.
Examples:
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
'13 december 2025' -> ['2025-12-13', ...]
"""
# Swedish month names
SWEDISH_MONTHS = {
'januari': '01', 'jan': '01',
'februari': '02', 'feb': '02',
'mars': '03', 'mar': '03',
'april': '04', 'apr': '04',
'maj': '05',
'juni': '06', 'jun': '06',
'juli': '07', 'jul': '07',
'augusti': '08', 'aug': '08',
'september': '09', 'sep': '09', 'sept': '09',
'oktober': '10', 'okt': '10',
'november': '11', 'nov': '11',
'december': '12', 'dec': '12'
}
def normalize(self, value: str) -> list[str]:
"""Generate variants of date."""
value = self.clean_text(value)
variants = [value]
parsed_dates = []
# Try unambiguous patterns first
date_patterns = [
# ISO format with optional time
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$',
lambda m: (int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYMMDD
(r'^(\d{2})(\d{2})(\d{2})$',
lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYYYMMDD
(r'^(\d{4})(\d{2})(\d{2})$',
lambda m: (int(m[1]), int(m[2]), int(m[3]))),
]
for pattern, extractor in date_patterns:
match = re.match(pattern, value)
if match:
try:
year, month, day = extractor(match)
parsed_dates.append(datetime(year, month, day))
break
except ValueError:
continue
# Try ambiguous patterns with 4-digit year
ambiguous_patterns_4digit = [
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
]
if not parsed_dates:
for pattern in ambiguous_patterns_4digit:
match = re.match(pattern, value)
if match:
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
# Try DD/MM/YYYY (European - day first)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YYYY (US - month first) if different
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try ambiguous patterns with 2-digit year
ambiguous_patterns_2digit = [
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
]
if not parsed_dates:
for pattern in ambiguous_patterns_2digit:
match = re.match(pattern, value)
if match:
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
year = 2000 + yy if yy < 50 else 1900 + yy
# Try DD/MM/YY (European)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YY (US) if different
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try Swedish month names
if not parsed_dates:
for month_name, month_num in self.SWEDISH_MONTHS.items():
if month_name in value.lower():
numbers = re.findall(r'\d+', value)
if len(numbers) >= 2:
day = int(numbers[0])
year = int(numbers[-1])
if year < 100:
year = 2000 + year if year < 50 else 1900 + year
try:
parsed_dates.append(datetime(year, int(month_num), day))
break
except ValueError:
continue
# Generate variants for all parsed dates
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
for parsed_date in parsed_dates:
iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y')
us_slash = parsed_date.strftime('%m/%d/%Y')
eu_dot = parsed_date.strftime('%d.%m.%Y')
iso_dot = parsed_date.strftime('%Y.%m.%d')
compact = parsed_date.strftime('%Y%m%d')
compact_short = parsed_date.strftime('%y%m%d')
eu_dot_short = parsed_date.strftime('%d.%m.%y')
eu_slash_short = parsed_date.strftime('%d/%m/%y')
yy_mm_dd_short = parsed_date.strftime('%y-%m-%d')
iso_middot = parsed_date.strftime('%%%d')
spaced_full = parsed_date.strftime('%Y %m %d')
spaced_short = parsed_date.strftime('%y %m %d')
# Swedish month name formats
month_full = swedish_months_full[parsed_date.month - 1]
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
month_abbrev_upper = month_abbrev.upper()
swedish_hyphen_short = f"{parsed_date.day:02d}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
swedish_hyphen_short_lower = f"{parsed_date.day:02d}-{month_abbrev}-{parsed_date.strftime('%y')}"
swedish_hyphen_short_no_zero = f"{parsed_date.day}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
month_year_only = f"{month_abbrev_upper}-{parsed_date.strftime('%y')}"
swedish_spaced = f"{parsed_date.day:02d} {month_abbrev_upper} {parsed_date.strftime('%y')}"
variants.extend([
iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
eu_dot_short, eu_slash_short, yy_mm_dd_short, iso_middot, spaced_full, spaced_short,
swedish_format_full, swedish_format_abbrev,
swedish_hyphen_short, swedish_hyphen_short_lower, swedish_hyphen_short_no_zero,
month_year_only, swedish_spaced
])
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,31 @@
"""
Invoice Number Normalizer
Normalizes invoice numbers for matching.
"""
import re
from .base import BaseNormalizer
class InvoiceNumberNormalizer(BaseNormalizer):
"""
Normalizes invoice numbers.
Keeps only digits for matching while preserving original format.
Examples:
'100017500321' -> ['100017500321']
'INV-100017500321' -> ['100017500321', 'INV-100017500321']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of invoice number."""
value = self.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only and digits_only != value:
variants.append(digits_only)
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,31 @@
"""
OCR Number Normalizer
Normalizes OCR reference numbers (Swedish payment system).
"""
import re
from .base import BaseNormalizer
class OCRNormalizer(BaseNormalizer):
"""
Normalizes OCR reference numbers.
Similar to invoice number - primarily digits.
Examples:
'94228110015950070' -> ['94228110015950070']
'OCR: 94228110015950070' -> ['94228110015950070', 'OCR: 94228110015950070']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of OCR number."""
value = self.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only and digits_only != value:
variants.append(digits_only)
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,39 @@
"""
Organisation Number Normalizer
Normalizes Swedish organisation numbers and VAT numbers.
"""
from .base import BaseNormalizer
from shared.utils.format_variants import FormatVariants
from shared.utils.text_cleaner import TextCleaner
class OrganisationNumberNormalizer(BaseNormalizer):
"""
Normalizes Swedish organisation numbers and VAT numbers.
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
Swedish VAT format: SE + org_number (10 digits) + 01
Examples:
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of organisation number."""
# Use shared module for base variants
variants = set(FormatVariants.organisation_number_variants(value))
# Add OCR error variants for digit sequences
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits and len(digits) >= 10:
# Generate variants where OCR might have misread characters
for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]):
variants.add(ocr_var)
if len(ocr_var) == 10:
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
return list(v for v in variants if v)

View File

@@ -0,0 +1,34 @@
"""
Plusgiro Number Normalizer
Normalizes Swedish Plusgiro account numbers.
"""
from .base import BaseNormalizer
from shared.utils.format_variants import FormatVariants
from shared.utils.text_cleaner import TextCleaner
class PlusgiroNormalizer(BaseNormalizer):
"""
Normalizes Plusgiro numbers.
Generates format variants and OCR error variants.
Examples:
'1234567-8' -> ['1234567-8', '12345678', ...]
'12345678' -> ['12345678', '1234567-8', ...]
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of Plusgiro number."""
# Use shared module for base variants
variants = set(FormatVariants.plusgiro_variants(value))
# Add OCR error variants
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)

View File

@@ -0,0 +1,75 @@
"""
Supplier Accounts Normalizer
Normalizes supplier account numbers (Bankgiro/Plusgiro).
"""
import re
from .base import BaseNormalizer
class SupplierAccountsNormalizer(BaseNormalizer):
"""
Normalizes supplier accounts field.
The field may contain multiple accounts separated by ' | '.
Format examples:
'PG:48676043 | PG:49128028 | PG:8915035'
'BG:5393-9484'
Each account is normalized separately to generate variants.
Examples:
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
"""
def normalize(self, value: str) -> list[str]:
"""Generate variants of supplier accounts."""
value = self.clean_text(value)
variants = []
# Split by ' | ' to handle multiple accounts
accounts = [acc.strip() for acc in value.split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Add original value
variants.append(account)
# Remove prefix (PG:, BG:, etc.)
if ':' in account:
prefix, number = account.split(':', 1)
number = number.strip()
variants.append(number) # Just the number without prefix
# Also add with different prefix formats
prefix_upper = prefix.strip().upper()
variants.append(f"{prefix_upper}:{number}")
variants.append(f"{prefix_upper}: {number}") # With space
else:
number = account
# Extract digits only
digits_only = re.sub(r'\D', '', number)
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also try 4-4 format for bankgiro
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
elif len(digits_only) == 10:
# 6-4 format (like org number)
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
return list(set(v for v in variants if v))

View File

@@ -0,0 +1,16 @@
from .paddle_ocr import OCREngine, OCRResult, OCRToken, extract_ocr_tokens
from .machine_code_parser import (
MachineCodeParser,
MachineCodeResult,
parse_machine_code,
)
__all__ = [
'OCREngine',
'OCRResult',
'OCRToken',
'extract_ocr_tokens',
'MachineCodeParser',
'MachineCodeResult',
'parse_machine_code',
]

View File

@@ -0,0 +1,929 @@
"""
Machine Code Line Parser for Swedish Invoices
Parses the bottom machine-readable payment line to extract:
- OCR reference number (10-25 digits)
- Amount (payment amount in SEK)
- Bankgiro account number (XXX-XXXX or XXXX-XXXX format)
- Plusgiro account number (XXXXXXX-X format)
The machine code line is typically found at the bottom of Swedish invoices,
in the payment slip (Inbetalningskort) section. It contains machine-readable
data for automated payment processing.
## Swedish Payment Line Standard Format
The standard machine-readable payment line follows this structure:
# <OCR> # <Kronor> <Öre> <Type> > <Bankgiro>#<Control>#
Example:
# 31130954410 # 315 00 2 > 8983025#14#
Components:
- `#` - Start delimiter
- `31130954410` - OCR number (with Mod 10 check digit)
- `#` - Separator
- `315 00` - Amount: 315 SEK and 00 öre (315.00 SEK)
- `2` - Payment type / record type
- `>` - Points to recipient info
- `8983025` - Bankgiro number
- `#14#` - End marker with control code
Legacy patterns also supported:
- OCR: 8120000849965361 (10-25 consecutive digits)
- Bankgiro: 5393-9484 or 53939484
- Plusgiro: 1234567-8
- Amount: 1234,56 or 1234.56 (with decimal separator)
"""
import re
from dataclasses import dataclass, field
from typing import Optional
from shared.pdf.extractor import Token as TextToken
from shared.utils.validators import FieldValidators
@dataclass
class MachineCodeResult:
"""Result of machine code parsing."""
ocr: Optional[str] = None
amount: Optional[str] = None
bankgiro: Optional[str] = None
plusgiro: Optional[str] = None
confidence: float = 0.0
source_tokens: list[TextToken] = field(default_factory=list)
raw_line: str = ""
# Region bounding box in PDF coordinates (x0, y0, x1, y1)
region_bbox: Optional[tuple[float, float, float, float]] = None
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {
'ocr': self.ocr,
'amount': self.amount,
'bankgiro': self.bankgiro,
'plusgiro': self.plusgiro,
'confidence': self.confidence,
'raw_line': self.raw_line,
'region_bbox': self.region_bbox,
}
def get_region_bbox(self) -> Optional[tuple[float, float, float, float]]:
"""
Get the bounding box of the payment slip region.
Returns:
Tuple (x0, y0, x1, y1) in PDF coordinates, or None if no region detected
"""
if self.region_bbox:
return self.region_bbox
if not self.source_tokens:
return None
# Calculate bbox from source tokens
x0 = min(t.bbox[0] for t in self.source_tokens)
y0 = min(t.bbox[1] for t in self.source_tokens)
x1 = max(t.bbox[2] for t in self.source_tokens)
y1 = max(t.bbox[3] for t in self.source_tokens)
return (x0, y0, x1, y1)
class MachineCodeParser:
"""
Parser for machine-readable payment lines on Swedish invoices.
The parser focuses on the bottom region of the invoice where
the payment slip (Inbetalningskort) is typically located.
"""
# Payment slip detection keywords (Swedish)
PAYMENT_SLIP_KEYWORDS = [
'inbetalning', 'girering', 'avi', 'betalning',
'plusgiro', 'postgiro', 'bankgiro', 'bankgirot',
'betalningsavsändare', 'betalningsmottagare',
'maskinellt', 'ändringar', # "DEN AVLÄSES MASKINELLT"
]
# Patterns for field extraction
# OCR: 10-25 consecutive digits (may have spaces or # at end)
OCR_PATTERN = re.compile(r'(?<!\d)(\d{10,25})(?!\d)')
# Bankgiro: XXX-XXXX or XXXX-XXXX (7-8 digits with optional dash)
BANKGIRO_PATTERN = re.compile(r'\b(\d{3,4}[-\s]?\d{4})\b')
# Plusgiro: XXXXXXX-X (7-8 digits with dash before last digit)
PLUSGIRO_PATTERN = re.compile(r'\b(\d{6,7}[-\s]?\d)\b')
# Amount: digits with comma or dot as decimal separator
# Supports formats: 1234,56 | 1234.56 | 1 234,56 | 1.234,56
AMOUNT_PATTERN = re.compile(
r'\b(\d{1,3}(?:[\s\.\xa0]\d{3})*[,\.]\d{2})\b'
)
# Alternative amount pattern for integers (no decimal)
AMOUNT_INTEGER_PATTERN = re.compile(r'\b(\d{2,6})\b')
# Standard Swedish payment line pattern
# Format: # <OCR> # <Kronor> <Öre> <Type> > <Bankgiro/Plusgiro>#<Control>#
# Example: # 31130954410 # 315 00 2 > 8983025#14#
# This pattern captures both Bankgiro and Plusgiro accounts
PAYMENT_LINE_PATTERN = re.compile(
r'#\s*' # Start delimiter
r'(\d{5,25})\s*' # OCR number (capture group 1)
r'#\s*' # Separator
r'(\d{1,7})\s+' # Kronor (capture group 2)
r'(\d{2})\s+' # Öre (capture group 3)
r'(\d)\s*' # Type (capture group 4)
r'>\s*' # Direction marker
r'(\d{5,10})' # Bankgiro/Plusgiro (capture group 5)
r'(?:#\d{1,3}#)?' # Optional end marker
)
# Alternative pattern with different spacing
PAYMENT_LINE_PATTERN_ALT = re.compile(
r'#?\s*' # Optional start delimiter
r'(\d{8,25})\s*' # OCR number
r'#?\s*' # Optional separator
r'(\d{1,7})\s+' # Kronor
r'(\d{2})\s+' # Öre
r'\d\s*' # Type
r'>?\s*' # Optional direction marker
r'(\d{5,10})' # Bankgiro
)
# Reverse format pattern (Bankgiro first, then OCR)
# Format: <Bankgiro>#<Control># <Kronor> <Öre> <Type> > <OCR> #
# Example: 53241469#41# 2428 00 1 > 4388595300 #
PAYMENT_LINE_PATTERN_REVERSE = re.compile(
r'(\d{7,8})' # Bankgiro (capture group 1)
r'#\d{1,3}#\s+' # Control marker
r'(\d{1,7})\s+' # Kronor (capture group 2)
r'(\d{2})\s+' # Öre (capture group 3)
r'\d\s*' # Type
r'>\s*' # Direction marker
r'(\d{5,25})' # OCR number (capture group 4)
)
def __init__(self, bottom_region_ratio: float = 0.35):
"""
Initialize the parser.
Args:
bottom_region_ratio: Fraction of page height to consider as bottom region.
Default 0.35 means bottom 35% of the page.
"""
self.bottom_region_ratio = bottom_region_ratio
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
"""
Detect account type keywords in context.
Returns:
Dict with 'bankgiro' and 'plusgiro' boolean flags
"""
context_text = ' '.join(t.text.lower() for t in tokens)
return {
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
}
def _normalize_account_spaces(self, line: str) -> str:
"""
Remove spaces in account number portion after > marker.
Args:
line: Payment line text
Returns:
Line with normalized account number spacing
"""
if '>' not in line:
return line
parts = line.split('>', 1)
# After >, remove spaces between digits (but keep # markers)
after_arrow = parts[1]
# Extract digits and # markers, remove spaces between digits
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
# May need multiple passes for sequences like "78 2 1 713"
while re.search(r'(\d)\s+(\d)', normalized):
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
return parts[0] + '>' + normalized
def _format_account(
self,
account_digits: str,
is_plusgiro_context: bool
) -> tuple[str, str]:
"""
Format account number and determine type (bankgiro or plusgiro).
Uses context keywords first, then falls back to Luhn validation
to determine the most likely account type.
Args:
account_digits: Raw digits of account number
is_plusgiro_context: Whether context indicates Plusgiro
Returns:
Tuple of (formatted_account, account_type)
"""
if is_plusgiro_context:
# Context explicitly indicates Plusgiro
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
return formatted, 'plusgiro'
# No explicit context - use Luhn validation to determine type
# Try both formats and see which passes Luhn check
# Format as Plusgiro: XXXXXXX-X (all digits, check digit at end)
pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
# Format as Bankgiro: XXX-XXXX or XXXX-XXXX
if len(account_digits) == 7:
bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}"
elif len(account_digits) == 8:
bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}"
else:
bg_formatted = account_digits
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
# Decision logic:
# 1. If only one format passes Luhn, use that
# 2. If both pass or both fail, default to Bankgiro (more common in payment lines)
if pg_valid and not bg_valid:
return pg_formatted, 'plusgiro'
elif bg_valid and not pg_valid:
return bg_formatted, 'bankgiro'
else:
# Both valid or both invalid - default to bankgiro
return bg_formatted, 'bankgiro'
def parse(
self,
tokens: list[TextToken],
page_height: float,
page_width: float | None = None,
) -> MachineCodeResult:
"""
Parse machine code from tokens.
Args:
tokens: List of text tokens from OCR or text extraction
page_height: Height of the page in points
page_width: Width of the page in points (optional)
Returns:
MachineCodeResult with extracted fields
"""
if not tokens:
return MachineCodeResult()
# Filter to bottom region tokens
bottom_y_threshold = page_height * (1 - self.bottom_region_ratio)
bottom_tokens = [
t for t in tokens
if t.bbox[1] >= bottom_y_threshold # y0 >= threshold
]
if not bottom_tokens:
return MachineCodeResult()
# Sort by y position (top to bottom) then x (left to right)
bottom_tokens.sort(key=lambda t: (t.bbox[1], t.bbox[0]))
# Check if this looks like a payment slip region
combined_text = ' '.join(t.text for t in bottom_tokens).lower()
has_payment_keywords = any(
kw in combined_text for kw in self.PAYMENT_SLIP_KEYWORDS
)
# Build raw line from bottom tokens
raw_line = ' '.join(t.text for t in bottom_tokens)
# Try standard payment line format first and find the matching tokens
standard_result, matched_tokens = self._parse_standard_payment_line_with_tokens(
raw_line, bottom_tokens
)
if standard_result and matched_tokens:
# Calculate bbox only from tokens that contain the machine code
x0 = min(t.bbox[0] for t in matched_tokens)
y0 = min(t.bbox[1] for t in matched_tokens)
x1 = max(t.bbox[2] for t in matched_tokens)
y1 = max(t.bbox[3] for t in matched_tokens)
region_bbox = (x0, y0, x1, y1)
result = MachineCodeResult(
ocr=standard_result.get('ocr'),
amount=standard_result.get('amount'),
bankgiro=standard_result.get('bankgiro'),
plusgiro=standard_result.get('plusgiro'),
confidence=0.95,
source_tokens=matched_tokens,
raw_line=raw_line,
region_bbox=region_bbox,
)
return result
# Fall back to individual field extraction
result = MachineCodeResult(
source_tokens=bottom_tokens,
raw_line=raw_line,
)
# Extract OCR number (longest digit sequence 10-25 digits)
result.ocr = self._extract_ocr(bottom_tokens)
# Extract Bankgiro
result.bankgiro = self._extract_bankgiro(bottom_tokens)
# Extract Plusgiro (if no Bankgiro found)
if not result.bankgiro:
result.plusgiro = self._extract_plusgiro(bottom_tokens)
# Extract Amount
result.amount = self._extract_amount(bottom_tokens)
# Calculate confidence
result.confidence = self._calculate_confidence(
result, has_payment_keywords
)
# For fallback extraction, compute bbox from tokens that contain the extracted values
matched_tokens = self._find_tokens_with_values(bottom_tokens, result)
if matched_tokens:
x0 = min(t.bbox[0] for t in matched_tokens)
y0 = min(t.bbox[1] for t in matched_tokens)
x1 = max(t.bbox[2] for t in matched_tokens)
y1 = max(t.bbox[3] for t in matched_tokens)
result.region_bbox = (x0, y0, x1, y1)
result.source_tokens = matched_tokens
return result
def _find_tokens_with_values(
self,
tokens: list[TextToken],
result: MachineCodeResult
) -> list[TextToken]:
"""Find tokens that contain the extracted values (OCR, Amount, Bankgiro)."""
matched = []
values_to_find = []
if result.ocr:
values_to_find.append(result.ocr)
if result.amount:
# Amount might be just digits
amount_digits = re.sub(r'\D', '', result.amount)
values_to_find.append(amount_digits)
values_to_find.append(result.amount)
if result.bankgiro:
# Bankgiro might have dash or not
bg_digits = re.sub(r'\D', '', result.bankgiro)
values_to_find.append(bg_digits)
values_to_find.append(result.bankgiro)
if result.plusgiro:
pg_digits = re.sub(r'\D', '', result.plusgiro)
values_to_find.append(pg_digits)
values_to_find.append(result.plusgiro)
for token in tokens:
text = token.text.replace(' ', '').replace('#', '')
text_digits = re.sub(r'\D', '', token.text)
for value in values_to_find:
if value in text or value in text_digits:
if token not in matched:
matched.append(token)
break
return matched
def _find_machine_code_line_tokens(
self,
tokens: list[TextToken]
) -> list[TextToken]:
"""
Find tokens that belong to the machine code line using pure regex patterns.
The machine code line typically contains:
- Control markers like #14#, #41#
- Direction marker >
- Account numbers with # suffix
Returns:
List of tokens belonging to the machine code line
"""
# Find tokens with characteristic machine code patterns
ref_y = None
# First, find the reference y-coordinate from tokens with machine code patterns
for token in tokens:
text = token.text
# Check if token contains machine code patterns
# Priority 1: Control marker like #14#, 47304035#14#
has_control_marker = bool(re.search(r'#\d+#', text))
# Priority 2: Direction marker >
has_direction = '>' in text
if has_control_marker:
# This is very likely part of the machine code line
ref_y = token.bbox[1]
break
elif has_direction and ref_y is None:
# Direction marker is also a good indicator
ref_y = token.bbox[1]
if ref_y is None:
return []
# Collect all tokens on the same line (within 3 points of ref_y)
# Use very small tolerance because Swedish invoices often have duplicate
# machine code lines (upper and lower part of payment slip)
y_tolerance = 3
machine_code_tokens = []
for token in tokens:
if abs(token.bbox[1] - ref_y) < y_tolerance:
text = token.text
# Include token if it contains:
# - Digits (OCR, amount, account numbers)
# - # symbol (delimiters, control markers)
# - > symbol (direction marker)
if (re.search(r'\d', text) or '#' in text or '>' in text):
machine_code_tokens.append(token)
# If we found very few tokens, try to expand to nearby y values
# that might be part of the same logical line
if len(machine_code_tokens) < 3:
y_tolerance = 10
machine_code_tokens = []
for token in tokens:
if abs(token.bbox[1] - ref_y) < y_tolerance:
text = token.text
if (re.search(r'\d', text) or '#' in text or '>' in text):
machine_code_tokens.append(token)
return machine_code_tokens
def _parse_standard_payment_line_with_tokens(
self,
raw_line: str,
tokens: list[TextToken]
) -> tuple[Optional[dict], list[TextToken]]:
"""
Parse standard Swedish payment line format and find matching tokens.
Uses pure regex to identify the machine code line, then finds tokens
that are part of that line based on their position.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Bankgiro/Plusgiro>#<Control>#
Example: # 31130954410 # 315 00 2 > 8983025#14#
Returns:
Tuple of (parsed_dict, matched_tokens) or (None, [])
"""
# First find the machine code line tokens using pattern matching
machine_code_tokens = self._find_machine_code_line_tokens(tokens)
if not machine_code_tokens:
# Fall back to regex on raw_line
parsed = self._parse_standard_payment_line(raw_line, raw_line)
return parsed, []
# Build a line from just the machine code tokens (sorted by x position)
# Group tokens by approximate x position to handle duplicate OCR results
mc_tokens_sorted = sorted(machine_code_tokens, key=lambda t: t.bbox[0])
# Deduplicate tokens at similar x positions (keep the first one)
deduped_tokens = []
last_x = -100
for t in mc_tokens_sorted:
# Skip tokens that are too close to the previous one (likely duplicates)
if t.bbox[0] - last_x < 5:
continue
deduped_tokens.append(t)
last_x = t.bbox[2] # Use end x for next comparison
mc_line = ' '.join(t.text for t in deduped_tokens)
# Try to parse this line, using raw_line for context detection
parsed = self._parse_standard_payment_line(mc_line, raw_line)
if parsed:
return parsed, deduped_tokens
# If machine code line parsing failed, try the full raw_line
parsed = self._parse_standard_payment_line(raw_line, raw_line)
if parsed:
return parsed, machine_code_tokens
return None, []
def _parse_standard_payment_line(
self,
raw_line: str,
context_line: str | None = None
) -> Optional[dict]:
"""
Parse standard Swedish payment line format.
Format: # <OCR> # <Kronor> <Öre> <Type> > <Bankgiro/Plusgiro>#<Control>#
Example: # 31130954410 # 315 00 2 > 8983025#14#
Args:
raw_line: The line to parse (may be just the machine code tokens)
context_line: Optional full line for context detection (e.g., to find "plusgiro" keywords)
Returns:
Dict with 'ocr', 'amount', and 'bankgiro' or 'plusgiro' if matched, None otherwise
"""
# Use context_line for detecting Plusgiro/Bankgiro, fall back to raw_line
context = (context_line or raw_line).lower()
is_plusgiro_context = (
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
and 'bankgiro' not in context
)
# Preprocess: remove spaces in the account number part (after >)
raw_line = self._normalize_account_spaces(raw_line)
# Try primary pattern
match = self.PAYMENT_LINE_PATTERN.search(raw_line)
if match:
ocr = match.group(1)
kronor = match.group(2)
ore = match.group(3)
account_digits = match.group(5)
# Format amount: combine kronor and öre
amount = f"{kronor},{ore}" if ore != "00" else kronor
formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
return {
'ocr': ocr,
'amount': amount,
account_type: formatted_account,
}
# Try alternative pattern
match = self.PAYMENT_LINE_PATTERN_ALT.search(raw_line)
if match:
ocr = match.group(1)
kronor = match.group(2)
ore = match.group(3)
account_digits = match.group(4)
amount = f"{kronor},{ore}" if ore != "00" else kronor
formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
return {
'ocr': ocr,
'amount': amount,
account_type: formatted_account,
}
# Try reverse pattern (Account first, then OCR)
match = self.PAYMENT_LINE_PATTERN_REVERSE.search(raw_line)
if match:
account_digits = match.group(1)
kronor = match.group(2)
ore = match.group(3)
ocr = match.group(4)
amount = f"{kronor},{ore}" if ore != "00" else kronor
formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
return {
'ocr': ocr,
'amount': amount,
account_type: formatted_account,
}
return None
def _extract_ocr(self, tokens: list[TextToken]) -> Optional[str]:
"""Extract OCR reference number."""
candidates = []
# First, collect all bankgiro-like patterns to exclude
bankgiro_digits = set()
for token in tokens:
text = token.text.strip()
bg_matches = self.BANKGIRO_PATTERN.findall(text)
for bg in bg_matches:
digits = re.sub(r'\D', '', bg)
bankgiro_digits.add(digits)
# Also add with potential check digits (common pattern)
for i in range(10):
bankgiro_digits.add(digits + str(i))
bankgiro_digits.add(digits + str(i) + str(i))
for token in tokens:
# Remove spaces and common suffixes
text = token.text.replace(' ', '').replace('#', '').strip()
# Find all digit sequences
matches = self.OCR_PATTERN.findall(text)
for match in matches:
# OCR numbers are typically 10-25 digits
if 10 <= len(match) <= 25:
# Skip if this looks like a bankgiro number with check digit
is_bankgiro_variant = any(
match.startswith(bg) or match.endswith(bg)
for bg in bankgiro_digits if len(bg) >= 7
)
# Also check if it's exactly bankgiro with 2-3 extra digits
for bg in bankgiro_digits:
if len(bg) >= 7 and (
match == bg or
(len(match) - len(bg) <= 3 and match.startswith(bg))
):
is_bankgiro_variant = True
break
if not is_bankgiro_variant:
candidates.append((match, len(match), token))
if not candidates:
return None
# Prefer longer sequences (more likely to be OCR)
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[0][0]
def _extract_bankgiro(self, tokens: list[TextToken]) -> Optional[str]:
"""Extract Bankgiro account number.
Bankgiro format: XXX-XXXX or XXXX-XXXX (dash in middle)
NOT Plusgiro: XXXXXXX-X (dash before last digit)
"""
candidates = []
context = self._detect_account_context(tokens)
# If clearly Plusgiro context (and not bankgiro), don't extract as Bankgiro
if context['plusgiro'] and not context['bankgiro']:
return None
for token in tokens:
text = token.text.strip()
# Look for Bankgiro pattern
matches = self.BANKGIRO_PATTERN.findall(text)
for match in matches:
# Check if this looks like Plusgiro format (dash before last digit)
# Plusgiro: 1234567-8 (dash at position -2)
if '-' in match:
parts = match.replace(' ', '').split('-')
if len(parts) == 2 and len(parts[1]) == 1:
# This is Plusgiro format, skip
continue
# Normalize: remove spaces, ensure dash
digits = re.sub(r'\D', '', match)
if len(digits) == 7:
normalized = f"{digits[:3]}-{digits[3:]}"
elif len(digits) == 8:
normalized = f"{digits[:4]}-{digits[4:]}"
else:
continue
candidates.append((normalized, context['bankgiro'], token))
if not candidates:
return None
# Prefer matches with bankgiro context
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
return candidates[0][0]
def _extract_plusgiro(self, tokens: list[TextToken]) -> Optional[str]:
"""Extract Plusgiro account number."""
candidates = []
context = self._detect_account_context(tokens)
for token in tokens:
text = token.text.strip()
matches = self.PLUSGIRO_PATTERN.findall(text)
for match in matches:
# Normalize: remove spaces, ensure dash before last digit
digits = re.sub(r'\D', '', match)
if 7 <= len(digits) <= 8:
normalized = f"{digits[:-1]}-{digits[-1]}"
candidates.append((normalized, context['plusgiro'], token))
if not candidates:
return None
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
return candidates[0][0]
def _extract_amount(self, tokens: list[TextToken]) -> Optional[str]:
"""Extract payment amount."""
candidates = []
for token in tokens:
text = token.text.strip()
# Try decimal amount pattern first
matches = self.AMOUNT_PATTERN.findall(text)
for match in matches:
# Normalize: remove thousand separators, use comma as decimal
normalized = match.replace(' ', '').replace('\xa0', '')
# Convert dot thousand separator to none, keep comma decimal
if '.' in normalized and ',' in normalized:
# Format like 1.234,56 -> 1234,56
normalized = normalized.replace('.', '')
elif '.' in normalized:
# Could be 1234.56 -> 1234,56
parts = normalized.split('.')
if len(parts) == 2 and len(parts[1]) == 2:
normalized = f"{parts[0]},{parts[1]}"
# Parse to verify it's a valid amount
try:
value = float(normalized.replace(',', '.'))
if 0 < value < 1000000: # Reasonable amount range
candidates.append((normalized, value, token))
except ValueError:
continue
# If no decimal amounts found, try integer amounts
# Look for "Kronor" label nearby and extract integer
if not candidates:
for i, token in enumerate(tokens):
text = token.text.strip().lower()
if 'kronor' in text or 'kr' == text or text.endswith(' kr'):
# Look at nearby tokens for amounts (wider range)
for j in range(max(0, i - 5), min(len(tokens), i + 5)):
nearby_text = tokens[j].text.strip()
# Match pure integer (1-6 digits)
int_match = re.match(r'^(\d{1,6})$', nearby_text)
if int_match:
value = int(int_match.group(1))
if 0 < value < 1000000:
candidates.append((str(value), float(value), tokens[j]))
# Also try to find amounts near "öre" label (Swedish cents)
if not candidates:
for i, token in enumerate(tokens):
text = token.text.strip().lower()
if 'öre' in text:
# Look at nearby tokens for amounts
for j in range(max(0, i - 5), min(len(tokens), i + 5)):
nearby_text = tokens[j].text.strip()
int_match = re.match(r'^(\d{1,6})$', nearby_text)
if int_match:
value = int(int_match.group(1))
if 0 < value < 1000000:
candidates.append((str(value), float(value), tokens[j]))
if not candidates:
return None
# Sort by value (prefer larger amounts - likely total)
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[0][0]
def _calculate_confidence(
self,
result: MachineCodeResult,
has_payment_keywords: bool
) -> float:
"""Calculate confidence score for the extraction."""
confidence = 0.0
# Base confidence from payment keywords
if has_payment_keywords:
confidence += 0.3
# Points for each extracted field
if result.ocr:
confidence += 0.25
# Bonus for typical OCR length (15-17 digits)
if 15 <= len(result.ocr) <= 17:
confidence += 0.1
if result.bankgiro or result.plusgiro:
confidence += 0.2
if result.amount:
confidence += 0.15
return min(confidence, 1.0)
def cross_validate(
self,
machine_result: MachineCodeResult,
csv_values: dict[str, str],
) -> dict[str, dict]:
"""
Cross-validate machine code extraction with CSV ground truth.
Args:
machine_result: Result from parse()
csv_values: Dict of field values from CSV
(keys: 'ocr', 'amount', 'bankgiro', 'plusgiro')
Returns:
Dict with validation results for each field:
{
'ocr': {
'machine': '123456789',
'csv': '123456789',
'match': True,
'use_machine': False, # CSV has value
},
...
}
"""
from shared.normalize import normalize_field
results = {}
field_mapping = [
('ocr', 'OCR', machine_result.ocr),
('amount', 'Amount', machine_result.amount),
('bankgiro', 'Bankgiro', machine_result.bankgiro),
('plusgiro', 'Plusgiro', machine_result.plusgiro),
]
for field_key, normalizer_name, machine_value in field_mapping:
csv_value = csv_values.get(field_key, '').strip()
result_entry = {
'machine': machine_value,
'csv': csv_value if csv_value else None,
'match': False,
'use_machine': False,
}
if machine_value and csv_value:
# Both have values - check if they match
machine_variants = normalize_field(normalizer_name, machine_value)
csv_variants = normalize_field(normalizer_name, csv_value)
# Check for any overlap
result_entry['match'] = bool(
set(machine_variants) & set(csv_variants)
)
# Special handling for amounts - allow rounding differences
if not result_entry['match'] and field_key == 'amount':
try:
# Parse both values as floats
machine_float = float(
machine_value.replace(' ', '')
.replace(',', '.').replace('\xa0', '')
)
csv_float = float(
csv_value.replace(' ', '')
.replace(',', '.').replace('\xa0', '')
)
# Allow 1 unit difference (rounding)
if abs(machine_float - csv_float) <= 1.0:
result_entry['match'] = True
result_entry['rounding_diff'] = True
except ValueError:
pass
elif machine_value and not csv_value:
# CSV is missing, use machine value
result_entry['use_machine'] = True
results[field_key] = result_entry
return results
def parse_machine_code(
tokens: list[TextToken],
page_height: float,
page_width: float | None = None,
bottom_ratio: float = 0.35,
) -> MachineCodeResult:
"""
Convenience function to parse machine code from tokens.
Args:
tokens: List of text tokens
page_height: Page height in points
page_width: Page width in points (optional)
bottom_ratio: Fraction of page to consider as bottom region
Returns:
MachineCodeResult with extracted fields
"""
parser = MachineCodeParser(bottom_region_ratio=bottom_ratio)
return parser.parse(tokens, page_height, page_width)

View File

@@ -0,0 +1,405 @@
"""
OCR Extraction Module using PaddleOCR
Extracts text tokens with bounding boxes from scanned PDFs.
"""
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
import numpy as np
# Suppress PaddlePaddle reinitialization warnings
os.environ.setdefault('GLOG_minloglevel', '2')
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
warnings.filterwarnings('ignore', message='.*reinitialization.*')
@dataclass
class OCRToken:
"""Represents an OCR-extracted text token with its bounding box."""
text: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
confidence: float
page_no: int = 0
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
@dataclass
class OCRResult:
"""Result from OCR extraction including tokens and preprocessed image."""
tokens: list[OCRToken]
output_img: np.ndarray | None = None # Preprocessed image from PaddleOCR
class OCREngine:
"""PaddleOCR wrapper for text extraction."""
def __init__(
self,
lang: str = "en",
det_model_dir: str | None = None,
rec_model_dir: str | None = None,
use_doc_orientation_classify: bool = True,
use_doc_unwarping: bool = False
):
"""
Initialize OCR engine.
Args:
lang: Language code ('en', 'sv', 'ch', etc.)
det_model_dir: Custom detection model directory
rec_model_dir: Custom recognition model directory
use_doc_orientation_classify: Whether to auto-detect and correct document orientation.
Default True to handle rotated documents.
use_doc_unwarping: Whether to use UVDoc document unwarping for curved/warped documents.
Default False to preserve original image layout,
especially important for payment OCR lines at bottom.
Enable for severely warped documents at the cost of potentially
losing bottom content.
Note:
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
Use `paddle.set_device('gpu')` before initialization to force GPU.
"""
# Suppress warnings during import and initialization
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
from paddleocr import PaddleOCR
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
init_params = {
'lang': lang,
# Enable orientation classification to handle rotated documents
'use_doc_orientation_classify': use_doc_orientation_classify,
# Disable UVDoc unwarping to preserve original image layout
# This prevents the bottom payment OCR line from being cut off
# For severely warped documents, enable this but expect potential content loss
'use_doc_unwarping': use_doc_unwarping,
}
if det_model_dir:
init_params['text_detection_model_dir'] = det_model_dir
if rec_model_dir:
init_params['text_recognition_model_dir'] = rec_model_dir
self.ocr = PaddleOCR(**init_params)
def extract_from_image(
self,
image: str | Path | np.ndarray,
page_no: int = 0,
max_size: int = 2000,
scale_to_pdf_points: float | None = None,
scan_bottom_region: bool = True,
bottom_region_ratio: float = 0.15
) -> list[OCRToken]:
"""
Extract text tokens from an image.
Args:
image: Image path or numpy array
page_no: Page number for reference
max_size: Maximum image dimension. Larger images will be scaled down
to avoid OCR issues with PaddleOCR on large images.
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
to convert from pixel to PDF point coordinates.
Use (72 / dpi) for images rendered at a specific DPI.
scan_bottom_region: If True, also scan the bottom region separately to catch
OCR payment lines that may be missed in full-page scan.
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
Returns:
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
"""
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
tokens = result.tokens
# Optionally scan bottom region separately for Swedish OCR payment lines
if scan_bottom_region:
bottom_tokens = self._scan_bottom_region(
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
)
tokens = self._merge_tokens(tokens, bottom_tokens)
return tokens
def _scan_bottom_region(
self,
image: str | Path | np.ndarray,
page_no: int,
max_size: int,
scale_to_pdf_points: float | None,
bottom_ratio: float
) -> list[OCRToken]:
"""Scan the bottom region of the image separately."""
from PIL import Image as PILImage
# Load image if path
if isinstance(image, (str, Path)):
img = PILImage.open(str(image))
img_array = np.array(img)
else:
img_array = image
h, w = img_array.shape[:2]
crop_y = int(h * (1 - bottom_ratio))
# Crop bottom region
bottom_crop = img_array[crop_y:h, :, :] if len(img_array.shape) == 3 else img_array[crop_y:h, :]
# OCR the cropped region (without recursive bottom scan to avoid infinite loop)
result = self.extract_with_image(
bottom_crop, page_no, max_size,
scale_to_pdf_points=None,
scan_bottom_region=False # Important: disable to prevent recursion
)
# Adjust bbox y-coordinates to full image space
adjusted_tokens = []
for token in result.tokens:
# Scale factor for coordinates
scale = scale_to_pdf_points if scale_to_pdf_points else 1.0
adjusted_bbox = (
token.bbox[0] * scale,
(token.bbox[1] + crop_y) * scale,
token.bbox[2] * scale,
(token.bbox[3] + crop_y) * scale
)
adjusted_tokens.append(OCRToken(
text=token.text,
bbox=adjusted_bbox,
confidence=token.confidence,
page_no=token.page_no
))
return adjusted_tokens
def _merge_tokens(
self,
main_tokens: list[OCRToken],
bottom_tokens: list[OCRToken]
) -> list[OCRToken]:
"""Merge tokens from main scan and bottom region scan, removing duplicates."""
if not bottom_tokens:
return main_tokens
# Create a set of existing token texts for deduplication
existing_texts = {t.text.strip() for t in main_tokens}
# Add bottom tokens that aren't duplicates
merged = list(main_tokens)
for token in bottom_tokens:
if token.text.strip() not in existing_texts:
merged.append(token)
existing_texts.add(token.text.strip())
return merged
def extract_with_image(
self,
image: str | Path | np.ndarray,
page_no: int = 0,
max_size: int = 2000,
scale_to_pdf_points: float | None = None,
scan_bottom_region: bool = True,
bottom_region_ratio: float = 0.15
) -> OCRResult:
"""
Extract text tokens from an image and return the preprocessed image.
PaddleOCR applies document preprocessing (unwarping, rotation, enhancement)
and returns coordinates relative to the preprocessed image (output_img).
This method returns both tokens and output_img so the caller can save
the correct image that matches the coordinates.
Args:
image: Image path or numpy array
page_no: Page number for reference
max_size: Maximum image dimension. Larger images will be scaled down
to avoid OCR issues with PaddleOCR on large images.
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
to convert from pixel to PDF point coordinates.
Use (72 / dpi) for images rendered at a specific DPI.
scan_bottom_region: If True, also scan the bottom region separately to catch
OCR payment lines that may be missed in full-page scan.
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
Returns:
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
"""
from PIL import Image as PILImage
# Load image if path
if isinstance(image, (str, Path)):
img = PILImage.open(str(image))
img_array = np.array(img)
else:
img_array = image
# Check if image needs scaling for OCR
h, w = img_array.shape[:2]
ocr_scale_factor = 1.0
if max(h, w) > max_size:
ocr_scale_factor = max_size / max(h, w)
new_w = int(w * ocr_scale_factor)
new_h = int(h * ocr_scale_factor)
# Resize image for OCR
img = PILImage.fromarray(img_array)
img = img.resize((new_w, new_h), PILImage.Resampling.LANCZOS)
img_array = np.array(img)
# PaddleOCR 3.x uses predict() method instead of ocr()
result = self.ocr.predict(img_array)
tokens = []
output_img = None
if result:
for item in result:
# PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys'
if isinstance(item, dict):
rec_texts = item.get('rec_texts', [])
rec_scores = item.get('rec_scores', [])
dt_polys = item.get('dt_polys', [])
# Get output_img from doc_preprocessor_res
# This is the preprocessed image that coordinates are relative to
doc_preproc = item.get('doc_preprocessor_res', {})
if isinstance(doc_preproc, dict):
output_img = doc_preproc.get('output_img')
# Coordinates are relative to output_img (preprocessed image)
# No rotation compensation needed - just use coordinates directly
for text, score, poly in zip(rec_texts, rec_scores, dt_polys):
# poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
x_coords = [float(p[0]) for p in poly]
y_coords = [float(p[1]) for p in poly]
# Apply PDF points scale if requested
if scale_to_pdf_points is not None:
final_scale = scale_to_pdf_points
else:
final_scale = 1.0
bbox = (
min(x_coords) * final_scale,
min(y_coords) * final_scale,
max(x_coords) * final_scale,
max(y_coords) * final_scale
)
tokens.append(OCRToken(
text=text,
bbox=bbox,
confidence=float(score),
page_no=page_no
))
elif isinstance(item, (list, tuple)) and len(item) == 2:
# Legacy format: [[bbox_points], (text, confidence)]
bbox_points, (text, confidence) = item
x_coords = [p[0] for p in bbox_points]
y_coords = [p[1] for p in bbox_points]
# Apply PDF points scale if requested
if scale_to_pdf_points is not None:
final_scale = scale_to_pdf_points
else:
final_scale = 1.0
bbox = (
min(x_coords) * final_scale,
min(y_coords) * final_scale,
max(x_coords) * final_scale,
max(y_coords) * final_scale
)
tokens.append(OCRToken(
text=text,
bbox=bbox,
confidence=confidence,
page_no=page_no
))
# If no output_img was found, use the original input array
if output_img is None:
output_img = img_array
# Optionally scan bottom region separately for Swedish OCR payment lines
if scan_bottom_region:
bottom_tokens = self._scan_bottom_region(
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
)
tokens = self._merge_tokens(tokens, bottom_tokens)
return OCRResult(tokens=tokens, output_img=output_img)
def extract_from_pdf(
self,
pdf_path: str | Path,
dpi: int = 300
) -> Generator[list[OCRToken], None, None]:
"""
Extract text from all pages of a scanned PDF.
Args:
pdf_path: Path to the PDF file
dpi: Resolution for rendering
Yields:
List of OCRToken for each page
"""
from ..pdf.renderer import render_pdf_to_images
import io
from PIL import Image
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
# Convert bytes to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
tokens = self.extract_from_image(image_array, page_no=page_no)
yield tokens
def extract_ocr_tokens(
image_path: str | Path,
lang: str = "en",
page_no: int = 0
) -> list[OCRToken]:
"""
Convenience function to extract OCR tokens from an image.
Args:
image_path: Path to the image file
lang: Language code
page_no: Page number for reference
Returns:
List of OCRToken objects
"""
engine = OCREngine(lang=lang)
return engine.extract_from_image(image_path, page_no=page_no)

View File

@@ -0,0 +1,12 @@
from .detector import is_text_pdf, get_pdf_type
from .renderer import render_pdf_to_images
from .extractor import extract_text_tokens, PDFDocument, Token
__all__ = [
'is_text_pdf',
'get_pdf_type',
'render_pdf_to_images',
'extract_text_tokens',
'PDFDocument',
'Token',
]

View File

@@ -0,0 +1,150 @@
"""
PDF Type Detection Module
Automatically distinguishes between:
- Text-based PDFs (digitally generated)
- Scanned image PDFs
"""
from pathlib import Path
from typing import Literal
import fitz # PyMuPDF
PDFType = Literal["text", "scanned", "mixed"]
def extract_text_first_page(pdf_path: str | Path) -> str:
"""Extract text from the first page of a PDF."""
doc = fitz.open(pdf_path)
if len(doc) == 0:
return ""
first_page = doc[0]
text = first_page.get_text()
doc.close()
return text
def is_text_pdf(pdf_path: str | Path, min_chars: int = 30) -> bool:
"""
Check if PDF has extractable AND READABLE text layer.
Some PDFs have custom font encodings that produce garbled text.
This function checks both the presence and readability of text.
Args:
pdf_path: Path to the PDF file
min_chars: Minimum characters to consider it a text PDF
Returns:
True if PDF has readable text layer, False if scanned or garbled
"""
text = extract_text_first_page(pdf_path)
stripped_text = text.strip()
# First check: enough characters (basic minimum)
if len(stripped_text) <= min_chars:
return False
# Second check: text readability
# PDFs with custom font encoding often produce garbled text
# Check if common invoice-related keywords are present
text_lower = stripped_text.lower()
invoice_keywords = [
'faktura', 'invoice', 'datum', 'date', 'belopp', 'amount',
'moms', 'vat', 'bankgiro', 'plusgiro', 'ocr', 'betala',
'summa', 'total', 'pris', 'price', 'kr', 'sek'
]
found_keywords = sum(1 for kw in invoice_keywords if kw in text_lower)
# If at least 2 keywords found, likely readable text
if found_keywords >= 2:
return True
# Third check: minimum content threshold
# A real text PDF invoice should have at least 200 chars of content
# PDFs with only headers/footers (like "Brandsign") should use OCR
if len(stripped_text) < 200:
return False
# Fourth check: character readability ratio
# Count printable ASCII and common Swedish/European characters
readable_chars = 0
total_chars = len(stripped_text)
for c in stripped_text:
# Printable ASCII (32-126) or common Swedish/European chars
if 32 <= ord(c) <= 126 or c in 'åäöÅÄÖéèêëÉÈÊËüÜ':
readable_chars += 1
# If less than 70% readable, treat as garbled/scanned
readable_ratio = readable_chars / total_chars if total_chars > 0 else 0
if readable_ratio < 0.70:
return False
# Fifth check: if no keywords found but passes basic readability,
# require higher readability threshold (85%) or at least 1 keyword
# This catches garbled PDFs that have high ASCII ratio but unreadable content
# (e.g., custom font encoding that maps to different characters)
if found_keywords == 0 and readable_ratio < 0.85:
return False
return True
def get_pdf_type(pdf_path: str | Path) -> PDFType:
"""
Determine the PDF type.
Returns:
'text' - Has extractable text layer
'scanned' - Image-based, needs OCR
'mixed' - Some pages have text, some don't
"""
doc = fitz.open(pdf_path)
if len(doc) == 0:
doc.close()
return "scanned"
text_pages = 0
total_pages = len(doc)
for page in doc:
text = page.get_text().strip()
if len(text) > 30:
text_pages += 1
doc.close()
if text_pages == total_pages:
return "text"
elif text_pages == 0:
return "scanned"
else:
return "mixed"
def get_page_info(pdf_path: str | Path) -> list[dict]:
"""
Get information about each page in the PDF.
Returns:
List of dicts with page info (number, width, height, has_text)
"""
doc = fitz.open(pdf_path)
pages = []
for i, page in enumerate(doc):
text = page.get_text().strip()
rect = page.rect
pages.append({
"page_no": i,
"width": rect.width,
"height": rect.height,
"has_text": len(text) > 30,
"char_count": len(text)
})
doc.close()
return pages

View File

@@ -0,0 +1,323 @@
"""
PDF Text Extraction Module
Extracts text tokens with bounding boxes from text-layer PDFs.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Optional
import fitz # PyMuPDF
from .detector import is_text_pdf as _is_text_pdf_standalone
@dataclass
class Token:
"""Represents a text token with its bounding box."""
text: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def width(self) -> float:
return self.x1 - self.x0
@property
def height(self) -> float:
return self.y1 - self.y0
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
class PDFDocument:
"""
Context manager for efficient PDF document handling.
Caches the open document handle to avoid repeated open/close cycles.
Use this when you need to perform multiple operations on the same PDF.
"""
def __init__(self, pdf_path: str | Path):
self.pdf_path = Path(pdf_path)
self._doc: Optional[fitz.Document] = None
self._dimensions_cache: dict[int, tuple[float, float]] = {}
def __enter__(self) -> 'PDFDocument':
self._doc = fitz.open(self.pdf_path)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._doc:
self._doc.close()
self._doc = None
@property
def doc(self) -> fitz.Document:
if self._doc is None:
raise RuntimeError("PDFDocument must be used within a context manager")
return self._doc
@property
def page_count(self) -> int:
return len(self.doc)
def is_text_pdf(self, min_chars: int = 30) -> bool:
"""
Check if PDF has extractable AND READABLE text layer.
Uses the improved detection from detector.py that also checks
for garbled text (custom font encoding issues).
"""
return _is_text_pdf_standalone(self.pdf_path, min_chars)
def get_page_dimensions(self, page_no: int = 0) -> tuple[float, float]:
"""Get page dimensions in points (cached)."""
if page_no not in self._dimensions_cache:
page = self.doc[page_no]
rect = page.rect
self._dimensions_cache[page_no] = (rect.width, rect.height)
return self._dimensions_cache[page_no]
def get_render_dimensions(self, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
"""Get rendered image dimensions in pixels."""
width, height = self.get_page_dimensions(page_no)
zoom = dpi / 72
return int(width * zoom), int(height * zoom)
def extract_text_tokens(self, page_no: int) -> Generator[Token, None, None]:
"""Extract text tokens from a specific page."""
page = self.doc[page_no]
text_dict = page.get_text("dict")
tokens_found = False
for block in text_dict.get("blocks", []):
if block.get("type") != 0:
continue
for line in block.get("lines", []):
for span in line.get("spans", []):
text = span.get("text", "").strip()
if not text:
continue
bbox = span.get("bbox")
if bbox and all(abs(b) < 1e9 for b in bbox):
tokens_found = True
yield Token(
text=text,
bbox=tuple(bbox),
page_no=page_no
)
# Fallback: if dict mode failed, use words mode
if not tokens_found:
words = page.get_text("words")
for word_info in words:
x0, y0, x1, y1, text, *_ = word_info
text = text.strip()
if text:
yield Token(
text=text,
bbox=(x0, y0, x1, y1),
page_no=page_no
)
def render_page(self, page_no: int, output_path: Path, dpi: int = 300) -> Path:
"""Render a page to an image file."""
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
page = self.doc[page_no]
pix = page.get_pixmap(matrix=matrix)
output_path.parent.mkdir(parents=True, exist_ok=True)
pix.save(str(output_path))
return output_path
def render_all_pages(
self,
output_dir: Path,
dpi: int = 300
) -> Generator[tuple[int, Path], None, None]:
"""Render all pages to images."""
output_dir.mkdir(parents=True, exist_ok=True)
pdf_name = self.pdf_path.stem
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
for page_no in range(self.page_count):
page = self.doc[page_no]
pix = page.get_pixmap(matrix=matrix)
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.png"
pix.save(str(image_path))
yield page_no, image_path
def extract_text_tokens(
pdf_path: str | Path,
page_no: int | None = None
) -> Generator[Token, None, None]:
"""
Extract text tokens with bounding boxes from PDF.
Args:
pdf_path: Path to the PDF file
page_no: Specific page to extract (None for all pages)
Yields:
Token objects with text and bbox
"""
doc = fitz.open(pdf_path)
pages_to_process = [page_no] if page_no is not None else range(len(doc))
for pg_no in pages_to_process:
page = doc[pg_no]
# Get text with position info using "dict" mode
text_dict = page.get_text("dict")
tokens_found = False
for block in text_dict.get("blocks", []):
if block.get("type") != 0: # Skip non-text blocks
continue
for line in block.get("lines", []):
for span in line.get("spans", []):
text = span.get("text", "").strip()
if not text:
continue
bbox = span.get("bbox")
# Check for corrupted bbox (overflow values)
if bbox and all(abs(b) < 1e9 for b in bbox):
tokens_found = True
yield Token(
text=text,
bbox=tuple(bbox),
page_no=pg_no
)
# Fallback: if dict mode failed, use words mode
if not tokens_found:
words = page.get_text("words")
for word_info in words:
x0, y0, x1, y1, text, *_ = word_info
text = text.strip()
if text:
yield Token(
text=text,
bbox=(x0, y0, x1, y1),
page_no=pg_no
)
doc.close()
def extract_words(
pdf_path: str | Path,
page_no: int | None = None
) -> Generator[Token, None, None]:
"""
Extract individual words with bounding boxes.
Uses PyMuPDF's word extraction which splits text into words.
"""
doc = fitz.open(pdf_path)
pages_to_process = [page_no] if page_no is not None else range(len(doc))
for pg_no in pages_to_process:
page = doc[pg_no]
# get_text("words") returns list of (x0, y0, x1, y1, word, block_no, line_no, word_no)
words = page.get_text("words")
for word_info in words:
x0, y0, x1, y1, text, *_ = word_info
text = text.strip()
if text:
yield Token(
text=text,
bbox=(x0, y0, x1, y1),
page_no=pg_no
)
doc.close()
def extract_lines(
pdf_path: str | Path,
page_no: int | None = None
) -> Generator[Token, None, None]:
"""
Extract text lines with bounding boxes.
"""
doc = fitz.open(pdf_path)
pages_to_process = [page_no] if page_no is not None else range(len(doc))
for pg_no in pages_to_process:
page = doc[pg_no]
text_dict = page.get_text("dict")
for block in text_dict.get("blocks", []):
if block.get("type") != 0:
continue
for line in block.get("lines", []):
spans = line.get("spans", [])
if not spans:
continue
# Combine all spans in the line
line_text = " ".join(s.get("text", "") for s in spans).strip()
if not line_text:
continue
# Get line bbox from all spans
x0 = min(s["bbox"][0] for s in spans)
y0 = min(s["bbox"][1] for s in spans)
x1 = max(s["bbox"][2] for s in spans)
y1 = max(s["bbox"][3] for s in spans)
yield Token(
text=line_text,
bbox=(x0, y0, x1, y1),
page_no=pg_no
)
doc.close()
def get_page_dimensions(pdf_path: str | Path, page_no: int = 0) -> tuple[float, float]:
"""Get the dimensions of a PDF page in points."""
doc = fitz.open(pdf_path)
page = doc[page_no]
rect = page.rect
doc.close()
return rect.width, rect.height

View File

@@ -0,0 +1,117 @@
"""
PDF Rendering Module
Converts PDF pages to images for YOLO training.
"""
from pathlib import Path
from typing import Generator
import fitz # PyMuPDF
def render_pdf_to_images(
pdf_path: str | Path,
output_dir: str | Path | None = None,
dpi: int = 300,
image_format: str = "png"
) -> Generator[tuple[int, Path | bytes], None, None]:
"""
Render PDF pages to images.
Args:
pdf_path: Path to the PDF file
output_dir: Directory to save images (if None, returns bytes)
dpi: Resolution for rendering (default 300)
image_format: Output format ('png' or 'jpg')
Yields:
Tuple of (page_number, image_path or image_bytes)
"""
doc = fitz.open(pdf_path)
# Calculate zoom factor for desired DPI (72 is base DPI for PDF)
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
pdf_name = Path(pdf_path).stem
for page_no, page in enumerate(doc):
# Render page to pixmap
pix = page.get_pixmap(matrix=matrix)
if output_dir:
# Save to file
ext = "jpg" if image_format.lower() in ("jpg", "jpeg") else "png"
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.{ext}"
if ext == "jpg":
pix.save(str(image_path), "jpeg")
else:
pix.save(str(image_path))
yield page_no, image_path
else:
# Return bytes
if image_format.lower() in ("jpg", "jpeg"):
yield page_no, pix.tobytes("jpeg")
else:
yield page_no, pix.tobytes("png")
doc.close()
def render_page_to_image(
pdf_path: str | Path,
page_no: int,
dpi: int = 300
) -> bytes:
"""
Render a single page to image bytes.
Args:
pdf_path: Path to the PDF file
page_no: Page number (0-indexed)
dpi: Resolution for rendering
Returns:
PNG image bytes
"""
doc = fitz.open(pdf_path)
if page_no >= len(doc):
doc.close()
raise ValueError(f"Page {page_no} does not exist (PDF has {len(doc)} pages)")
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
page = doc[page_no]
pix = page.get_pixmap(matrix=matrix)
image_bytes = pix.tobytes("png")
doc.close()
return image_bytes
def get_render_dimensions(pdf_path: str | Path, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
"""
Get the dimensions of a rendered page.
Returns:
(width, height) in pixels
"""
doc = fitz.open(pdf_path)
page = doc[page_no]
zoom = dpi / 72
rect = page.rect
width = int(rect.width * zoom)
height = int(rect.height * zoom)
doc.close()
return width, height

View File

@@ -0,0 +1,34 @@
"""
Shared utilities for invoice field extraction and matching.
This module provides common functionality used by both:
- Inference stage (field_extractor.py) - extracting values from OCR text
- Matching stage (normalizer.py) - generating variants for CSV matching
Modules:
- TextCleaner: Unicode normalization and OCR error correction
- FormatVariants: Generate format variants for matching
- FieldValidators: Validate field values (Luhn, dates, amounts)
- FuzzyMatcher: Fuzzy string matching with OCR awareness
- OCRCorrections: Comprehensive OCR error correction
- ContextExtractor: Context-aware field extraction
"""
from .text_cleaner import TextCleaner
from .format_variants import FormatVariants
from .validators import FieldValidators
from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
from .ocr_corrections import OCRCorrections, CorrectionResult
from .context_extractor import ContextExtractor, ExtractionCandidate
__all__ = [
'TextCleaner',
'FormatVariants',
'FieldValidators',
'FuzzyMatcher',
'FuzzyMatchResult',
'OCRCorrections',
'CorrectionResult',
'ContextExtractor',
'ExtractionCandidate',
]

View File

@@ -0,0 +1,433 @@
"""
Context-Aware Extraction Module
Extracts field values using contextual cues and label detection.
Improves extraction accuracy by understanding the semantic context.
"""
import re
from typing import Optional, NamedTuple
from dataclasses import dataclass
from .text_cleaner import TextCleaner
from .validators import FieldValidators
@dataclass
class ExtractionCandidate:
"""A candidate extracted value with metadata."""
value: str
raw_text: str
context_label: str
confidence: float
position: int # Character position in source text
extraction_method: str # 'label', 'pattern', 'proximity'
class ContextExtractor:
"""
Context-aware field extraction.
Uses multiple strategies:
1. Label detection - finds values after field labels
2. Pattern matching - uses field-specific regex patterns
3. Proximity analysis - finds values near related terms
4. Validation filtering - removes invalid candidates
"""
# =========================================================================
# Swedish Label Patterns (what appears before the value)
# =========================================================================
LABEL_PATTERNS = {
'InvoiceNumber': [
# Swedish
r'(?:faktura|fakt)\.?\s*(?:nr|nummer|#)?[:\s]*',
r'(?:fakturanummer|fakturanr)[:\s]*',
r'(?:vår\s+referens)[:\s]*',
# English
r'(?:invoice)\s*(?:no|number|#)?[:\s]*',
r'inv[.:\s]*#?',
],
'Amount': [
# Swedish
r'(?:att\s+)?betala[:\s]*',
r'(?:total|totalt|summa)[:\s]*',
r'(?:belopp)[:\s]*',
r'(?:slutsumma)[:\s]*',
r'(?:att\s+erlägga)[:\s]*',
# English
r'(?:total|amount|sum)[:\s]*',
r'(?:amount\s+due)[:\s]*',
],
'InvoiceDate': [
# Swedish
r'(?:faktura)?datum[:\s]*',
r'(?:fakt\.?\s*datum)[:\s]*',
# English
r'(?:invoice\s+)?date[:\s]*',
],
'InvoiceDueDate': [
# Swedish
r'(?:förfall(?:o)?datum)[:\s]*',
r'(?:betalas\s+senast)[:\s]*',
r'(?:sista\s+betalningsdag)[:\s]*',
r'(?:förfaller)[:\s]*',
# English
r'(?:due\s+date)[:\s]*',
r'(?:payment\s+due)[:\s]*',
],
'OCR': [
r'(?:ocr)[:\s]*',
r'(?:ocr\s*-?\s*nummer)[:\s]*',
r'(?:referens(?:nummer)?)[:\s]*',
r'(?:betalningsreferens)[:\s]*',
],
'Bankgiro': [
r'(?:bankgiro|bg)[:\s]*',
r'(?:bank\s*giro)[:\s]*',
],
'Plusgiro': [
r'(?:plusgiro|pg)[:\s]*',
r'(?:plus\s*giro)[:\s]*',
r'(?:postgiro)[:\s]*',
],
'supplier_organisation_number': [
r'(?:org\.?\s*(?:nr|nummer)?)[:\s]*',
r'(?:organisationsnummer)[:\s]*',
r'(?:org\.?\s*id)[:\s]*',
r'(?:vat\s*(?:no|number|nr)?)[:\s]*',
r'(?:moms(?:reg)?\.?\s*(?:nr|nummer)?)[:\s]*',
r'(?:se)[:\s]*', # VAT prefix
],
'customer_number': [
r'(?:kund(?:nr|nummer)?)[:\s]*',
r'(?:kundnummer)[:\s]*',
r'(?:customer\s*(?:no|number|id)?)[:\s]*',
r'(?:er\s+referens)[:\s]*',
],
}
# =========================================================================
# Value Patterns (what the value looks like)
# =========================================================================
VALUE_PATTERNS = {
'InvoiceNumber': [
r'[A-Z]{0,3}\d{3,15}', # Alphanumeric: INV12345
r'\d{3,15}', # Pure digits
r'20\d{2}[-/]\d{3,8}', # Year prefix: 2024-001
],
'Amount': [
r'\d{1,3}(?:[\s.]\d{3})*[,]\d{2}', # Swedish: 1 234,56
r'\d{1,3}(?:[,]\d{3})*[.]\d{2}', # US: 1,234.56
r'\d+[,.]\d{2}', # Simple: 123,45
r'\d+', # Integer
],
'InvoiceDate': [
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}', # ISO-like
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}', # European
r'\d{8}', # Compact YYYYMMDD
],
'InvoiceDueDate': [
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}',
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}',
r'\d{8}',
],
'OCR': [
r'\d{10,25}', # Long digit sequence
],
'Bankgiro': [
r'\d{3,4}[-\s]?\d{4}', # XXX-XXXX or XXXX-XXXX
r'\d{7,8}', # Without separator
],
'Plusgiro': [
r'\d{1,7}[-\s]?\d', # XXXXXXX-X
r'\d{2,8}', # Without separator
],
'supplier_organisation_number': [
r'\d{6}[-\s]?\d{4}', # NNNNNN-NNNN
r'\d{10}', # Without separator
r'SE\s?\d{10,12}(?:\s?01)?', # VAT format
],
'customer_number': [
r'[A-Z]{0,5}\s?[-]?\s?\d{1,10}', # EMM 256-6
r'\d{3,15}', # Pure digits
],
}
# =========================================================================
# Extraction Methods
# =========================================================================
@classmethod
def extract_with_label(
cls,
text: str,
field_name: str,
validate: bool = True
) -> list[ExtractionCandidate]:
"""
Extract field values by finding labels and taking following values.
Example: "Fakturanummer: 12345" -> extracts "12345"
"""
candidates = []
label_patterns = cls.LABEL_PATTERNS.get(field_name, [])
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
for label_pattern in label_patterns:
for value_pattern in value_patterns:
# Combine label + value patterns
full_pattern = f'({label_pattern})({value_pattern})'
matches = re.finditer(full_pattern, text, re.IGNORECASE)
for match in matches:
label = match.group(1).strip()
value = match.group(2).strip()
# Validate if requested
if validate and not cls._validate_value(field_name, value):
continue
# Calculate confidence based on label specificity
confidence = cls._calculate_label_confidence(label, field_name)
candidates.append(ExtractionCandidate(
value=value,
raw_text=match.group(0),
context_label=label,
confidence=confidence,
position=match.start(),
extraction_method='label'
))
return candidates
@classmethod
def extract_with_pattern(
cls,
text: str,
field_name: str,
validate: bool = True
) -> list[ExtractionCandidate]:
"""
Extract field values using only value patterns (no label required).
This is a fallback when no labels are found.
"""
candidates = []
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
for pattern in value_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
value = match.group(0).strip()
# Validate if requested
if validate and not cls._validate_value(field_name, value):
continue
# Lower confidence for pattern-only extraction
confidence = 0.6
candidates.append(ExtractionCandidate(
value=value,
raw_text=value,
context_label='',
confidence=confidence,
position=match.start(),
extraction_method='pattern'
))
return candidates
@classmethod
def extract_field(
cls,
text: str,
field_name: str,
validate: bool = True
) -> list[ExtractionCandidate]:
"""
Extract all candidate values for a field using multiple strategies.
Returns candidates sorted by confidence (highest first).
"""
candidates = []
# Strategy 1: Label-based extraction (highest confidence)
label_candidates = cls.extract_with_label(text, field_name, validate)
candidates.extend(label_candidates)
# Strategy 2: Pattern-based extraction (fallback)
if not label_candidates:
pattern_candidates = cls.extract_with_pattern(text, field_name, validate)
candidates.extend(pattern_candidates)
# Remove duplicates (same value, keep highest confidence)
seen_values = {}
for candidate in candidates:
normalized = TextCleaner.normalize_for_comparison(candidate.value)
if normalized not in seen_values or candidate.confidence > seen_values[normalized].confidence:
seen_values[normalized] = candidate
# Sort by confidence
result = sorted(seen_values.values(), key=lambda x: x.confidence, reverse=True)
return result
@classmethod
def extract_best(
cls,
text: str,
field_name: str,
validate: bool = True
) -> Optional[ExtractionCandidate]:
"""
Extract the best (highest confidence) candidate for a field.
"""
candidates = cls.extract_field(text, field_name, validate)
return candidates[0] if candidates else None
@classmethod
def extract_all_fields(cls, text: str) -> dict[str, list[ExtractionCandidate]]:
"""
Extract all known fields from text.
Returns a dictionary mapping field names to their candidates.
"""
results = {}
for field_name in cls.LABEL_PATTERNS.keys():
candidates = cls.extract_field(text, field_name)
if candidates:
results[field_name] = candidates
return results
# =========================================================================
# Helper Methods
# =========================================================================
@classmethod
def _validate_value(cls, field_name: str, value: str) -> bool:
"""Validate a value based on field type."""
field_lower = field_name.lower()
if 'date' in field_lower:
return FieldValidators.is_valid_date(value)
elif 'amount' in field_lower:
return FieldValidators.is_valid_amount(value)
elif 'bankgiro' in field_lower:
# Basic format check, not Luhn
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return 7 <= len(digits) <= 8
elif 'plusgiro' in field_lower:
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return 2 <= len(digits) <= 8
elif 'ocr' in field_lower:
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return 10 <= len(digits) <= 25
elif 'org' in field_lower:
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
return len(digits) >= 10
else:
# For other fields, just check it's not empty
return bool(value.strip())
@classmethod
def _calculate_label_confidence(cls, label: str, field_name: str) -> float:
"""
Calculate confidence based on how specific the label is.
More specific labels = higher confidence.
"""
label_lower = label.lower()
# Very specific labels
very_specific = {
'InvoiceNumber': ['fakturanummer', 'invoice number', 'fakturanr'],
'Amount': ['att betala', 'slutsumma', 'amount due'],
'InvoiceDate': ['fakturadatum', 'invoice date'],
'InvoiceDueDate': ['förfallodatum', 'förfallodag', 'due date'],
'OCR': ['ocr', 'betalningsreferens'],
'Bankgiro': ['bankgiro'],
'Plusgiro': ['plusgiro', 'postgiro'],
'supplier_organisation_number': ['organisationsnummer', 'org nummer'],
'customer_number': ['kundnummer', 'customer number'],
}
# Check for very specific match
if field_name in very_specific:
for specific in very_specific[field_name]:
if specific in label_lower:
return 0.95
# Moderately specific
moderate = {
'InvoiceNumber': ['faktura', 'invoice', 'nr'],
'Amount': ['total', 'summa', 'belopp'],
'InvoiceDate': ['datum', 'date'],
'InvoiceDueDate': ['förfall', 'due'],
}
if field_name in moderate:
for mod in moderate[field_name]:
if mod in label_lower:
return 0.85
# Generic match
return 0.75
@classmethod
def find_field_context(cls, text: str, position: int, window: int = 50) -> str:
"""
Get the surrounding context for a position in text.
Useful for understanding what field a value belongs to.
"""
start = max(0, position - window)
end = min(len(text), position + window)
return text[start:end]
@classmethod
def identify_field_type(cls, text: str, value: str) -> Optional[str]:
"""
Try to identify what field type a value belongs to based on context.
Looks at text surrounding the value to find labels.
"""
# Find the value in text
pos = text.find(value)
if pos == -1:
return None
# Get context before the value
context_before = text[max(0, pos - 50):pos].lower()
# Check each field's labels
for field_name, patterns in cls.LABEL_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, context_before, re.IGNORECASE):
return field_name
return None
# =========================================================================
# Convenience functions
# =========================================================================
def extract_field_with_context(text: str, field_name: str) -> Optional[str]:
"""Convenience function to extract a field value."""
candidate = ContextExtractor.extract_best(text, field_name)
return candidate.value if candidate else None
def extract_all_with_context(text: str) -> dict[str, str]:
"""Convenience function to extract all fields."""
all_candidates = ContextExtractor.extract_all_fields(text)
return {
field: candidates[0].value
for field, candidates in all_candidates.items()
if candidates
}

View File

@@ -0,0 +1,610 @@
"""
Format Variants Generator
Generates multiple format variants for invoice field values.
Used by both inference (to try different extractions) and matching (to match CSV values).
"""
import re
from datetime import datetime
from typing import Optional
from .text_cleaner import TextCleaner
class FormatVariants:
"""
Generates format variants for different field types.
The same logic is used for:
- Inference: trying different formats to extract a value
- Matching: generating variants of CSV values to match against OCR text
"""
# Swedish month names for date parsing
SWEDISH_MONTHS = {
'januari': '01', 'jan': '01',
'februari': '02', 'feb': '02',
'mars': '03', 'mar': '03',
'april': '04', 'apr': '04',
'maj': '05',
'juni': '06', 'jun': '06',
'juli': '07', 'jul': '07',
'augusti': '08', 'aug': '08',
'september': '09', 'sep': '09', 'sept': '09',
'oktober': '10', 'okt': '10',
'november': '11', 'nov': '11',
'december': '12', 'dec': '12',
}
# =========================================================================
# Organization Number Variants
# =========================================================================
@classmethod
def organisation_number_variants(cls, value: str) -> list[str]:
"""
Generate all format variants for Swedish organization number.
Input formats handled:
- "556123-4567" (standard with hyphen)
- "5561234567" (no hyphen)
- "SE556123456701" (VAT format)
- "SE 556123-4567 01" (VAT with spaces)
Returns all possible variants for matching.
"""
value = TextCleaner.clean_text(value)
value_upper = value.upper()
variants = set()
# 提取纯数字
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
# 如果是 VAT 格式,提取中间的 org number
# SE + 10 digits + 01 = "SE556123456701"
if value_upper.startswith('SE') and len(digits) == 12 and digits.endswith('01'):
# VAT format: SE + org_number + 01
digits = digits[:10]
elif digits.startswith('46') and len(digits) == 14:
# SE prefix in numeric (46 is SE in phone code): 46 + 10 digits + 01
digits = digits[2:12]
if len(digits) == 12:
# 12 位数字可能是带世纪前缀的: NNNNNNNN-NNNN (19556123-4567)
variants.add(value)
variants.add(digits) # 195561234567
# 带横线格式
variants.add(f"{digits[:8]}-{digits[8:]}") # 19556123-4567
# 提取后 10 位作为标准 org number
short_digits = digits[2:] # 5561234567
variants.add(short_digits)
variants.add(f"{short_digits[:6]}-{short_digits[6:]}") # 556123-4567
# VAT 格式
variants.add(f"SE{short_digits}01") # SE556123456701
return list(v for v in variants if v)
if len(digits) != 10:
# 如果不是标准 10 位,返回原始值和清洗后的变体
variants.add(value)
if digits:
variants.add(digits)
return list(variants)
# 生成所有变体
# 1. 纯数字
variants.add(digits) # 5561234567
# 2. 标准格式 (NNNNNN-NNNN)
with_hyphen = f"{digits[:6]}-{digits[6:]}"
variants.add(with_hyphen) # 556123-4567
# 3. VAT 格式
vat_compact = f"SE{digits}01"
variants.add(vat_compact) # SE556123456701
variants.add(vat_compact.lower()) # se556123456701
vat_spaced = f"SE {digits[:6]}-{digits[6:]} 01"
variants.add(vat_spaced) # SE 556123-4567 01
vat_spaced_no_hyphen = f"SE {digits} 01"
variants.add(vat_spaced_no_hyphen) # SE 5561234567 01
# 4. 有时带国家代码但无 01 后缀
variants.add(f"SE{digits}") # SE5561234567
variants.add(f"SE {digits}") # SE 5561234567
variants.add(f"SE{digits[:6]}-{digits[6:]}") # SE556123-4567
# 5. OCR 可能的错误变体
ocr_variants = TextCleaner.generate_ocr_variants(digits)
for ocr_var in ocr_variants:
if len(ocr_var) == 10:
variants.add(ocr_var)
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
return list(v for v in variants if v)
# =========================================================================
# Bankgiro Variants
# =========================================================================
@classmethod
def bankgiro_variants(cls, value: str) -> list[str]:
"""
Generate variants for Bankgiro number.
Formats:
- 7 digits: XXX-XXXX (e.g., 123-4567)
- 8 digits: XXXX-XXXX (e.g., 1234-5678)
"""
value = TextCleaner.clean_text(value)
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
variants = set()
variants.add(value)
if not digits or len(digits) < 7 or len(digits) > 8:
return list(v for v in variants if v)
# 纯数字
variants.add(digits)
# 带横线格式
if len(digits) == 7:
variants.add(f"{digits[:3]}-{digits[3:]}") # XXX-XXXX
elif len(digits) == 8:
variants.add(f"{digits[:4]}-{digits[4:]}") # XXXX-XXXX
# 有些 8 位也用 XXX-XXXXX 格式
variants.add(f"{digits[:3]}-{digits[3:]}")
# 带空格格式 (有时 OCR 会这样识别)
if len(digits) == 7:
variants.add(f"{digits[:3]} {digits[3:]}")
elif len(digits) == 8:
variants.add(f"{digits[:4]} {digits[4:]}")
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Plusgiro Variants
# =========================================================================
@classmethod
def plusgiro_variants(cls, value: str) -> list[str]:
"""
Generate variants for Plusgiro number.
Format: XXXXXXX-X (7 digits + check digit) or shorter
Examples: 1234567-8, 12345-6, 1-8
"""
value = TextCleaner.clean_text(value)
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
variants = set()
variants.add(value)
if not digits or len(digits) < 2 or len(digits) > 8:
return list(v for v in variants if v)
# 纯数字
variants.add(digits)
# Plusgiro 格式: 最后一位是校验位,用横线分隔
main_part = digits[:-1]
check_digit = digits[-1]
variants.add(f"{main_part}-{check_digit}")
# 有时带空格
variants.add(f"{main_part} {check_digit}")
# 分组格式 (常见于长号码): XX XX XX-X
if len(digits) >= 6:
# 尝试 XX XX XX-X 格式
spaced = ' '.join([digits[i:i + 2] for i in range(0, len(digits) - 1, 2)])
if len(digits) % 2 == 0:
spaced = spaced[:-1] + '-' + digits[-1]
else:
spaced = spaced + '-' + digits[-1]
variants.add(spaced.replace('- ', '-'))
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Amount Variants
# =========================================================================
@classmethod
def amount_variants(cls, value: str) -> list[str]:
"""
Generate variants for monetary amounts.
Handles:
- Swedish: 1 234,56 (space thousand, comma decimal)
- German: 1.234,56 (dot thousand, comma decimal)
- US/UK: 1,234.56 (comma thousand, dot decimal)
- Integer: 1234 -> 1234.00
Returns variants with different separators and with/without decimals.
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 尝试解析为数值
amount = cls._parse_amount(value)
if amount is None:
return list(v for v in variants if v)
# 生成不同格式的变体
int_part = int(amount)
dec_part = round((amount - int_part) * 100)
# 1. 基础格式
if dec_part == 0:
variants.add(str(int_part)) # 1234
variants.add(f"{int_part}.00") # 1234.00
variants.add(f"{int_part},00") # 1234,00
else:
variants.add(f"{int_part}.{dec_part:02d}") # 1234.56
variants.add(f"{int_part},{dec_part:02d}") # 1234,56
# 2. 带千位分隔符
int_str = str(int_part)
if len(int_str) > 3:
# 从右往左每3位加分隔符
parts = []
while int_str:
parts.append(int_str[-3:])
int_str = int_str[:-3]
parts.reverse()
# 空格分隔 (Swedish)
space_sep = ' '.join(parts)
if dec_part == 0:
variants.add(space_sep)
else:
variants.add(f"{space_sep},{dec_part:02d}")
variants.add(f"{space_sep}.{dec_part:02d}")
# 点分隔 (German)
dot_sep = '.'.join(parts)
if dec_part == 0:
variants.add(dot_sep)
else:
variants.add(f"{dot_sep},{dec_part:02d}")
# 逗号分隔 (US)
comma_sep = ','.join(parts)
if dec_part == 0:
variants.add(comma_sep)
else:
variants.add(f"{comma_sep}.{dec_part:02d}")
# 3. 带货币符号
base_amounts = [f"{int_part}.{dec_part:02d}", f"{int_part},{dec_part:02d}"]
if dec_part == 0:
base_amounts.append(str(int_part))
for base in base_amounts:
variants.add(f"{base} kr")
variants.add(f"{base} SEK")
variants.add(f"{base}kr")
variants.add(f"SEK {base}")
return list(v for v in variants if v)
@classmethod
def _parse_amount(cls, text: str) -> Optional[float]:
"""Parse amount from various formats."""
text = TextCleaner.normalize_amount_text(text)
# 移除所有非数字和分隔符
clean = re.sub(r'[^\d,.\s]', '', text)
if not clean:
return None
# 检测格式
# 瑞典格式: 1 234,56 或 1234,56
if re.match(r'^[\d\s]+,\d{2}$', clean):
clean = clean.replace(' ', '').replace(',', '.')
try:
return float(clean)
except ValueError:
pass
# 德国格式: 1.234,56
if re.match(r'^[\d.]+,\d{2}$', clean):
clean = clean.replace('.', '').replace(',', '.')
try:
return float(clean)
except ValueError:
pass
# 美国格式: 1,234.56
if re.match(r'^[\d,]+\.\d{2}$', clean):
clean = clean.replace(',', '')
try:
return float(clean)
except ValueError:
pass
# 简单格式
clean = clean.replace(' ', '').replace(',', '.')
# 如果有多个点,只保留最后一个
if clean.count('.') > 1:
parts = clean.rsplit('.', 1)
clean = parts[0].replace('.', '') + '.' + parts[1]
try:
return float(clean)
except ValueError:
return None
# =========================================================================
# Date Variants
# =========================================================================
@classmethod
def date_variants(cls, value: str) -> list[str]:
"""
Generate variants for dates.
Input can be:
- ISO: 2024-12-29
- European: 29/12/2024, 29.12.2024
- Swedish text: "29 december 2024"
- Compact: 20241229
Returns all format variants.
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 尝试解析日期
parsed = cls._parse_date(value)
if parsed is None:
return list(v for v in variants if v)
year, month, day = parsed
# 生成所有格式变体
# ISO
variants.add(f"{year}-{month:02d}-{day:02d}")
variants.add(f"{year}-{month}-{day}") # 不补零
# 点分隔 (Swedish common)
variants.add(f"{year}.{month:02d}.{day:02d}")
variants.add(f"{day:02d}.{month:02d}.{year}")
# 斜杠分隔
variants.add(f"{day:02d}/{month:02d}/{year}")
variants.add(f"{month:02d}/{day:02d}/{year}") # US format
variants.add(f"{year}/{month:02d}/{day:02d}")
# 紧凑格式
variants.add(f"{year}{month:02d}{day:02d}")
# 带月份名 (Swedish)
for month_name, month_num in cls.SWEDISH_MONTHS.items():
if month_num == f"{month:02d}":
variants.add(f"{day} {month_name} {year}")
variants.add(f"{day:02d} {month_name} {year}")
# 首字母大写
variants.add(f"{day} {month_name.capitalize()} {year}")
# 短年份
short_year = str(year)[2:]
variants.add(f"{day:02d}.{month:02d}.{short_year}")
variants.add(f"{day:02d}/{month:02d}/{short_year}")
variants.add(f"{short_year}-{month:02d}-{day:02d}")
return list(v for v in variants if v)
@classmethod
def _parse_date(cls, text: str) -> Optional[tuple[int, int, int]]:
"""
Parse date from text, returns (year, month, day) or None.
"""
text = TextCleaner.clean_text(text).lower()
# ISO: 2024-12-29
match = re.search(r'(\d{4})-(\d{1,2})-(\d{1,2})', text)
if match:
return int(match.group(1)), int(match.group(2)), int(match.group(3))
# Dot format: 2024.12.29
match = re.search(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', text)
if match:
return int(match.group(1)), int(match.group(2)), int(match.group(3))
# European: 29/12/2024 or 29.12.2024
match = re.search(r'(\d{1,2})[/.](\d{1,2})[/.](\d{4})', text)
if match:
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
# 验证日期合理性
if 1 <= day <= 31 and 1 <= month <= 12:
return year, month, day
# Compact: 20241229
match = re.search(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', text)
if match:
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
if 2000 <= year <= 2100 and 1 <= month <= 12 and 1 <= day <= 31:
return year, month, day
# Swedish month name: "29 december 2024"
for month_name, month_num in cls.SWEDISH_MONTHS.items():
pattern = rf'(\d{{1,2}})\s*{month_name}\s*(\d{{4}})'
match = re.search(pattern, text)
if match:
day, year = int(match.group(1)), int(match.group(2))
return year, int(month_num), day
return None
# =========================================================================
# Invoice Number Variants
# =========================================================================
@classmethod
def invoice_number_variants(cls, value: str) -> list[str]:
"""
Generate variants for invoice numbers.
Invoice numbers are highly variable:
- Pure digits: 12345678
- Alphanumeric: A3861, INV-2024-001
- With separators: 2024/001
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 提取数字部分
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
variants.add(digits)
# 大小写变体
variants.add(value.upper())
variants.add(value.lower())
# 移除分隔符
no_sep = re.sub(r'[-/\s]', '', value)
variants.add(no_sep)
variants.add(no_sep.upper())
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(value):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# OCR Number Variants
# =========================================================================
@classmethod
def ocr_number_variants(cls, value: str) -> list[str]:
"""
Generate variants for OCR reference numbers.
OCR numbers are typically 10-25 digits.
"""
value = TextCleaner.clean_text(value)
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
variants = set()
variants.add(value)
if digits:
variants.add(digits)
# 有些 OCR 号码带空格分组
if len(digits) > 4:
# 每 4 位分组
spaced = ' '.join([digits[i:i + 4] for i in range(0, len(digits), 4)])
variants.add(spaced)
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(digits):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Customer Number Variants
# =========================================================================
@classmethod
def customer_number_variants(cls, value: str) -> list[str]:
"""
Generate variants for customer numbers.
Customer numbers can be very diverse:
- Pure digits: 12345
- Alphanumeric: ABC123, EMM 256-6
- With separators: 123-456
"""
value = TextCleaner.clean_text(value)
variants = set()
variants.add(value)
# 大小写
variants.add(value.upper())
variants.add(value.lower())
# 移除所有分隔符和空格
compact = re.sub(r'[-/\s]', '', value)
variants.add(compact)
variants.add(compact.upper())
variants.add(compact.lower())
# 纯数字
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
if digits:
variants.add(digits)
# 纯字母 + 数字 (分离)
letters = re.sub(r'[^a-zA-Z]', '', value)
if letters and digits:
variants.add(f"{letters}{digits}")
variants.add(f"{letters.upper()}{digits}")
variants.add(f"{letters} {digits}")
variants.add(f"{letters.upper()} {digits}")
variants.add(f"{letters}-{digits}")
variants.add(f"{letters.upper()}-{digits}")
# OCR 错误变体
for ocr_var in TextCleaner.generate_ocr_variants(value):
variants.add(ocr_var)
return list(v for v in variants if v)
# =========================================================================
# Generic Field Variants
# =========================================================================
@classmethod
def get_variants(cls, field_name: str, value: str) -> list[str]:
"""
Get variants for a field by name.
This is the main entry point - dispatches to specific variant generators.
"""
if not value:
return []
field_lower = field_name.lower()
# 映射字段名到变体生成器
if 'organisation' in field_lower or 'org' in field_lower:
return cls.organisation_number_variants(value)
elif 'bankgiro' in field_lower or field_lower == 'bg':
return cls.bankgiro_variants(value)
elif 'plusgiro' in field_lower or field_lower == 'pg':
return cls.plusgiro_variants(value)
elif 'amount' in field_lower or 'belopp' in field_lower:
return cls.amount_variants(value)
elif 'date' in field_lower or 'datum' in field_lower:
return cls.date_variants(value)
elif 'invoice' in field_lower and 'number' in field_lower:
return cls.invoice_number_variants(value)
elif field_lower == 'invoicenumber':
return cls.invoice_number_variants(value)
elif 'ocr' in field_lower:
return cls.ocr_number_variants(value)
elif 'customer' in field_lower:
return cls.customer_number_variants(value)
else:
# 默认: 返回原值和基本清洗
return [value, TextCleaner.clean_text(value)]

View File

@@ -0,0 +1,417 @@
"""
Fuzzy Matching Module
Provides fuzzy string matching with OCR-aware similarity scoring.
Handles common OCR errors and format variations in invoice fields.
"""
import re
from typing import Optional
from dataclasses import dataclass
from .text_cleaner import TextCleaner
@dataclass
class FuzzyMatchResult:
"""Result of a fuzzy match operation."""
matched: bool
score: float # 0.0 to 1.0
ocr_value: str
expected_value: str
normalized_ocr: str
normalized_expected: str
match_type: str # 'exact', 'normalized', 'fuzzy', 'ocr_corrected'
class FuzzyMatcher:
"""
Fuzzy string matcher optimized for OCR text matching.
Provides multiple matching strategies:
1. Exact match
2. Normalized match (case-insensitive, whitespace-normalized)
3. OCR-corrected match (applying common OCR error corrections)
4. Edit distance based fuzzy match
5. Digit-sequence match (for numeric fields)
"""
# Minimum similarity threshold for fuzzy matches
DEFAULT_THRESHOLD = 0.85
# Field-specific thresholds (some fields need stricter matching)
FIELD_THRESHOLDS = {
'InvoiceNumber': 0.90,
'OCR': 0.95, # OCR numbers need high precision
'Amount': 0.95,
'Bankgiro': 0.90,
'Plusgiro': 0.90,
'InvoiceDate': 0.90,
'InvoiceDueDate': 0.90,
'supplier_organisation_number': 0.85,
'customer_number': 0.80, # More lenient for customer numbers
}
@classmethod
def get_threshold(cls, field_name: str) -> float:
"""Get the matching threshold for a specific field."""
return cls.FIELD_THRESHOLDS.get(field_name, cls.DEFAULT_THRESHOLD)
@classmethod
def levenshtein_distance(cls, s1: str, s2: str) -> int:
"""
Calculate Levenshtein (edit) distance between two strings.
This is the minimum number of single-character edits
(insertions, deletions, substitutions) needed to change s1 into s2.
"""
if len(s1) < len(s2):
return cls.levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
# Cost is 0 if characters match, 1 otherwise
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
@classmethod
def similarity_ratio(cls, s1: str, s2: str) -> float:
"""
Calculate similarity ratio between two strings.
Returns a value between 0.0 (completely different) and 1.0 (identical).
Based on Levenshtein distance normalized by the length of the longer string.
"""
if not s1 and not s2:
return 1.0
if not s1 or not s2:
return 0.0
max_len = max(len(s1), len(s2))
distance = cls.levenshtein_distance(s1, s2)
return 1.0 - (distance / max_len)
@classmethod
def ocr_aware_similarity(cls, ocr_text: str, expected: str) -> float:
"""
Calculate similarity with OCR error awareness.
This method considers common OCR errors when calculating similarity,
giving higher scores when differences are likely OCR mistakes.
"""
if not ocr_text or not expected:
return 0.0 if ocr_text != expected else 1.0
# First try exact match
if ocr_text == expected:
return 1.0
# Try with OCR corrections applied to ocr_text
corrected = TextCleaner.apply_ocr_digit_corrections(ocr_text)
if corrected == expected:
return 0.98 # Slightly less than exact match
# Try normalized comparison
norm_ocr = TextCleaner.normalize_for_comparison(ocr_text)
norm_expected = TextCleaner.normalize_for_comparison(expected)
if norm_ocr == norm_expected:
return 0.95
# Calculate base similarity
base_sim = cls.similarity_ratio(norm_ocr, norm_expected)
# Boost score if differences are common OCR errors
boost = cls._calculate_ocr_error_boost(ocr_text, expected)
return min(1.0, base_sim + boost)
@classmethod
def _calculate_ocr_error_boost(cls, ocr_text: str, expected: str) -> float:
"""
Calculate a score boost based on whether differences are likely OCR errors.
Returns a value between 0.0 and 0.1.
"""
if len(ocr_text) != len(expected):
return 0.0
ocr_errors = 0
total_diffs = 0
for oc, ec in zip(ocr_text, expected):
if oc != ec:
total_diffs += 1
# Check if this is a known OCR confusion pair
if cls._is_ocr_confusion_pair(oc, ec):
ocr_errors += 1
if total_diffs == 0:
return 0.0
# Boost proportional to how many differences are OCR errors
ocr_error_ratio = ocr_errors / total_diffs
return ocr_error_ratio * 0.1
@classmethod
def _is_ocr_confusion_pair(cls, char1: str, char2: str) -> bool:
"""Check if two characters are commonly confused in OCR."""
confusion_pairs = {
('0', 'O'), ('0', 'o'), ('0', 'D'), ('0', 'Q'),
('1', 'l'), ('1', 'I'), ('1', 'i'), ('1', '|'),
('2', 'Z'), ('2', 'z'),
('5', 'S'), ('5', 's'),
('6', 'G'), ('6', 'b'),
('8', 'B'),
('9', 'g'), ('9', 'q'),
}
pair = (char1, char2)
return pair in confusion_pairs or (char2, char1) in confusion_pairs
@classmethod
def match_digits(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
"""
Match digit sequences with OCR error tolerance.
Optimized for numeric fields like OCR numbers, amounts, etc.
"""
# Extract digits
ocr_digits = TextCleaner.extract_digits(ocr_text, apply_ocr_correction=True)
expected_digits = TextCleaner.extract_digits(expected, apply_ocr_correction=False)
# Exact match after extraction
if ocr_digits == expected_digits:
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_digits,
normalized_expected=expected_digits,
match_type='exact'
)
# Calculate similarity
score = cls.ocr_aware_similarity(ocr_digits, expected_digits)
return FuzzyMatchResult(
matched=score >= threshold,
score=score,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_digits,
normalized_expected=expected_digits,
match_type='fuzzy' if score >= threshold else 'no_match'
)
@classmethod
def match_amount(cls, ocr_text: str, expected: str, threshold: float = 0.95) -> FuzzyMatchResult:
"""
Match monetary amounts with format tolerance.
Handles different decimal separators (. vs ,) and thousand separators.
"""
from .validators import FieldValidators
# Parse both amounts
ocr_amount = FieldValidators.parse_amount(ocr_text)
expected_amount = FieldValidators.parse_amount(expected)
if ocr_amount is None or expected_amount is None:
# Can't parse, fall back to string matching
return cls.match_string(ocr_text, expected, threshold)
# Compare numeric values
if abs(ocr_amount - expected_amount) < 0.01: # Within 1 cent
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=f"{ocr_amount:.2f}",
normalized_expected=f"{expected_amount:.2f}",
match_type='exact'
)
# Calculate relative difference
max_val = max(abs(ocr_amount), abs(expected_amount))
if max_val > 0:
diff_ratio = abs(ocr_amount - expected_amount) / max_val
score = max(0.0, 1.0 - diff_ratio)
else:
score = 1.0 if ocr_amount == expected_amount else 0.0
return FuzzyMatchResult(
matched=score >= threshold,
score=score,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=f"{ocr_amount:.2f}" if ocr_amount else ocr_text,
normalized_expected=f"{expected_amount:.2f}" if expected_amount else expected,
match_type='fuzzy' if score >= threshold else 'no_match'
)
@classmethod
def match_date(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
"""
Match dates with format tolerance.
Handles different date formats (ISO, European, compact, etc.)
"""
from .validators import FieldValidators
# Parse both dates to ISO format
ocr_iso = FieldValidators.format_date_iso(ocr_text)
expected_iso = FieldValidators.format_date_iso(expected)
if ocr_iso and expected_iso:
if ocr_iso == expected_iso:
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_iso,
normalized_expected=expected_iso,
match_type='exact'
)
# Fall back to string matching on digits
return cls.match_digits(ocr_text, expected, threshold)
@classmethod
def match_string(cls, ocr_text: str, expected: str, threshold: float = 0.85) -> FuzzyMatchResult:
"""
General string matching with multiple strategies.
Tries exact, normalized, and fuzzy matching in order.
"""
# Clean both strings
ocr_clean = TextCleaner.clean_text(ocr_text)
expected_clean = TextCleaner.clean_text(expected)
# Strategy 1: Exact match
if ocr_clean == expected_clean:
return FuzzyMatchResult(
matched=True,
score=1.0,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_clean,
normalized_expected=expected_clean,
match_type='exact'
)
# Strategy 2: Case-insensitive match
if ocr_clean.lower() == expected_clean.lower():
return FuzzyMatchResult(
matched=True,
score=0.98,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_clean,
normalized_expected=expected_clean,
match_type='normalized'
)
# Strategy 3: OCR-corrected match
ocr_corrected = TextCleaner.apply_ocr_digit_corrections(ocr_clean)
if ocr_corrected == expected_clean:
return FuzzyMatchResult(
matched=True,
score=0.95,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_corrected,
normalized_expected=expected_clean,
match_type='ocr_corrected'
)
# Strategy 4: Fuzzy match
score = cls.ocr_aware_similarity(ocr_clean, expected_clean)
return FuzzyMatchResult(
matched=score >= threshold,
score=score,
ocr_value=ocr_text,
expected_value=expected,
normalized_ocr=ocr_clean,
normalized_expected=expected_clean,
match_type='fuzzy' if score >= threshold else 'no_match'
)
@classmethod
def match_field(
cls,
field_name: str,
ocr_value: str,
expected_value: str,
threshold: Optional[float] = None
) -> FuzzyMatchResult:
"""
Match a field value using field-appropriate strategy.
Automatically selects the best matching strategy based on field type.
"""
if threshold is None:
threshold = cls.get_threshold(field_name)
field_lower = field_name.lower()
# Route to appropriate matcher
if 'amount' in field_lower or 'belopp' in field_lower:
return cls.match_amount(ocr_value, expected_value, threshold)
if 'date' in field_lower or 'datum' in field_lower:
return cls.match_date(ocr_value, expected_value, threshold)
if any(x in field_lower for x in ['ocr', 'bankgiro', 'plusgiro', 'org']):
# Numeric fields with OCR errors
return cls.match_digits(ocr_value, expected_value, threshold)
if 'invoice' in field_lower and 'number' in field_lower:
# Invoice numbers can be alphanumeric
return cls.match_string(ocr_value, expected_value, threshold)
# Default to string matching
return cls.match_string(ocr_value, expected_value, threshold)
@classmethod
def find_best_match(
cls,
ocr_value: str,
candidates: list[str],
field_name: str = '',
threshold: Optional[float] = None
) -> Optional[tuple[str, FuzzyMatchResult]]:
"""
Find the best matching candidate from a list.
Returns (matched_value, match_result) or None if no match above threshold.
"""
if threshold is None:
threshold = cls.get_threshold(field_name) if field_name else cls.DEFAULT_THRESHOLD
best_match = None
best_result = None
for candidate in candidates:
result = cls.match_field(field_name, ocr_value, candidate, threshold=0.0)
if result.score >= threshold:
if best_result is None or result.score > best_result.score:
best_match = candidate
best_result = result
if best_match:
return (best_match, best_result)
return None

View File

@@ -0,0 +1,384 @@
"""
OCR Error Corrections Module
Provides comprehensive OCR error correction tables and correction functions.
Based on common OCR recognition errors in Swedish invoice documents.
"""
import re
from typing import Optional
from dataclasses import dataclass
@dataclass
class CorrectionResult:
"""Result of an OCR correction operation."""
original: str
corrected: str
corrections_applied: list[tuple[int, str, str]] # (position, from_char, to_char)
confidence: float # How confident we are in the correction
class OCRCorrections:
"""
Comprehensive OCR error correction utilities.
Provides:
- Character-level corrections for digits
- Word-level corrections for common Swedish terms
- Context-aware corrections
- Multiple correction strategies
"""
# =========================================================================
# Character-level OCR errors (digit fields)
# =========================================================================
# Characters commonly misread as digits
CHAR_TO_DIGIT = {
# Letters that look like digits
'O': '0', 'o': '0', # O -> 0
'Q': '0', # Q -> 0 (less common)
'D': '0', # D -> 0 (in some fonts)
'l': '1', 'I': '1', # l/I -> 1
'i': '1', # i without dot -> 1
'|': '1', # pipe -> 1
'!': '1', # exclamation -> 1
'Z': '2', 'z': '2', # Z -> 2
'E': '3', # E -> 3 (rare)
'A': '4', 'h': '4', # A/h -> 4 (in some fonts)
'S': '5', 's': '5', # S -> 5
'G': '6', 'b': '6', # G/b -> 6
'T': '7', 't': '7', # T -> 7 (rare)
'B': '8', # B -> 8
'g': '9', 'q': '9', # g/q -> 9
}
# Digits commonly misread as other characters
DIGIT_TO_CHAR = {
'0': ['O', 'o', 'D', 'Q'],
'1': ['l', 'I', 'i', '|', '!'],
'2': ['Z', 'z'],
'3': ['E'],
'4': ['A', 'h'],
'5': ['S', 's'],
'6': ['G', 'b'],
'7': ['T', 't'],
'8': ['B'],
'9': ['g', 'q'],
}
# Bidirectional confusion pairs (either direction is possible)
CONFUSION_PAIRS = [
('0', 'O'), ('0', 'o'), ('0', 'D'),
('1', 'l'), ('1', 'I'), ('1', '|'),
('2', 'Z'), ('2', 'z'),
('5', 'S'), ('5', 's'),
('6', 'G'), ('6', 'b'),
('8', 'B'),
('9', 'g'), ('9', 'q'),
]
# =========================================================================
# Word-level OCR errors (Swedish invoice terms)
# =========================================================================
# Common Swedish invoice terms and their OCR misreadings
SWEDISH_TERM_CORRECTIONS = {
# Faktura (Invoice)
'faktura': ['Faktura', 'FAKTURA', 'faktúra', 'faKtura'],
'fakturanummer': ['Fakturanummer', 'FAKTURANUMMER', 'fakturanr', 'fakt.nr'],
'fakturadatum': ['Fakturadatum', 'FAKTURADATUM', 'fakt.datum'],
# Belopp (Amount)
'belopp': ['Belopp', 'BELOPP', 'be1opp', 'bel0pp'],
'summa': ['Summa', 'SUMMA', '5umma'],
'total': ['Total', 'TOTAL', 'tota1', 't0tal'],
'moms': ['Moms', 'MOMS', 'm0ms'],
# Dates
'förfallodatum': ['Förfallodatum', 'FÖRFALLODATUM', 'förfa11odatum'],
'datum': ['Datum', 'DATUM', 'dátum'],
# Payment
'bankgiro': ['Bankgiro', 'BANKGIRO', 'BG', 'bg', 'bank giro'],
'plusgiro': ['Plusgiro', 'PLUSGIRO', 'PG', 'pg', 'plus giro'],
'postgiro': ['Postgiro', 'POSTGIRO'],
'ocr': ['OCR', 'ocr', '0CR', 'OcR'],
# Organization
'organisationsnummer': ['Organisationsnummer', 'ORGANISATIONSNUMMER', 'org.nr', 'orgnr'],
'kundnummer': ['Kundnummer', 'KUNDNUMMER', 'kund nr', 'kundnr'],
# Currency
'kronor': ['Kronor', 'KRONOR', 'kr', 'KR', 'SEK', 'sek'],
'öre': ['Öre', 'ÖRE', 'ore', 'ORE'],
}
# =========================================================================
# Context patterns
# =========================================================================
# Patterns that indicate the following/preceding text is a specific field
CONTEXT_INDICATORS = {
'invoice_number': [
r'faktura\s*(?:nr|nummer)?[:\s]*',
r'invoice\s*(?:no|number)?[:\s]*',
r'fakt\.?\s*nr[:\s]*',
r'inv[:\s]*#?',
],
'amount': [
r'(?:att\s+)?betala[:\s]*',
r'total[t]?[:\s]*',
r'summa[:\s]*',
r'belopp[:\s]*',
r'amount[:\s]*',
],
'date': [
r'datum[:\s]*',
r'date[:\s]*',
r'förfall(?:o)?datum[:\s]*',
r'fakturadatum[:\s]*',
],
'ocr': [
r'ocr[:\s]*',
r'referens[:\s]*',
r'betalningsreferens[:\s]*',
],
'bankgiro': [
r'bankgiro[:\s]*',
r'bg[:\s]*',
r'bank\s*giro[:\s]*',
],
'plusgiro': [
r'plusgiro[:\s]*',
r'pg[:\s]*',
r'plus\s*giro[:\s]*',
r'postgiro[:\s]*',
],
'org_number': [
r'org\.?\s*(?:nr|nummer)?[:\s]*',
r'organisationsnummer[:\s]*',
r'vat[:\s]*',
r'moms(?:reg)?\.?\s*(?:nr|nummer)?[:\s]*',
],
}
# =========================================================================
# Correction Methods
# =========================================================================
@classmethod
def correct_digits(cls, text: str, aggressive: bool = False) -> CorrectionResult:
"""
Apply digit corrections to text.
Args:
text: Input text
aggressive: If True, correct all potential digit-like characters.
If False, only correct characters adjacent to digits.
Returns:
CorrectionResult with original, corrected text, and details.
"""
corrections = []
result = []
for i, char in enumerate(text):
if char.isdigit():
result.append(char)
elif char in cls.CHAR_TO_DIGIT:
if aggressive:
# Always correct
corrected_char = cls.CHAR_TO_DIGIT[char]
corrections.append((i, char, corrected_char))
result.append(corrected_char)
else:
# Only correct if adjacent to digit
prev_is_digit = i > 0 and (text[i-1].isdigit() or text[i-1] in cls.CHAR_TO_DIGIT)
next_is_digit = i < len(text) - 1 and (text[i+1].isdigit() or text[i+1] in cls.CHAR_TO_DIGIT)
if prev_is_digit or next_is_digit:
corrected_char = cls.CHAR_TO_DIGIT[char]
corrections.append((i, char, corrected_char))
result.append(corrected_char)
else:
result.append(char)
else:
result.append(char)
corrected = ''.join(result)
confidence = 1.0 - (len(corrections) * 0.05) # Decrease confidence per correction
return CorrectionResult(
original=text,
corrected=corrected,
corrections_applied=corrections,
confidence=max(0.5, confidence)
)
@classmethod
def generate_digit_variants(cls, text: str) -> list[str]:
"""
Generate all possible digit interpretations of a text.
Useful for matching when we don't know which direction the OCR error went.
"""
if not text:
return [text]
variants = {text}
# For each character that could be confused
for i, char in enumerate(text):
new_variants = set()
for existing in variants:
# If it's a digit, add letter variants
if char.isdigit() and char in cls.DIGIT_TO_CHAR:
for replacement in cls.DIGIT_TO_CHAR[char]:
new_variants.add(existing[:i] + replacement + existing[i+1:])
# If it's a letter that looks like a digit, add digit variant
if char in cls.CHAR_TO_DIGIT:
new_variants.add(existing[:i] + cls.CHAR_TO_DIGIT[char] + existing[i+1:])
variants.update(new_variants)
# Limit explosion - only keep reasonable number
if len(variants) > 100:
break
return list(variants)
@classmethod
def correct_swedish_term(cls, text: str) -> str:
"""
Correct common Swedish invoice terms that may have OCR errors.
"""
text_lower = text.lower()
for canonical, variants in cls.SWEDISH_TERM_CORRECTIONS.items():
for variant in variants:
if variant.lower() in text_lower:
# Replace with canonical form (preserving case of first letter)
pattern = re.compile(re.escape(variant), re.IGNORECASE)
if text[0].isupper():
replacement = canonical.capitalize()
else:
replacement = canonical
text = pattern.sub(replacement, text)
return text
@classmethod
def extract_with_context(cls, text: str, field_type: str) -> Optional[str]:
"""
Extract a field value using context indicators.
Looks for patterns like "Fakturanr: 12345" and extracts "12345".
"""
patterns = cls.CONTEXT_INDICATORS.get(field_type, [])
for pattern in patterns:
# Look for pattern followed by value
full_pattern = pattern + r'([^\s,;]+)'
match = re.search(full_pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None
@classmethod
def is_likely_ocr_error(cls, char1: str, char2: str) -> bool:
"""
Check if two characters are commonly confused in OCR.
"""
pair = (char1, char2)
reverse_pair = (char2, char1)
for p in cls.CONFUSION_PAIRS:
if pair == p or reverse_pair == p:
return True
return False
@classmethod
def count_potential_ocr_errors(cls, s1: str, s2: str) -> tuple[int, int]:
"""
Count how many character differences between two strings
are likely OCR errors vs other differences.
Returns: (ocr_errors, other_errors)
"""
if len(s1) != len(s2):
return (0, abs(len(s1) - len(s2)))
ocr_errors = 0
other_errors = 0
for c1, c2 in zip(s1, s2):
if c1 != c2:
if cls.is_likely_ocr_error(c1, c2):
ocr_errors += 1
else:
other_errors += 1
return (ocr_errors, other_errors)
@classmethod
def suggest_corrections(cls, text: str, expected_type: str = 'digit') -> list[tuple[str, float]]:
"""
Suggest possible corrections for a text with confidence scores.
Returns list of (corrected_text, confidence) tuples, sorted by confidence.
"""
suggestions = []
if expected_type == 'digit':
# Apply digit corrections with different levels of aggressiveness
mild = cls.correct_digits(text, aggressive=False)
if mild.corrected != text:
suggestions.append((mild.corrected, mild.confidence))
aggressive = cls.correct_digits(text, aggressive=True)
if aggressive.corrected != text and aggressive.corrected != mild.corrected:
suggestions.append((aggressive.corrected, aggressive.confidence * 0.9))
# Generate variants
variants = cls.generate_digit_variants(text)
for variant in variants[:10]: # Limit to top 10
if variant != text and variant not in [s[0] for s in suggestions]:
# Lower confidence for variants
suggestions.append((variant, 0.7))
# Sort by confidence
suggestions.sort(key=lambda x: x[1], reverse=True)
return suggestions
# =========================================================================
# Convenience functions
# =========================================================================
def correct_ocr_digits(text: str, aggressive: bool = False) -> str:
"""Convenience function to correct OCR digit errors."""
return OCRCorrections.correct_digits(text, aggressive).corrected
def generate_ocr_variants(text: str) -> list[str]:
"""Convenience function to generate OCR variants."""
return OCRCorrections.generate_digit_variants(text)
def is_ocr_confusion(char1: str, char2: str) -> bool:
"""Convenience function to check if characters are OCR confusable."""
return OCRCorrections.is_likely_ocr_error(char1, char2)

View File

@@ -0,0 +1,244 @@
"""
Text Cleaning Module
Provides text normalization and OCR error correction utilities.
Used by both inference (field_extractor) and matching (normalizer) stages.
"""
import re
from typing import Optional
class TextCleaner:
"""
Unified text cleaning utilities for invoice processing.
Handles:
- Unicode normalization (zero-width chars, dash variants)
- OCR error correction (O/0, l/1, etc.)
- Whitespace normalization
- Swedish-specific character handling
"""
# OCR常见错误修正映射 (用于数字字段)
# 当我们期望数字时,这些字符常被误识别
OCR_DIGIT_CORRECTIONS = {
'O': '0', 'o': '0', # 字母O -> 数字0
'Q': '0', # Q 有时像 0
'l': '1', 'I': '1', # 小写L/大写I -> 数字1
'|': '1', # 竖线 -> 1
'i': '1', # 小写i -> 1
'S': '5', 's': '5', # S -> 5
'B': '8', # B -> 8
'Z': '2', 'z': '2', # Z -> 2
'G': '6', 'g': '6', # G -> 6 (在某些字体中)
'A': '4', # A -> 4 (在某些字体中)
'T': '7', # T -> 7 (在某些字体中)
'q': '9', # q -> 9
'D': '0', # D -> 0
}
# 反向映射:数字被误识别为字母的情况 (用于字母数字混合字段)
OCR_LETTER_CORRECTIONS = {
'0': 'O',
'1': 'I',
'5': 'S',
'8': 'B',
'2': 'Z',
}
# Unicode 特殊字符归一化
UNICODE_NORMALIZATIONS = {
# 各种横线/破折号 -> 标准连字符
'\u2013': '-', # en-dash
'\u2014': '-', # em-dash —
'\u2212': '-', # minus sign
'\u00b7': '-', # middle dot ·
'\u2010': '-', # hyphen
'\u2011': '-', # non-breaking hyphen
'\u2012': '-', # figure dash
'\u2015': '-', # horizontal bar ―
# 各种空格 -> 标准空格
'\u00a0': ' ', # non-breaking space
'\u2002': ' ', # en space
'\u2003': ' ', # em space
'\u2009': ' ', # thin space
'\u200a': ' ', # hair space
# 零宽字符 -> 删除
'\u200b': '', # zero-width space
'\u200c': '', # zero-width non-joiner
'\u200d': '', # zero-width joiner
'\ufeff': '', # BOM / zero-width no-break space
# 各种引号 -> 标准引号
'\u2018': "'", # left single quote '
'\u2019': "'", # right single quote '
'\u201c': '"', # left double quote "
'\u201d': '"', # right double quote "
}
@classmethod
def clean_unicode(cls, text: str) -> str:
"""
Normalize Unicode characters to ASCII equivalents.
Handles:
- Various dash types -> standard hyphen (-)
- Various spaces -> standard space
- Zero-width characters -> removed
- Various quotes -> standard quotes
"""
for unicode_char, replacement in cls.UNICODE_NORMALIZATIONS.items():
text = text.replace(unicode_char, replacement)
return text
@classmethod
def normalize_whitespace(cls, text: str) -> str:
"""Collapse multiple whitespace to single space and strip."""
return ' '.join(text.split())
@classmethod
def clean_text(cls, text: str) -> str:
"""
Full text cleaning pipeline.
1. Normalize Unicode
2. Normalize whitespace
3. Strip
This is safe for all field types.
"""
text = cls.clean_unicode(text)
text = cls.normalize_whitespace(text)
return text.strip()
@classmethod
def apply_ocr_digit_corrections(cls, text: str) -> str:
"""
Apply OCR error corrections for digit-only fields.
Use this when the field is expected to contain only digits
(e.g., OCR number, organization number digits, etc.)
Example:
"556l23-4S67" -> "556123-4567"
"""
result = []
for char in text:
if char in cls.OCR_DIGIT_CORRECTIONS:
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
else:
result.append(char)
return ''.join(result)
@classmethod
def extract_digits(cls, text: str, apply_ocr_correction: bool = True) -> str:
"""
Extract only digits from text.
Args:
text: Input text
apply_ocr_correction: If True, apply OCR corrections ONLY to characters
that are adjacent to digits (not standalone letters)
Returns:
String containing only digits
"""
if apply_ocr_correction:
# 只对看起来像数字序列中的字符应用 OCR 修正
# 例如 "556O23" 中的 O 应该修正,但 "ABC 123" 中的 ABC 不应该
result = []
for i, char in enumerate(text):
if char.isdigit():
result.append(char)
elif char in cls.OCR_DIGIT_CORRECTIONS:
# 检查前后是否有数字
prev_is_digit = i > 0 and (text[i - 1].isdigit() or text[i - 1] in cls.OCR_DIGIT_CORRECTIONS)
next_is_digit = i < len(text) - 1 and (text[i + 1].isdigit() or text[i + 1] in cls.OCR_DIGIT_CORRECTIONS)
if prev_is_digit or next_is_digit:
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
# 其他字符跳过
return ''.join(result)
else:
return re.sub(r'\D', '', text)
@classmethod
def clean_for_digits(cls, text: str) -> str:
"""
Clean text that should primarily contain digits.
Pipeline:
1. Clean Unicode
2. Apply OCR digit corrections
3. Normalize whitespace
Preserves separators (-, /) for formatted numbers like "556123-4567"
"""
text = cls.clean_unicode(text)
text = cls.apply_ocr_digit_corrections(text)
text = cls.normalize_whitespace(text)
return text.strip()
@classmethod
def generate_ocr_variants(cls, text: str) -> list[str]:
"""
Generate possible OCR error variants of the input text.
This is useful for matching: if we have a CSV value,
we generate variants that might appear in OCR output.
Example:
"5561234567" -> ["5561234567", "556I234567", "5561234S67", ...]
"""
variants = {text}
# 只对数字生成字母变体
for digit, letter in cls.OCR_LETTER_CORRECTIONS.items():
if digit in text:
variants.add(text.replace(digit, letter))
# 对字母生成数字变体
for letter, digit in cls.OCR_DIGIT_CORRECTIONS.items():
if letter in text:
variants.add(text.replace(letter, digit))
return list(variants)
@classmethod
def normalize_amount_text(cls, text: str) -> str:
"""
Normalize amount text for parsing.
- Removes currency symbols and labels
- Normalizes separators
- Handles Swedish format (space as thousand separator)
"""
text = cls.clean_text(text)
# 移除货币符号和标签 (使用单词边界确保完整匹配)
text = re.sub(r'(?i)\b(kr|sek|kronor|öre)\b', '', text)
# 移除千位分隔空格 (Swedish: "1 234,56" -> "1234,56")
# 但保留小数点前的数字
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
return text.strip()
@classmethod
def normalize_for_comparison(cls, text: str) -> str:
"""
Normalize text for loose comparison.
- Lowercase
- Remove all non-alphanumeric
- Apply OCR corrections
This is the most aggressive normalization, used for fuzzy matching.
"""
text = cls.clean_text(text)
text = text.lower()
text = cls.apply_ocr_digit_corrections(text)
text = re.sub(r'[^a-z0-9]', '', text)
return text

View File

@@ -0,0 +1,393 @@
"""
Field Validators Module
Provides validation functions for Swedish invoice fields.
Used by both inference (to validate extracted values) and matching (to filter candidates).
"""
import re
from datetime import datetime
from typing import Optional
from .text_cleaner import TextCleaner
class FieldValidators:
"""
Validators for Swedish invoice field values.
Includes:
- Luhn (Mod10) checksum validation
- Format validation for specific field types
- Range validation for dates and amounts
"""
# =========================================================================
# Luhn (Mod10) Checksum
# =========================================================================
@classmethod
def luhn_checksum(cls, digits: str) -> bool:
"""
Validate using Luhn (Mod10) algorithm.
Used for:
- Bankgiro numbers
- Plusgiro numbers
- OCR reference numbers
- Swedish organization numbers
The checksum is valid if the total modulo 10 equals 0.
"""
# 只保留数字
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
if not digits or not digits.isdigit():
return False
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 1: # 从右往左,每隔一位加倍
digit *= 2
if digit > 9:
digit -= 9
total += digit
return total % 10 == 0
@classmethod
def calculate_luhn_check_digit(cls, digits: str) -> int:
"""
Calculate the Luhn check digit for a number.
Given a number without check digit, returns the digit that would make it valid.
"""
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
# 计算现有数字的 Luhn 和
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 0: # 注意:因为还要加一位,所以偶数位置加倍
digit *= 2
if digit > 9:
digit -= 9
total += digit
# 计算需要的校验位
check_digit = (10 - (total % 10)) % 10
return check_digit
# =========================================================================
# Organisation Number Validation
# =========================================================================
@classmethod
def is_valid_organisation_number(cls, value: str) -> bool:
"""
Validate Swedish organisation number.
Format: NNNNNN-NNNN (10 digits)
- First digit: 1-9
- Third digit: >= 2 (distinguishes from personal numbers)
- Last digit: Luhn check digit
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
# 处理 VAT 格式
if len(digits) == 12 and digits.endswith('01'):
digits = digits[:10]
elif len(digits) == 14 and digits.startswith('46') and digits.endswith('01'):
digits = digits[2:12]
if len(digits) != 10:
return False
# 第一位 1-9
if digits[0] == '0':
return False
# 第三位 >= 2 (区分组织号和个人号)
# 注意:有些特殊组织可能不符合此规则,所以这里放宽
# if int(digits[2]) < 2:
# return False
# Luhn 校验
return cls.luhn_checksum(digits)
# =========================================================================
# Bankgiro Validation
# =========================================================================
@classmethod
def is_valid_bankgiro(cls, value: str) -> bool:
"""
Validate Swedish Bankgiro number.
Format: 7 or 8 digits with Luhn checksum
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 7 or len(digits) > 8:
return False
return cls.luhn_checksum(digits)
@classmethod
def format_bankgiro(cls, value: str) -> Optional[str]:
"""
Format Bankgiro number to standard format.
Returns: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits), or None if invalid
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) == 7:
return f"{digits[:3]}-{digits[3:]}"
elif len(digits) == 8:
return f"{digits[:4]}-{digits[4:]}"
else:
return None
# =========================================================================
# Plusgiro Validation
# =========================================================================
@classmethod
def is_valid_plusgiro(cls, value: str) -> bool:
"""
Validate Swedish Plusgiro number.
Format: 2-8 digits with Luhn checksum
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 2 or len(digits) > 8:
return False
return cls.luhn_checksum(digits)
@classmethod
def format_plusgiro(cls, value: str) -> Optional[str]:
"""
Format Plusgiro number to standard format.
Returns: XXXXXXX-X format, or None if invalid
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 2 or len(digits) > 8:
return None
return f"{digits[:-1]}-{digits[-1]}"
# =========================================================================
# OCR Number Validation
# =========================================================================
@classmethod
def is_valid_ocr_number(cls, value: str, validate_checksum: bool = True) -> bool:
"""
Validate Swedish OCR reference number.
- Typically 10-25 digits
- Usually has Luhn checksum (but not always enforced)
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 5 or len(digits) > 25:
return False
if validate_checksum:
return cls.luhn_checksum(digits)
return True
# =========================================================================
# Amount Validation
# =========================================================================
@classmethod
def is_valid_amount(cls, value: str, min_amount: float = 0.0, max_amount: float = 10_000_000.0) -> bool:
"""
Validate monetary amount.
- Must be positive (or at least >= min_amount)
- Should be within reasonable range
"""
try:
# 尝试解析
text = TextCleaner.normalize_amount_text(value)
# 统一为点作为小数分隔符
text = text.replace(' ', '').replace(',', '.')
# 如果有多个点,保留最后一个
if text.count('.') > 1:
parts = text.rsplit('.', 1)
text = parts[0].replace('.', '') + '.' + parts[1]
amount = float(text)
return min_amount <= amount <= max_amount
except (ValueError, TypeError):
return False
@classmethod
def parse_amount(cls, value: str) -> Optional[float]:
"""
Parse amount from string, handling various formats.
Returns float or None if parsing fails.
"""
try:
text = TextCleaner.normalize_amount_text(value)
text = text.replace(' ', '')
# 检测格式并解析
# 瑞典/德国格式: 逗号是小数点
if re.match(r'^[\d.]+,\d{1,2}$', text):
text = text.replace('.', '').replace(',', '.')
# 美国格式: 点是小数点
elif re.match(r'^[\d,]+\.\d{1,2}$', text):
text = text.replace(',', '')
else:
# 简单格式
text = text.replace(',', '.')
if text.count('.') > 1:
parts = text.rsplit('.', 1)
text = parts[0].replace('.', '') + '.' + parts[1]
return float(text)
except (ValueError, TypeError):
return None
# =========================================================================
# Date Validation
# =========================================================================
@classmethod
def is_valid_date(cls, value: str, min_year: int = 2000, max_year: int = 2100) -> bool:
"""
Validate date string.
- Year should be within reasonable range
- Month 1-12
- Day 1-31 (basic check)
"""
parsed = cls.parse_date(value)
if parsed is None:
return False
year, month, day = parsed
if not (min_year <= year <= max_year):
return False
if not (1 <= month <= 12):
return False
if not (1 <= day <= 31):
return False
# 更精确的日期验证
try:
datetime(year, month, day)
return True
except ValueError:
return False
@classmethod
def parse_date(cls, value: str) -> Optional[tuple[int, int, int]]:
"""
Parse date from string.
Returns (year, month, day) tuple or None.
"""
from .format_variants import FormatVariants
return FormatVariants._parse_date(value)
@classmethod
def format_date_iso(cls, value: str) -> Optional[str]:
"""
Format date to ISO format (YYYY-MM-DD).
Returns formatted string or None if parsing fails.
"""
parsed = cls.parse_date(value)
if parsed is None:
return None
year, month, day = parsed
return f"{year}-{month:02d}-{day:02d}"
# =========================================================================
# Invoice Number Validation
# =========================================================================
@classmethod
def is_valid_invoice_number(cls, value: str, min_length: int = 1, max_length: int = 30) -> bool:
"""
Validate invoice number.
Basic validation - just length check since invoice numbers are highly variable.
"""
clean = TextCleaner.clean_text(value)
if not clean:
return False
# 提取有意义的字符(字母和数字)
meaningful = re.sub(r'[^a-zA-Z0-9]', '', clean)
return min_length <= len(meaningful) <= max_length
# =========================================================================
# Generic Validation
# =========================================================================
@classmethod
def validate_field(cls, field_name: str, value: str) -> tuple[bool, Optional[str]]:
"""
Validate a field by name.
Returns (is_valid, error_message).
"""
if not value:
return False, "Empty value"
field_lower = field_name.lower()
if 'organisation' in field_lower or 'org' in field_lower:
if cls.is_valid_organisation_number(value):
return True, None
return False, "Invalid organisation number format or checksum"
elif 'bankgiro' in field_lower:
if cls.is_valid_bankgiro(value):
return True, None
return False, "Invalid Bankgiro format or checksum"
elif 'plusgiro' in field_lower:
if cls.is_valid_plusgiro(value):
return True, None
return False, "Invalid Plusgiro format or checksum"
elif 'ocr' in field_lower:
if cls.is_valid_ocr_number(value, validate_checksum=False):
return True, None
return False, "Invalid OCR number length"
elif 'amount' in field_lower:
if cls.is_valid_amount(value):
return True, None
return False, "Invalid amount format"
elif 'date' in field_lower:
if cls.is_valid_date(value):
return True, None
return False, "Invalid date format"
elif 'invoice' in field_lower and 'number' in field_lower:
if cls.is_valid_invoice_number(value):
return True, None
return False, "Invalid invoice number"
else:
# 默认:只检查非空
if TextCleaner.clean_text(value):
return True, None
return False, "Empty value after cleaning"