restructure project
This commit is contained in:
9
packages/shared/requirements.txt
Normal file
9
packages/shared/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
PyMuPDF>=1.23.0
|
||||
paddleocr>=2.7.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
opencv-python>=4.8.0
|
||||
psycopg2-binary>=2.9.0
|
||||
python-dotenv>=1.0.0
|
||||
pyyaml>=6.0
|
||||
thefuzz>=0.20.0
|
||||
19
packages/shared/setup.py
Normal file
19
packages/shared/setup.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="invoice-shared",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.11",
|
||||
install_requires=[
|
||||
"PyMuPDF>=1.23.0",
|
||||
"paddleocr>=2.7.0",
|
||||
"Pillow>=10.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"opencv-python>=4.8.0",
|
||||
"psycopg2-binary>=2.9.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pyyaml>=6.0",
|
||||
"thefuzz>=0.20.0",
|
||||
],
|
||||
)
|
||||
2
packages/shared/shared/__init__.py
Normal file
2
packages/shared/shared/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Invoice Master POC v2
|
||||
# Automatic invoice information extraction system using YOLO + OCR
|
||||
88
packages/shared/shared/config.py
Normal file
88
packages/shared/shared/config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Configuration settings for the invoice extraction system.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file at project root
|
||||
# Walk up from packages/shared/shared/config.py to find project root
|
||||
_config_dir = Path(__file__).parent
|
||||
for _candidate in [_config_dir.parent.parent.parent, _config_dir.parent.parent, _config_dir.parent]:
|
||||
_env_path = _candidate / '.env'
|
||||
if _env_path.exists():
|
||||
load_dotenv(dotenv_path=_env_path)
|
||||
break
|
||||
else:
|
||||
load_dotenv() # fallback: search cwd and parents
|
||||
|
||||
# Global DPI setting - must match training DPI for optimal model performance
|
||||
DEFAULT_DPI = 150
|
||||
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Check if running inside WSL (Windows Subsystem for Linux)."""
|
||||
if platform.system() != 'Linux':
|
||||
return False
|
||||
# Check for WSL-specific indicators
|
||||
if os.environ.get('WSL_DISTRO_NAME'):
|
||||
return True
|
||||
try:
|
||||
with open('/proc/version', 'r') as f:
|
||||
return 'microsoft' in f.read().lower()
|
||||
except (FileNotFoundError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
# PostgreSQL Database Configuration
|
||||
# Now loaded from environment variables for security
|
||||
DATABASE = {
|
||||
'host': os.getenv('DB_HOST', '192.168.68.31'),
|
||||
'port': int(os.getenv('DB_PORT', '5432')),
|
||||
'database': os.getenv('DB_NAME', 'docmaster'),
|
||||
'user': os.getenv('DB_USER', 'docmaster'),
|
||||
'password': os.getenv('DB_PASSWORD'), # No default for security
|
||||
}
|
||||
|
||||
# Validate required configuration
|
||||
if not DATABASE['password']:
|
||||
raise ValueError(
|
||||
"DB_PASSWORD environment variable is not set. "
|
||||
"Please create a .env file based on .env.example and set DB_PASSWORD."
|
||||
)
|
||||
|
||||
# Connection string for psycopg2
|
||||
def get_db_connection_string():
|
||||
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
||||
|
||||
|
||||
# Paths Configuration - auto-detect WSL vs Windows
|
||||
if _is_wsl():
|
||||
# WSL: use native Linux filesystem for better I/O performance
|
||||
PATHS = {
|
||||
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
|
||||
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
|
||||
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
else:
|
||||
# Windows or native Linux: use relative paths
|
||||
PATHS = {
|
||||
'csv_dir': 'data/structured_data',
|
||||
'pdf_dir': 'data/raw_pdfs',
|
||||
'output_dir': 'data/dataset',
|
||||
'reports_dir': 'reports',
|
||||
}
|
||||
|
||||
# Auto-labeling Configuration
|
||||
AUTOLABEL = {
|
||||
'workers': 2,
|
||||
'dpi': DEFAULT_DPI,
|
||||
'min_confidence': 0.5,
|
||||
'train_ratio': 0.8,
|
||||
'val_ratio': 0.1,
|
||||
'test_ratio': 0.1,
|
||||
'max_records_per_report': 10000,
|
||||
}
|
||||
3
packages/shared/shared/data/__init__.py
Normal file
3
packages/shared/shared/data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .csv_loader import CSVLoader, InvoiceRow
|
||||
|
||||
__all__ = ['CSVLoader', 'InvoiceRow']
|
||||
372
packages/shared/shared/data/csv_loader.py
Normal file
372
packages/shared/shared/data/csv_loader.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
CSV Data Loader
|
||||
|
||||
Loads and parses structured invoice data from CSV files.
|
||||
Follows the CSV specification for invoice data.
|
||||
"""
|
||||
|
||||
import csv
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvoiceRow:
|
||||
"""Parsed invoice data row."""
|
||||
DocumentId: str
|
||||
InvoiceDate: date | None = None
|
||||
InvoiceNumber: str | None = None
|
||||
InvoiceDueDate: date | None = None
|
||||
OCR: str | None = None
|
||||
Message: str | None = None
|
||||
Bankgiro: str | None = None
|
||||
Plusgiro: str | None = None
|
||||
Amount: Decimal | None = None
|
||||
# New fields
|
||||
split: str | None = None # train/test split indicator
|
||||
customer_number: str | None = None # Customer number (needs matching)
|
||||
supplier_name: str | None = None # Supplier name (no matching)
|
||||
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
|
||||
supplier_accounts: str | None = None # Supplier accounts (needs matching)
|
||||
|
||||
# Raw values for reference
|
||||
raw_data: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for matching."""
|
||||
return {
|
||||
'DocumentId': self.DocumentId,
|
||||
'InvoiceDate': self.InvoiceDate.isoformat() if self.InvoiceDate else None,
|
||||
'InvoiceNumber': self.InvoiceNumber,
|
||||
'InvoiceDueDate': self.InvoiceDueDate.isoformat() if self.InvoiceDueDate else None,
|
||||
'OCR': self.OCR,
|
||||
'Bankgiro': self.Bankgiro,
|
||||
'Plusgiro': self.Plusgiro,
|
||||
'Amount': str(self.Amount) if self.Amount else None,
|
||||
'supplier_organisation_number': self.supplier_organisation_number,
|
||||
'supplier_accounts': self.supplier_accounts,
|
||||
}
|
||||
|
||||
def get_field_value(self, field_name: str) -> str | None:
|
||||
"""Get field value as string for matching."""
|
||||
value = getattr(self, field_name, None)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, date):
|
||||
return value.isoformat()
|
||||
if isinstance(value, Decimal):
|
||||
return str(value)
|
||||
return str(value) if value else None
|
||||
|
||||
|
||||
class CSVLoader:
|
||||
"""Loads invoice data from CSV files."""
|
||||
|
||||
# Expected field mappings (CSV header -> InvoiceRow attribute)
|
||||
FIELD_MAPPINGS = {
|
||||
'DocumentId': 'DocumentId',
|
||||
'InvoiceDate': 'InvoiceDate',
|
||||
'InvoiceNumber': 'InvoiceNumber',
|
||||
'InvoiceDueDate': 'InvoiceDueDate',
|
||||
'OCR': 'OCR',
|
||||
'Message': 'Message',
|
||||
'Bankgiro': 'Bankgiro',
|
||||
'Plusgiro': 'Plusgiro',
|
||||
'Amount': 'Amount',
|
||||
# New fields
|
||||
'split': 'split',
|
||||
'customer_number': 'customer_number',
|
||||
'supplier_name': 'supplier_name',
|
||||
'supplier_organisation_number': 'supplier_organisation_number',
|
||||
'supplier_accounts': 'supplier_accounts',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str | Path | list[str | Path],
|
||||
pdf_dir: str | Path | None = None,
|
||||
doc_map_path: str | Path | None = None,
|
||||
encoding: str = 'utf-8'
|
||||
):
|
||||
"""
|
||||
Initialize CSV loader.
|
||||
|
||||
Args:
|
||||
csv_path: Path to CSV file(s). Can be:
|
||||
- Single path: 'data/file.csv'
|
||||
- List of paths: ['data/file1.csv', 'data/file2.csv']
|
||||
- Glob pattern: 'data/*.csv' or 'data/export_*.csv'
|
||||
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
|
||||
doc_map_path: Optional path to document mapping CSV
|
||||
encoding: CSV file encoding (default: utf-8)
|
||||
"""
|
||||
# Handle multiple CSV files
|
||||
if isinstance(csv_path, list):
|
||||
self.csv_paths = [Path(p) for p in csv_path]
|
||||
else:
|
||||
csv_path = Path(csv_path)
|
||||
# Check if it's a glob pattern (contains * or ?)
|
||||
if '*' in str(csv_path) or '?' in str(csv_path):
|
||||
parent = csv_path.parent
|
||||
pattern = csv_path.name
|
||||
self.csv_paths = sorted(parent.glob(pattern))
|
||||
else:
|
||||
self.csv_paths = [csv_path]
|
||||
|
||||
# For backward compatibility
|
||||
self.csv_path = self.csv_paths[0] if self.csv_paths else None
|
||||
|
||||
self.pdf_dir = Path(pdf_dir) if pdf_dir else (self.csv_path.parent.parent / 'raw_pdfs' if self.csv_path else Path('data/raw_pdfs'))
|
||||
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
|
||||
self.encoding = encoding
|
||||
|
||||
# Load document mapping if provided
|
||||
self.doc_map = self._load_doc_map() if self.doc_map_path else {}
|
||||
|
||||
def _load_doc_map(self) -> dict[str, str]:
|
||||
"""Load document ID to filename mapping."""
|
||||
mapping = {}
|
||||
if self.doc_map_path and self.doc_map_path.exists():
|
||||
with open(self.doc_map_path, 'r', encoding=self.encoding) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
doc_id = row.get('DocumentId', '').strip()
|
||||
filename = row.get('FileName', '').strip()
|
||||
if doc_id and filename:
|
||||
mapping[doc_id] = filename
|
||||
return mapping
|
||||
|
||||
def _parse_date(self, value: str | None) -> date | None:
|
||||
"""Parse date from various formats."""
|
||||
if not value or not value.strip():
|
||||
return None
|
||||
|
||||
value = value.strip()
|
||||
|
||||
# Try different date formats
|
||||
formats = [
|
||||
'%Y-%m-%d',
|
||||
'%Y-%m-%d %H:%M:%S',
|
||||
'%Y-%m-%d %H:%M:%S.%f',
|
||||
'%d/%m/%Y',
|
||||
'%d.%m.%Y',
|
||||
'%d-%m-%Y',
|
||||
'%Y%m%d',
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(value, fmt).date()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _parse_amount(self, value: str | None) -> Decimal | None:
|
||||
"""Parse monetary amount from various formats."""
|
||||
if not value or not value.strip():
|
||||
return None
|
||||
|
||||
value = value.strip()
|
||||
|
||||
# Remove currency symbols and common suffixes
|
||||
value = value.replace('SEK', '').replace('kr', '').replace(':-', '')
|
||||
value = value.strip()
|
||||
|
||||
# Remove spaces (thousand separators)
|
||||
value = value.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator (European format)
|
||||
if ',' in value and '.' not in value:
|
||||
value = value.replace(',', '.')
|
||||
elif ',' in value and '.' in value:
|
||||
# Assume comma is thousands separator, dot is decimal
|
||||
value = value.replace(',', '')
|
||||
|
||||
try:
|
||||
return Decimal(value)
|
||||
except InvalidOperation:
|
||||
return None
|
||||
|
||||
def _parse_string(self, value: str | None) -> str | None:
|
||||
"""Parse string field with cleanup."""
|
||||
if value is None:
|
||||
return None
|
||||
value = value.strip()
|
||||
return value if value else None
|
||||
|
||||
def _get_field(self, row: dict, *keys: str) -> str | None:
|
||||
"""Get field value trying multiple possible column names."""
|
||||
for key in keys:
|
||||
value = row.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
def _parse_row(self, row: dict) -> InvoiceRow | None:
|
||||
"""Parse a single CSV row into InvoiceRow."""
|
||||
doc_id = self._parse_string(self._get_field(row, 'DocumentId', 'document_id'))
|
||||
if not doc_id:
|
||||
return None
|
||||
|
||||
return InvoiceRow(
|
||||
DocumentId=doc_id,
|
||||
InvoiceDate=self._parse_date(self._get_field(row, 'InvoiceDate', 'invoice_date')),
|
||||
InvoiceNumber=self._parse_string(self._get_field(row, 'InvoiceNumber', 'invoice_number')),
|
||||
InvoiceDueDate=self._parse_date(self._get_field(row, 'InvoiceDueDate', 'invoice_due_date')),
|
||||
OCR=self._parse_string(self._get_field(row, 'OCR', 'ocr')),
|
||||
Message=self._parse_string(self._get_field(row, 'Message', 'message')),
|
||||
Bankgiro=self._parse_string(self._get_field(row, 'Bankgiro', 'bankgiro')),
|
||||
Plusgiro=self._parse_string(self._get_field(row, 'Plusgiro', 'plusgiro')),
|
||||
Amount=self._parse_amount(self._get_field(row, 'Amount', 'amount', 'invoice_data_amount')),
|
||||
# New fields
|
||||
split=self._parse_string(row.get('split')),
|
||||
customer_number=self._parse_string(row.get('customer_number')),
|
||||
supplier_name=self._parse_string(row.get('supplier_name')),
|
||||
supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')),
|
||||
supplier_accounts=self._parse_string(row.get('supplier_accounts')),
|
||||
raw_data=dict(row)
|
||||
)
|
||||
|
||||
def _iter_single_csv(self, csv_path: Path) -> Iterator[InvoiceRow]:
|
||||
"""Iterate over rows from a single CSV file."""
|
||||
# Handle BOM - try utf-8-sig first to handle BOM correctly
|
||||
encodings = ['utf-8-sig', self.encoding, 'latin-1']
|
||||
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(csv_path, 'r', encoding=enc) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
parsed = self._parse_row(row)
|
||||
if parsed:
|
||||
yield parsed
|
||||
return
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
raise ValueError(f"Could not read CSV file {csv_path} with any supported encoding")
|
||||
|
||||
def load_all(self) -> list[InvoiceRow]:
|
||||
"""Load all rows from CSV(s)."""
|
||||
rows = []
|
||||
for row in self.iter_rows():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def iter_rows(self) -> Iterator[InvoiceRow]:
|
||||
"""Iterate over CSV rows from all CSV files."""
|
||||
seen_doc_ids = set()
|
||||
|
||||
for csv_path in self.csv_paths:
|
||||
if not csv_path.exists():
|
||||
continue
|
||||
for row in self._iter_single_csv(csv_path):
|
||||
# Deduplicate by DocumentId
|
||||
if row.DocumentId not in seen_doc_ids:
|
||||
seen_doc_ids.add(row.DocumentId)
|
||||
yield row
|
||||
|
||||
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
|
||||
"""
|
||||
Get PDF path for an invoice row.
|
||||
|
||||
Uses document mapping if available, otherwise assumes
|
||||
DocumentId.pdf naming convention.
|
||||
"""
|
||||
doc_id = invoice_row.DocumentId
|
||||
|
||||
# Check document mapping first
|
||||
if doc_id in self.doc_map:
|
||||
filename = self.doc_map[doc_id]
|
||||
pdf_path = self.pdf_dir / filename
|
||||
if pdf_path.exists():
|
||||
return pdf_path
|
||||
|
||||
# Try default naming patterns
|
||||
patterns = [
|
||||
f"{doc_id}.pdf",
|
||||
f"{doc_id}.PDF",
|
||||
f"{doc_id.lower()}.pdf",
|
||||
f"{doc_id.lower()}.PDF",
|
||||
f"{doc_id.upper()}.pdf",
|
||||
f"{doc_id.upper()}.PDF",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
pdf_path = self.pdf_dir / pattern
|
||||
if pdf_path.exists():
|
||||
return pdf_path
|
||||
|
||||
# Try glob patterns for partial matches (both cases)
|
||||
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.pdf"):
|
||||
return pdf_file
|
||||
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.PDF"):
|
||||
return pdf_file
|
||||
|
||||
return None
|
||||
|
||||
def get_row_by_id(self, doc_id: str) -> InvoiceRow | None:
|
||||
"""Get a specific row by DocumentId."""
|
||||
for row in self.iter_rows():
|
||||
if row.DocumentId == doc_id:
|
||||
return row
|
||||
return None
|
||||
|
||||
def validate(self) -> list[dict]:
|
||||
"""
|
||||
Validate CSV data and return issues.
|
||||
|
||||
Returns:
|
||||
List of validation issues
|
||||
"""
|
||||
issues = []
|
||||
|
||||
for i, row in enumerate(self.iter_rows(), start=2): # Start at 2 (header is row 1)
|
||||
# Check required DocumentId
|
||||
if not row.DocumentId:
|
||||
issues.append({
|
||||
'row': i,
|
||||
'field': 'DocumentId',
|
||||
'issue': 'Missing required DocumentId'
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if PDF exists
|
||||
pdf_path = self.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
issues.append({
|
||||
'row': i,
|
||||
'doc_id': row.DocumentId,
|
||||
'field': 'PDF',
|
||||
'issue': 'PDF file not found'
|
||||
})
|
||||
|
||||
# Check for at least one matchable field
|
||||
matchable_fields = [
|
||||
row.InvoiceNumber,
|
||||
row.OCR,
|
||||
row.Bankgiro,
|
||||
row.Plusgiro,
|
||||
row.Amount,
|
||||
row.supplier_organisation_number,
|
||||
row.supplier_accounts,
|
||||
]
|
||||
if not any(matchable_fields):
|
||||
issues.append({
|
||||
'row': i,
|
||||
'doc_id': row.DocumentId,
|
||||
'field': 'All',
|
||||
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)'
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
def load_invoice_csv(csv_path: str | Path | list[str | Path], pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
|
||||
"""Convenience function to load invoice CSV(s)."""
|
||||
loader = CSVLoader(csv_path, pdf_dir)
|
||||
return loader.load_all()
|
||||
530
packages/shared/shared/data/db.py
Normal file
530
packages/shared/shared/data/db.py
Normal file
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Database utilities for autolabel workflow.
|
||||
"""
|
||||
|
||||
import json
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
|
||||
class DocumentDB:
|
||||
"""Database interface for document processing status."""
|
||||
|
||||
def __init__(self, connection_string: str = None):
|
||||
self.connection_string = connection_string or get_db_connection_string()
|
||||
self.conn = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to database."""
|
||||
if self.conn is None:
|
||||
self.conn = psycopg2.connect(self.connection_string)
|
||||
return self.conn
|
||||
|
||||
def create_tables(self):
|
||||
"""Create database tables if they don't exist."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
document_id TEXT PRIMARY KEY,
|
||||
pdf_path TEXT,
|
||||
pdf_type TEXT,
|
||||
success BOOLEAN,
|
||||
total_pages INTEGER,
|
||||
fields_matched INTEGER,
|
||||
fields_total INTEGER,
|
||||
annotations_generated INTEGER,
|
||||
processing_time_ms REAL,
|
||||
timestamp TIMESTAMPTZ,
|
||||
errors JSONB DEFAULT '[]',
|
||||
-- Extended CSV format fields
|
||||
split TEXT,
|
||||
customer_number TEXT,
|
||||
supplier_name TEXT,
|
||||
supplier_organisation_number TEXT,
|
||||
supplier_accounts TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS field_results (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
|
||||
field_name TEXT,
|
||||
csv_value TEXT,
|
||||
matched BOOLEAN,
|
||||
score REAL,
|
||||
matched_text TEXT,
|
||||
candidate_used TEXT,
|
||||
bbox JSONB,
|
||||
page_no INTEGER,
|
||||
context_keywords JSONB DEFAULT '[]',
|
||||
error TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
|
||||
|
||||
-- Add new columns to existing tables if they don't exist (for migration)
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
|
||||
ALTER TABLE documents ADD COLUMN split TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
|
||||
ALTER TABLE documents ADD COLUMN customer_number TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
|
||||
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
|
||||
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
|
||||
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
|
||||
END IF;
|
||||
END $$;
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def get_successful_doc_ids(self) -> Set[str]:
|
||||
"""Get all document IDs that have been successfully processed."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT document_id FROM documents WHERE success = true")
|
||||
return {row[0] for row in cursor.fetchall()}
|
||||
|
||||
def get_failed_doc_ids(self) -> Set[str]:
|
||||
"""Get all document IDs that failed processing."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT document_id FROM documents WHERE success = false")
|
||||
return {row[0] for row in cursor.fetchall()}
|
||||
|
||||
def check_document_status(self, doc_id: str) -> Optional[bool]:
|
||||
"""
|
||||
Check if a document exists and its success status.
|
||||
|
||||
Returns:
|
||||
True if exists and success=true
|
||||
False if exists and success=false
|
||||
None if not exists
|
||||
"""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT success FROM documents WHERE document_id = %s",
|
||||
(doc_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
|
||||
def check_documents_status_batch(self, doc_ids: list[str]) -> Dict[str, Optional[bool]]:
|
||||
"""
|
||||
Batch check document status for multiple IDs.
|
||||
|
||||
Returns:
|
||||
Dict mapping doc_id to status:
|
||||
True if exists and success=true
|
||||
False if exists and success=false
|
||||
(missing from dict if not exists)
|
||||
"""
|
||||
if not doc_ids:
|
||||
return {}
|
||||
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT document_id, success FROM documents WHERE document_id = ANY(%s)",
|
||||
(doc_ids,)
|
||||
)
|
||||
return {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
def delete_document(self, doc_id: str):
|
||||
"""Delete a document and its field results (for re-processing)."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# field_results will be cascade deleted
|
||||
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
|
||||
conn.commit()
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single document with its field results."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Get document
|
||||
cursor.execute("""
|
||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors,
|
||||
split, customer_number, supplier_name,
|
||||
supplier_organisation_number, supplier_accounts
|
||||
FROM documents WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
doc = {
|
||||
'document_id': row[0],
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'success': row[3],
|
||||
'total_pages': row[4],
|
||||
'fields_matched': row[5],
|
||||
'fields_total': row[6],
|
||||
'annotations_generated': row[7],
|
||||
'processing_time_ms': row[8],
|
||||
'timestamp': str(row[9]) if row[9] else None,
|
||||
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
||||
# New fields
|
||||
'split': row[11],
|
||||
'customer_number': row[12],
|
||||
'supplier_name': row[13],
|
||||
'supplier_organisation_number': row[14],
|
||||
'supplier_accounts': row[15],
|
||||
'field_results': []
|
||||
}
|
||||
|
||||
# Get field results
|
||||
cursor.execute("""
|
||||
SELECT field_name, csv_value, matched, score, matched_text,
|
||||
candidate_used, bbox, page_no, context_keywords, error
|
||||
FROM field_results WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
|
||||
for fr in cursor.fetchall():
|
||||
doc['field_results'].append({
|
||||
'field_name': fr[0],
|
||||
'csv_value': fr[1],
|
||||
'matched': fr[2],
|
||||
'score': fr[3],
|
||||
'matched_text': fr[4],
|
||||
'candidate_used': fr[5],
|
||||
'bbox': fr[6] if isinstance(fr[6], list) else json.loads(fr[6]) if fr[6] else None,
|
||||
'page_no': fr[7],
|
||||
'context_keywords': fr[8] if isinstance(fr[8], list) else json.loads(fr[8] or '[]'),
|
||||
'error': fr[9]
|
||||
})
|
||||
|
||||
return doc
|
||||
|
||||
def get_all_documents_summary(self, success_only: bool = False, limit: int = None) -> list[Dict[str, Any]]:
|
||||
"""Get summary of all documents (without field_results for efficiency)."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
query = """
|
||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total
|
||||
FROM documents
|
||||
"""
|
||||
params = []
|
||||
if success_only:
|
||||
query += " WHERE success = true"
|
||||
query += " ORDER BY timestamp DESC"
|
||||
if limit:
|
||||
# Use parameterized query instead of f-string
|
||||
query += " LIMIT %s"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params if params else None)
|
||||
return [
|
||||
{
|
||||
'document_id': row[0],
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'success': row[3],
|
||||
'total_pages': row[4],
|
||||
'fields_matched': row[5],
|
||||
'fields_total': row[6]
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_field_stats(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Get match statistics per field."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT field_name,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched
|
||||
FROM field_results
|
||||
GROUP BY field_name
|
||||
ORDER BY field_name
|
||||
""")
|
||||
return {
|
||||
row[0]: {'total': row[1], 'matched': row[2]}
|
||||
for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_failed_matches(self, field_name: str = None, limit: int = 100) -> list[Dict[str, Any]]:
|
||||
"""Get field results that failed to match."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
query = """
|
||||
SELECT fr.document_id, fr.field_name, fr.csv_value, fr.error,
|
||||
d.pdf_type
|
||||
FROM field_results fr
|
||||
JOIN documents d ON fr.document_id = d.document_id
|
||||
WHERE fr.matched = false
|
||||
"""
|
||||
params = []
|
||||
if field_name:
|
||||
query += " AND fr.field_name = %s"
|
||||
params.append(field_name)
|
||||
# Use parameterized query instead of f-string
|
||||
query += " LIMIT %s"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
return [
|
||||
{
|
||||
'document_id': row[0],
|
||||
'field_name': row[1],
|
||||
'csv_value': row[2],
|
||||
'error': row[3],
|
||||
'pdf_type': row[4]
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_documents_batch(self, doc_ids: list[str]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get multiple documents with their field results in a single batch query.
|
||||
|
||||
This is much more efficient than calling get_document() in a loop.
|
||||
|
||||
Args:
|
||||
doc_ids: List of document IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dict mapping doc_id to document data (with field_results)
|
||||
"""
|
||||
if not doc_ids:
|
||||
return {}
|
||||
|
||||
conn = self.connect()
|
||||
result: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Batch fetch all documents
|
||||
cursor.execute("""
|
||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors,
|
||||
split, customer_number, supplier_name,
|
||||
supplier_organisation_number, supplier_accounts
|
||||
FROM documents WHERE document_id = ANY(%s)
|
||||
""", (doc_ids,))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
result[row[0]] = {
|
||||
'document_id': row[0],
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'success': row[3],
|
||||
'total_pages': row[4],
|
||||
'fields_matched': row[5],
|
||||
'fields_total': row[6],
|
||||
'annotations_generated': row[7],
|
||||
'processing_time_ms': row[8],
|
||||
'timestamp': str(row[9]) if row[9] else None,
|
||||
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
||||
# New fields
|
||||
'split': row[11],
|
||||
'customer_number': row[12],
|
||||
'supplier_name': row[13],
|
||||
'supplier_organisation_number': row[14],
|
||||
'supplier_accounts': row[15],
|
||||
'field_results': []
|
||||
}
|
||||
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
# Batch fetch all field results for these documents
|
||||
cursor.execute("""
|
||||
SELECT document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error
|
||||
FROM field_results WHERE document_id = ANY(%s)
|
||||
""", (list(result.keys()),))
|
||||
|
||||
for fr in cursor.fetchall():
|
||||
doc_id = fr[0]
|
||||
if doc_id in result:
|
||||
result[doc_id]['field_results'].append({
|
||||
'field_name': fr[1],
|
||||
'csv_value': fr[2],
|
||||
'matched': fr[3],
|
||||
'score': fr[4],
|
||||
'matched_text': fr[5],
|
||||
'candidate_used': fr[6],
|
||||
'bbox': fr[7] if isinstance(fr[7], list) else json.loads(fr[7]) if fr[7] else None,
|
||||
'page_no': fr[8],
|
||||
'context_keywords': fr[9] if isinstance(fr[9], list) else json.loads(fr[9] or '[]'),
|
||||
'error': fr[10]
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def save_document(self, report: Dict[str, Any]):
|
||||
"""Save or update a document and its field results using batch operations."""
|
||||
conn = self.connect()
|
||||
doc_id = report.get('document_id')
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Delete existing record if any (for update)
|
||||
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
|
||||
|
||||
# Insert document
|
||||
cursor.execute("""
|
||||
INSERT INTO documents
|
||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors,
|
||||
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
doc_id,
|
||||
report.get('pdf_path'),
|
||||
report.get('pdf_type'),
|
||||
report.get('success'),
|
||||
report.get('total_pages'),
|
||||
report.get('fields_matched'),
|
||||
report.get('fields_total'),
|
||||
report.get('annotations_generated'),
|
||||
report.get('processing_time_ms'),
|
||||
report.get('timestamp'),
|
||||
json.dumps(report.get('errors', [])),
|
||||
# New fields
|
||||
report.get('split'),
|
||||
report.get('customer_number'),
|
||||
report.get('supplier_name'),
|
||||
report.get('supplier_organisation_number'),
|
||||
report.get('supplier_accounts'),
|
||||
))
|
||||
|
||||
# Batch insert field results using execute_values
|
||||
field_results = report.get('field_results', [])
|
||||
if field_results:
|
||||
field_values = [
|
||||
(
|
||||
doc_id,
|
||||
field.get('field_name'),
|
||||
field.get('csv_value'),
|
||||
field.get('matched'),
|
||||
field.get('score'),
|
||||
field.get('matched_text'),
|
||||
field.get('candidate_used'),
|
||||
json.dumps(field.get('bbox')) if field.get('bbox') else None,
|
||||
field.get('page_no'),
|
||||
json.dumps(field.get('context_keywords', [])),
|
||||
field.get('error')
|
||||
)
|
||||
for field in field_results
|
||||
]
|
||||
execute_values(cursor, """
|
||||
INSERT INTO field_results
|
||||
(document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error)
|
||||
VALUES %s
|
||||
""", field_values)
|
||||
|
||||
conn.commit()
|
||||
|
||||
def save_documents_batch(self, reports: list[Dict[str, Any]]):
|
||||
"""Save multiple documents in a batch."""
|
||||
if not reports:
|
||||
return
|
||||
|
||||
conn = self.connect()
|
||||
doc_ids = [r['document_id'] for r in reports]
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Delete existing records
|
||||
cursor.execute(
|
||||
"DELETE FROM documents WHERE document_id = ANY(%s)",
|
||||
(doc_ids,)
|
||||
)
|
||||
|
||||
# Batch insert documents
|
||||
doc_values = [
|
||||
(
|
||||
r.get('document_id'),
|
||||
r.get('pdf_path'),
|
||||
r.get('pdf_type'),
|
||||
r.get('success'),
|
||||
r.get('total_pages'),
|
||||
r.get('fields_matched'),
|
||||
r.get('fields_total'),
|
||||
r.get('annotations_generated'),
|
||||
r.get('processing_time_ms'),
|
||||
r.get('timestamp'),
|
||||
json.dumps(r.get('errors', [])),
|
||||
# New fields
|
||||
r.get('split'),
|
||||
r.get('customer_number'),
|
||||
r.get('supplier_name'),
|
||||
r.get('supplier_organisation_number'),
|
||||
r.get('supplier_accounts'),
|
||||
)
|
||||
for r in reports
|
||||
]
|
||||
execute_values(cursor, """
|
||||
INSERT INTO documents
|
||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors,
|
||||
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
|
||||
VALUES %s
|
||||
""", doc_values)
|
||||
|
||||
# Batch insert field results
|
||||
field_values = []
|
||||
for r in reports:
|
||||
doc_id = r.get('document_id')
|
||||
for field in r.get('field_results', []):
|
||||
field_values.append((
|
||||
doc_id,
|
||||
field.get('field_name'),
|
||||
field.get('csv_value'),
|
||||
field.get('matched'),
|
||||
field.get('score'),
|
||||
field.get('matched_text'),
|
||||
field.get('candidate_used'),
|
||||
json.dumps(field.get('bbox')) if field.get('bbox') else None,
|
||||
field.get('page_no'),
|
||||
json.dumps(field.get('context_keywords', [])),
|
||||
field.get('error')
|
||||
))
|
||||
|
||||
if field_values:
|
||||
execute_values(cursor, """
|
||||
INSERT INTO field_results
|
||||
(document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error)
|
||||
VALUES %s
|
||||
""", field_values)
|
||||
|
||||
conn.commit()
|
||||
102
packages/shared/shared/exceptions.py
Normal file
102
packages/shared/shared/exceptions.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Application-specific exceptions for invoice extraction system.
|
||||
|
||||
This module defines a hierarchy of custom exceptions to provide better
|
||||
error handling and debugging capabilities throughout the application.
|
||||
"""
|
||||
|
||||
|
||||
class InvoiceExtractionError(Exception):
|
||||
"""Base exception for all invoice extraction errors."""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
"""
|
||||
Initialize exception with message and optional details.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
details: Optional dict with additional error context
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self):
|
||||
if self.details:
|
||||
details_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
|
||||
return f"{self.message} ({details_str})"
|
||||
return self.message
|
||||
|
||||
|
||||
class PDFProcessingError(InvoiceExtractionError):
|
||||
"""Error during PDF processing (rendering, conversion)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OCRError(InvoiceExtractionError):
|
||||
"""Error during OCR processing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelInferenceError(InvoiceExtractionError):
|
||||
"""Error during YOLO model inference."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FieldValidationError(InvoiceExtractionError):
|
||||
"""Error during field validation or normalization."""
|
||||
|
||||
def __init__(self, field_name: str, value: str, reason: str, details: dict = None):
|
||||
"""
|
||||
Initialize field validation error.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field that failed validation
|
||||
value: The invalid value
|
||||
reason: Why validation failed
|
||||
details: Additional context
|
||||
"""
|
||||
message = f"Field '{field_name}' validation failed: {reason}"
|
||||
super().__init__(message, details)
|
||||
self.field_name = field_name
|
||||
self.value = value
|
||||
self.reason = reason
|
||||
|
||||
|
||||
class DatabaseError(InvoiceExtractionError):
|
||||
"""Error during database operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(InvoiceExtractionError):
|
||||
"""Error in application configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PaymentLineParseError(InvoiceExtractionError):
|
||||
"""Error parsing Swedish payment line format."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CustomerNumberParseError(InvoiceExtractionError):
|
||||
"""Error parsing Swedish customer number."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DataLoadError(InvoiceExtractionError):
|
||||
"""Error loading data from CSV or other sources."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AnnotationError(InvoiceExtractionError):
|
||||
"""Error generating or processing YOLO annotations."""
|
||||
|
||||
pass
|
||||
4
packages/shared/shared/matcher/__init__.py
Normal file
4
packages/shared/shared/matcher/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .field_matcher import FieldMatcher, find_field_matches
|
||||
from .models import Match, TokenLike
|
||||
|
||||
__all__ = ['FieldMatcher', 'Match', 'TokenLike', 'find_field_matches']
|
||||
92
packages/shared/shared/matcher/context.py
Normal file
92
packages/shared/shared/matcher/context.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Context keywords for field matching.
|
||||
"""
|
||||
|
||||
from .models import TokenLike
|
||||
from .token_index import TokenIndex
|
||||
|
||||
|
||||
# Context keywords for each field type (Swedish invoice terms)
|
||||
CONTEXT_KEYWORDS = {
|
||||
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
|
||||
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
|
||||
'förfallodag', 'oss tillhanda senast', 'senast'],
|
||||
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
|
||||
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
|
||||
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
|
||||
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
|
||||
}
|
||||
|
||||
|
||||
def find_context_keywords(
|
||||
tokens: list[TokenLike],
|
||||
target_token: TokenLike,
|
||||
field_name: str,
|
||||
context_radius: float,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> tuple[list[str], float]:
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
|
||||
Args:
|
||||
tokens: List of all tokens
|
||||
target_token: The token to find context for
|
||||
field_name: Name of the field
|
||||
context_radius: Search radius in pixels
|
||||
token_index: Optional spatial index for efficient lookup
|
||||
|
||||
Returns:
|
||||
Tuple of (found_keywords, boost_score)
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if token_index:
|
||||
nearby_tokens = token_index.find_nearby(target_token, context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= context_radius:
|
||||
token_lower = token.text.lower()
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
219
packages/shared/shared/matcher/field_matcher.py
Normal file
219
packages/shared/shared/matcher/field_matcher.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Field Matching Module - Refactored
|
||||
|
||||
Matches normalized field values to tokens extracted from documents.
|
||||
"""
|
||||
|
||||
from .models import TokenLike, Match
|
||||
from .token_index import TokenIndex
|
||||
from .utils import bbox_overlap
|
||||
from .strategies import (
|
||||
ExactMatcher,
|
||||
ConcatenatedMatcher,
|
||||
SubstringMatcher,
|
||||
FuzzyMatcher,
|
||||
FlexibleDateMatcher,
|
||||
)
|
||||
|
||||
|
||||
class FieldMatcher:
|
||||
"""Matches field values to document tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
|
||||
min_score_threshold: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize the matcher.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords (default 200px to handle
|
||||
typical label-value spacing in scanned invoices at 150 DPI)
|
||||
min_score_threshold: Minimum score to consider a match valid
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
self.min_score_threshold = min_score_threshold
|
||||
self._token_index: TokenIndex | None = None
|
||||
|
||||
# Initialize matching strategies
|
||||
self.exact_matcher = ExactMatcher(context_radius)
|
||||
self.concatenated_matcher = ConcatenatedMatcher(context_radius)
|
||||
self.substring_matcher = SubstringMatcher(context_radius)
|
||||
self.fuzzy_matcher = FuzzyMatcher(context_radius)
|
||||
self.flexible_date_matcher = FlexibleDateMatcher(context_radius)
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
field_name: str,
|
||||
normalized_values: list[str],
|
||||
page_no: int = 0
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find all matches for a field in the token list.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens from the document
|
||||
field_name: Name of the field to match
|
||||
normalized_values: List of normalized value variants to search for
|
||||
page_no: Page number to filter tokens
|
||||
|
||||
Returns:
|
||||
List of Match objects sorted by score (descending)
|
||||
"""
|
||||
matches = []
|
||||
# Filter tokens by page and exclude hidden metadata tokens
|
||||
# Hidden tokens often have bbox with y < 0 or y > page_height
|
||||
# These are typically PDF metadata stored as invisible text
|
||||
page_tokens = [
|
||||
t for t in tokens
|
||||
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
|
||||
]
|
||||
|
||||
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
|
||||
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
|
||||
|
||||
for value in normalized_values:
|
||||
# Strategy 1: Exact token match
|
||||
exact_matches = self.exact_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(exact_matches)
|
||||
|
||||
# Strategy 2: Multi-token concatenation
|
||||
concat_matches = self.concatenated_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(concat_matches)
|
||||
|
||||
# Strategy 3: Fuzzy match (for amounts and dates only)
|
||||
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
|
||||
fuzzy_matches = self.fuzzy_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(fuzzy_matches)
|
||||
|
||||
# Strategy 4: Substring match (for values embedded in longer text)
|
||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||
# in OCR payment lines or other unrelated text
|
||||
if field_name in (
|
||||
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
|
||||
'Bankgiro', 'Plusgiro', 'supplier_organisation_number',
|
||||
'supplier_accounts', 'customer_number'
|
||||
):
|
||||
substring_matches = self.substring_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
|
||||
# Only if no exact matches found for date fields
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
|
||||
for value in normalized_values:
|
||||
flexible_matches = self.flexible_date_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(flexible_matches)
|
||||
|
||||
# Deduplicate and sort by score
|
||||
matches = self._deduplicate_matches(matches)
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Clear token index to free memory
|
||||
self._token_index = None
|
||||
|
||||
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||
|
||||
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||
"""
|
||||
Remove duplicate matches based on bbox overlap.
|
||||
|
||||
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
|
||||
"""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Sort by: 1) score descending, 2) prefer matches with context keywords,
|
||||
# 3) prefer upper positions (smaller y) for same-score matches
|
||||
# This helps select the "main" occurrence in invoice body rather than footer
|
||||
matches.sort(key=lambda m: (
|
||||
-m.score,
|
||||
-len(m.context_keywords), # More keywords = better
|
||||
m.bbox[1] # Smaller y (upper position) = better
|
||||
))
|
||||
|
||||
# Use spatial grid for efficient overlap checking
|
||||
# Grid cell size based on typical bbox size
|
||||
grid_size = 50.0 # pixels
|
||||
grid: dict[tuple[int, int], list[Match]] = {}
|
||||
unique = []
|
||||
|
||||
for match in matches:
|
||||
bbox = match.bbox
|
||||
# Calculate grid cells this bbox touches
|
||||
min_gx = int(bbox[0] / grid_size)
|
||||
min_gy = int(bbox[1] / grid_size)
|
||||
max_gx = int(bbox[2] / grid_size)
|
||||
max_gy = int(bbox[3] / grid_size)
|
||||
|
||||
# Check for overlap only with matches in nearby grid cells
|
||||
is_duplicate = False
|
||||
cells_to_check = set()
|
||||
for gx in range(min_gx - 1, max_gx + 2):
|
||||
for gy in range(min_gy - 1, max_gy + 2):
|
||||
cells_to_check.add((gx, gy))
|
||||
|
||||
for cell in cells_to_check:
|
||||
if cell in grid:
|
||||
for existing in grid[cell]:
|
||||
if bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
if is_duplicate:
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
unique.append(match)
|
||||
# Add to all grid cells this bbox touches
|
||||
for gx in range(min_gx, max_gx + 1):
|
||||
for gy in range(min_gy, max_gy + 1):
|
||||
key = (gx, gy)
|
||||
if key not in grid:
|
||||
grid[key] = []
|
||||
grid[key].append(match)
|
||||
|
||||
return unique
|
||||
|
||||
|
||||
def find_field_matches(
|
||||
tokens: list[TokenLike],
|
||||
field_values: dict[str, str],
|
||||
page_no: int = 0
|
||||
) -> dict[str, list[Match]]:
|
||||
"""
|
||||
Convenience function to find matches for multiple fields.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens from the document
|
||||
field_values: Dict of field_name -> value to search for
|
||||
page_no: Page number
|
||||
|
||||
Returns:
|
||||
Dict of field_name -> list of matches
|
||||
"""
|
||||
from ..normalize import normalize_field
|
||||
|
||||
matcher = FieldMatcher()
|
||||
results = {}
|
||||
|
||||
for field_name, value in field_values.items():
|
||||
if value is None or str(value).strip() == '':
|
||||
continue
|
||||
|
||||
normalized_values = normalize_field(field_name, str(value))
|
||||
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
|
||||
results[field_name] = matches
|
||||
|
||||
return results
|
||||
875
packages/shared/shared/matcher/field_matcher_old.py
Normal file
875
packages/shared/shared/matcher/field_matcher_old.py
Normal file
@@ -0,0 +1,875 @@
|
||||
"""
|
||||
Field Matching Module
|
||||
|
||||
Matches normalized field values to tokens extracted from documents.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
import re
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
|
||||
|
||||
|
||||
def _normalize_dashes(text: str) -> str:
|
||||
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
|
||||
return _DASH_PATTERN.sub('-', text)
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
"""Protocol for token objects."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
field: str
|
||||
value: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
page_no: int
|
||||
score: float # 0-1 confidence score
|
||||
matched_text: str # Actual text that matched
|
||||
context_keywords: list[str] # Nearby keywords that boosted confidence
|
||||
|
||||
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
|
||||
"""Convert to YOLO annotation format."""
|
||||
x0, y0, x1, y1 = self.bbox
|
||||
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||
|
||||
|
||||
# Context keywords for each field type (Swedish invoice terms)
|
||||
CONTEXT_KEYWORDS = {
|
||||
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
|
||||
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
|
||||
'förfallodag', 'oss tillhanda senast', 'senast'],
|
||||
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
|
||||
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
|
||||
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
|
||||
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
|
||||
}
|
||||
|
||||
|
||||
class FieldMatcher:
|
||||
"""Matches field values to document tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
|
||||
min_score_threshold: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize the matcher.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords (default 200px to handle
|
||||
typical label-value spacing in scanned invoices at 150 DPI)
|
||||
min_score_threshold: Minimum score to consider a match valid
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
self.min_score_threshold = min_score_threshold
|
||||
self._token_index: TokenIndex | None = None
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
field_name: str,
|
||||
normalized_values: list[str],
|
||||
page_no: int = 0
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find all matches for a field in the token list.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens from the document
|
||||
field_name: Name of the field to match
|
||||
normalized_values: List of normalized value variants to search for
|
||||
page_no: Page number to filter tokens
|
||||
|
||||
Returns:
|
||||
List of Match objects sorted by score (descending)
|
||||
"""
|
||||
matches = []
|
||||
# Filter tokens by page and exclude hidden metadata tokens
|
||||
# Hidden tokens often have bbox with y < 0 or y > page_height
|
||||
# These are typically PDF metadata stored as invisible text
|
||||
page_tokens = [
|
||||
t for t in tokens
|
||||
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
|
||||
]
|
||||
|
||||
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
|
||||
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
|
||||
|
||||
for value in normalized_values:
|
||||
# Strategy 1: Exact token match
|
||||
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
|
||||
matches.extend(exact_matches)
|
||||
|
||||
# Strategy 2: Multi-token concatenation
|
||||
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
|
||||
matches.extend(concat_matches)
|
||||
|
||||
# Strategy 3: Fuzzy match (for amounts and dates only)
|
||||
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
|
||||
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
|
||||
matches.extend(fuzzy_matches)
|
||||
|
||||
# Strategy 4: Substring match (for values embedded in longer text)
|
||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||
# in OCR payment lines or other unrelated text
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
|
||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
|
||||
# Only if no exact matches found for date fields
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
|
||||
flexible_matches = self._find_flexible_date_matches(
|
||||
page_tokens, normalized_values, field_name
|
||||
)
|
||||
matches.extend(flexible_matches)
|
||||
|
||||
# Deduplicate and sort by score
|
||||
matches = self._deduplicate_matches(matches)
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Clear token index to free memory
|
||||
self._token_index = None
|
||||
|
||||
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||
|
||||
def _find_exact_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find tokens that exactly match the value."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts') else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif value_digits is not None:
|
||||
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Boost score if context keywords are nearby
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
score = min(1.0, score + context_boost)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=score,
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def _find_concatenated_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
matches = []
|
||||
value_clean = _WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
|
||||
for i, start_token in enumerate(sorted_tokens):
|
||||
# Try to build the value by concatenating nearby tokens
|
||||
concat_text = start_token.text.strip()
|
||||
concat_bbox = list(start_token.bbox)
|
||||
used_tokens = [start_token]
|
||||
|
||||
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
|
||||
next_token = sorted_tokens[j]
|
||||
|
||||
# Check if tokens are on the same line (y overlap)
|
||||
if not self._tokens_on_same_line(start_token, next_token):
|
||||
break
|
||||
|
||||
# Check horizontal proximity
|
||||
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
|
||||
break
|
||||
|
||||
concat_text += next_token.text.strip()
|
||||
used_tokens.append(next_token)
|
||||
|
||||
# Update bounding box
|
||||
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
|
||||
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
|
||||
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, start_token, field_name
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=tuple(concat_bbox),
|
||||
page_no=start_token.page_no,
|
||||
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
|
||||
matched_text=concat_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
break
|
||||
|
||||
return matches
|
||||
|
||||
def _find_substring_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find value as a substring within longer tokens.
|
||||
|
||||
Handles cases like:
|
||||
- 'Fakturadatum: 2026-01-09' where the date is embedded
|
||||
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
|
||||
- 'OCR: 1234567890' where reference number is embedded
|
||||
|
||||
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
|
||||
Only matches if the value appears as a distinct segment (not part of a larger number).
|
||||
"""
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
# Fields where spaces/dashes should be ignored during matching
|
||||
# (e.g., org number "55 65 74-6624" should match "5565746624")
|
||||
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
# Normalize different dash types to hyphen-minus for matching
|
||||
token_text_normalized = _normalize_dashes(token_text)
|
||||
|
||||
# For certain fields, also try matching with spaces/dashes removed
|
||||
if field_name in ignore_spaces_fields:
|
||||
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
|
||||
value_compact = value.replace(' ', '').replace('-', '')
|
||||
else:
|
||||
token_text_compact = None
|
||||
value_compact = None
|
||||
|
||||
# Skip if token is the same length as value (would be exact match)
|
||||
if len(token_text_normalized) <= len(value):
|
||||
continue
|
||||
|
||||
# Check if value appears as substring (using normalized text)
|
||||
# Try case-sensitive first, then case-insensitive
|
||||
idx = None
|
||||
case_sensitive_match = True
|
||||
used_compact = False
|
||||
|
||||
if value in token_text_normalized:
|
||||
idx = token_text_normalized.find(value)
|
||||
elif value.lower() in token_text_normalized.lower():
|
||||
idx = token_text_normalized.lower().find(value.lower())
|
||||
case_sensitive_match = False
|
||||
elif token_text_compact and value_compact in token_text_compact:
|
||||
# Try compact matching (spaces/dashes removed)
|
||||
idx = token_text_compact.find(value_compact)
|
||||
used_compact = True
|
||||
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
|
||||
idx = token_text_compact.lower().find(value_compact.lower())
|
||||
case_sensitive_match = False
|
||||
used_compact = True
|
||||
|
||||
if idx is None:
|
||||
continue
|
||||
|
||||
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
|
||||
if used_compact:
|
||||
# Verify proper boundary in compact text
|
||||
if idx > 0 and token_text_compact[idx - 1].isdigit():
|
||||
continue
|
||||
end_idx = idx + len(value_compact)
|
||||
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
|
||||
continue
|
||||
else:
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
# Check character before (if exists)
|
||||
if idx > 0:
|
||||
char_before = token_text_normalized[idx - 1]
|
||||
# Must be non-digit (allow : space - etc)
|
||||
if char_before.isdigit():
|
||||
continue
|
||||
|
||||
# Check character after (if exists)
|
||||
end_idx = idx + len(value)
|
||||
if end_idx < len(token_text_normalized):
|
||||
char_after = token_text_normalized[end_idx]
|
||||
# Must be non-digit
|
||||
if char_after.isdigit():
|
||||
continue
|
||||
|
||||
# Found valid substring match
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
# Check if context keyword is in the same token (like "Fakturadatum:")
|
||||
token_lower = token_text.lower()
|
||||
inline_context = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_context.append(keyword)
|
||||
|
||||
# Boost score if keyword is inline
|
||||
inline_boost = 0.1 if inline_context else 0
|
||||
|
||||
# Lower score for case-insensitive match
|
||||
base_score = 0.75 if case_sensitive_match else 0.70
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox, # Use full token bbox
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, base_score + context_boost + inline_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords + inline_context
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def _find_fuzzy_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find approximate matches for amounts and dates."""
|
||||
matches = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
if field_name == 'Amount':
|
||||
# Try to parse both as numbers
|
||||
try:
|
||||
token_num = self._parse_amount(token_text)
|
||||
value_num = self._parse_amount(value)
|
||||
|
||||
if token_num is not None and value_num is not None:
|
||||
if abs(token_num - value_num) < 0.01: # Within 1 cent
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, 0.8 + context_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
||||
def _find_flexible_date_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
normalized_values: list[str],
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Flexible date matching when exact match fails.
|
||||
|
||||
Strategies:
|
||||
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
|
||||
2. Nearby date match: Match dates within 7 days of CSV value
|
||||
3. Heuristic selection: Use context keywords to select the best date
|
||||
|
||||
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
|
||||
but we can still find a reasonable date to label.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
matches = []
|
||||
|
||||
# Parse the target date from normalized values
|
||||
target_date = None
|
||||
for value in normalized_values:
|
||||
# Try to parse YYYY-MM-DD format
|
||||
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
|
||||
if date_match:
|
||||
try:
|
||||
target_date = datetime(
|
||||
int(date_match.group(1)),
|
||||
int(date_match.group(2)),
|
||||
int(date_match.group(3))
|
||||
)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not target_date:
|
||||
return matches
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in _DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))
|
||||
)
|
||||
date_str = match.group(0)
|
||||
|
||||
# Calculate date difference
|
||||
days_diff = abs((found_date - target_date).days)
|
||||
|
||||
# Check for context keywords
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
# Check if keyword is in the same token
|
||||
token_lower = token_text.lower()
|
||||
inline_keywords = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_keywords.append(keyword)
|
||||
|
||||
date_candidates.append({
|
||||
'token': token,
|
||||
'date': found_date,
|
||||
'date_str': date_str,
|
||||
'matched_text': token_text,
|
||||
'days_diff': days_diff,
|
||||
'context_keywords': context_keywords + inline_keywords,
|
||||
'context_boost': context_boost + (0.1 if inline_keywords else 0),
|
||||
'same_year_month': (found_date.year == target_date.year and
|
||||
found_date.month == target_date.month),
|
||||
})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not date_candidates:
|
||||
return matches
|
||||
|
||||
# Score and rank candidates
|
||||
for candidate in date_candidates:
|
||||
score = 0.0
|
||||
|
||||
# Strategy 1: Same year-month gets higher score
|
||||
if candidate['same_year_month']:
|
||||
score = 0.7
|
||||
# Bonus if day is close
|
||||
if candidate['days_diff'] <= 7:
|
||||
score = 0.75
|
||||
if candidate['days_diff'] <= 3:
|
||||
score = 0.8
|
||||
# Strategy 2: Nearby dates (within 14 days)
|
||||
elif candidate['days_diff'] <= 14:
|
||||
score = 0.6
|
||||
elif candidate['days_diff'] <= 30:
|
||||
score = 0.55
|
||||
else:
|
||||
# Too far apart, skip unless has strong context
|
||||
if not candidate['context_keywords']:
|
||||
continue
|
||||
score = 0.5
|
||||
|
||||
# Strategy 3: Boost with context keywords
|
||||
score = min(1.0, score + candidate['context_boost'])
|
||||
|
||||
# For InvoiceDate, prefer dates that appear near invoice-related keywords
|
||||
# For InvoiceDueDate, prefer dates near due-date keywords
|
||||
if candidate['context_keywords']:
|
||||
score = min(1.0, score + 0.05)
|
||||
|
||||
if score >= self.min_score_threshold:
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=candidate['date_str'],
|
||||
bbox=candidate['token'].bbox,
|
||||
page_no=candidate['token'].page_no,
|
||||
score=score,
|
||||
matched_text=candidate['matched_text'],
|
||||
context_keywords=candidate['context_keywords']
|
||||
))
|
||||
|
||||
# Sort by score and return best matches
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Only return the best match to avoid multiple labels for same field
|
||||
return matches[:1] if matches else []
|
||||
|
||||
def _find_context_keywords(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
target_token: TokenLike,
|
||||
field_name: str
|
||||
) -> tuple[list[str], float]:
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if self._token_index:
|
||||
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = self._token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= self.context_radius:
|
||||
token_lower = token.text.lower()
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
|
||||
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
|
||||
"""Check if two tokens are on the same line."""
|
||||
# Check vertical overlap
|
||||
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
|
||||
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
|
||||
return y_overlap > min_height * 0.5
|
||||
|
||||
def _parse_amount(self, text: str | int | float) -> float | None:
|
||||
"""Try to parse text as a monetary amount."""
|
||||
# Convert to string first
|
||||
text = str(text)
|
||||
|
||||
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
|
||||
if ore_match:
|
||||
kronor = ore_match.group(1)
|
||||
ore = ore_match.group(2)
|
||||
try:
|
||||
return float(f"{kronor}.{ore}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
|
||||
text = re.sub(r'\s*\(.*\)', '', text)
|
||||
|
||||
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
|
||||
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'[:-]', '', text)
|
||||
|
||||
# Remove spaces (thousand separators) but be careful with öre format
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
# Swedish format: "500,00" means 500.00
|
||||
# Need to handle cases like "500,00." (after removing "kr.")
|
||||
if ',' in text:
|
||||
# Remove any trailing dots first (from "kr." removal)
|
||||
text = text.rstrip('.')
|
||||
# Now replace comma with dot
|
||||
if '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
|
||||
# Remove any remaining non-numeric characters except dot
|
||||
text = re.sub(r'[^\d.]', '', text)
|
||||
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||
"""
|
||||
Remove duplicate matches based on bbox overlap.
|
||||
|
||||
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
|
||||
"""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Sort by: 1) score descending, 2) prefer matches with context keywords,
|
||||
# 3) prefer upper positions (smaller y) for same-score matches
|
||||
# This helps select the "main" occurrence in invoice body rather than footer
|
||||
matches.sort(key=lambda m: (
|
||||
-m.score,
|
||||
-len(m.context_keywords), # More keywords = better
|
||||
m.bbox[1] # Smaller y (upper position) = better
|
||||
))
|
||||
|
||||
# Use spatial grid for efficient overlap checking
|
||||
# Grid cell size based on typical bbox size
|
||||
grid_size = 50.0 # pixels
|
||||
grid: dict[tuple[int, int], list[Match]] = {}
|
||||
unique = []
|
||||
|
||||
for match in matches:
|
||||
bbox = match.bbox
|
||||
# Calculate grid cells this bbox touches
|
||||
min_gx = int(bbox[0] / grid_size)
|
||||
min_gy = int(bbox[1] / grid_size)
|
||||
max_gx = int(bbox[2] / grid_size)
|
||||
max_gy = int(bbox[3] / grid_size)
|
||||
|
||||
# Check for overlap only with matches in nearby grid cells
|
||||
is_duplicate = False
|
||||
cells_to_check = set()
|
||||
for gx in range(min_gx - 1, max_gx + 2):
|
||||
for gy in range(min_gy - 1, max_gy + 2):
|
||||
cells_to_check.add((gx, gy))
|
||||
|
||||
for cell in cells_to_check:
|
||||
if cell in grid:
|
||||
for existing in grid[cell]:
|
||||
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
if is_duplicate:
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
unique.append(match)
|
||||
# Add to all grid cells this bbox touches
|
||||
for gx in range(min_gx, max_gx + 1):
|
||||
for gy in range(min_gy, max_gy + 1):
|
||||
key = (gx, gy)
|
||||
if key not in grid:
|
||||
grid[key] = []
|
||||
grid[key].append(match)
|
||||
|
||||
return unique
|
||||
|
||||
def _bbox_overlap(
|
||||
self,
|
||||
bbox1: tuple[float, float, float, float],
|
||||
bbox2: tuple[float, float, float, float]
|
||||
) -> float:
|
||||
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
|
||||
x1 = max(bbox1[0], bbox2[0])
|
||||
y1 = max(bbox1[1], bbox2[1])
|
||||
x2 = min(bbox1[2], bbox2[2])
|
||||
y2 = min(bbox1[3], bbox2[3])
|
||||
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def find_field_matches(
|
||||
tokens: list[TokenLike],
|
||||
field_values: dict[str, str],
|
||||
page_no: int = 0
|
||||
) -> dict[str, list[Match]]:
|
||||
"""
|
||||
Convenience function to find matches for multiple fields.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens from the document
|
||||
field_values: Dict of field_name -> value to search for
|
||||
page_no: Page number
|
||||
|
||||
Returns:
|
||||
Dict of field_name -> list of matches
|
||||
"""
|
||||
from ..normalize import normalize_field
|
||||
|
||||
matcher = FieldMatcher()
|
||||
results = {}
|
||||
|
||||
for field_name, value in field_values.items():
|
||||
if value is None or str(value).strip() == '':
|
||||
continue
|
||||
|
||||
normalized_values = normalize_field(field_name, str(value))
|
||||
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
|
||||
results[field_name] = matches
|
||||
|
||||
return results
|
||||
36
packages/shared/shared/matcher/models.py
Normal file
36
packages/shared/shared/matcher/models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Data models for field matching.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
"""Protocol for token objects."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
field: str
|
||||
value: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
page_no: int
|
||||
score: float # 0-1 confidence score
|
||||
matched_text: str # Actual text that matched
|
||||
context_keywords: list[str] # Nearby keywords that boosted confidence
|
||||
|
||||
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
|
||||
"""Convert to YOLO annotation format."""
|
||||
x0, y0, x1, y1 = self.bbox
|
||||
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||
17
packages/shared/shared/matcher/strategies/__init__.py
Normal file
17
packages/shared/shared/matcher/strategies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Matching strategies for field matching.
|
||||
"""
|
||||
|
||||
from .exact_matcher import ExactMatcher
|
||||
from .concatenated_matcher import ConcatenatedMatcher
|
||||
from .substring_matcher import SubstringMatcher
|
||||
from .fuzzy_matcher import FuzzyMatcher
|
||||
from .flexible_date_matcher import FlexibleDateMatcher
|
||||
|
||||
__all__ = [
|
||||
'ExactMatcher',
|
||||
'ConcatenatedMatcher',
|
||||
'SubstringMatcher',
|
||||
'FuzzyMatcher',
|
||||
'FlexibleDateMatcher',
|
||||
]
|
||||
42
packages/shared/shared/matcher/strategies/base.py
Normal file
42
packages/shared/shared/matcher/strategies/base.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Base class for matching strategies.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
|
||||
|
||||
class BaseMatchStrategy(ABC):
|
||||
"""Base class for all matching strategies."""
|
||||
|
||||
def __init__(self, context_radius: float = 200.0):
|
||||
"""
|
||||
Initialize the strategy.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
|
||||
@abstractmethod
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find matches for the given value.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to search
|
||||
value: Value to find
|
||||
field_name: Name of the field
|
||||
token_index: Optional spatial index for efficient lookup
|
||||
|
||||
Returns:
|
||||
List of Match objects
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Concatenated match strategy - finds value by concatenating adjacent tokens.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords
|
||||
from ..utils import WHITESPACE_PATTERN, tokens_on_same_line
|
||||
|
||||
|
||||
class ConcatenatedMatcher(BaseMatchStrategy):
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find concatenated matches."""
|
||||
matches = []
|
||||
value_clean = WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
|
||||
for i, start_token in enumerate(sorted_tokens):
|
||||
# Try to build the value by concatenating nearby tokens
|
||||
concat_text = start_token.text.strip()
|
||||
concat_bbox = list(start_token.bbox)
|
||||
used_tokens = [start_token]
|
||||
|
||||
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
|
||||
next_token = sorted_tokens[j]
|
||||
|
||||
# Check if tokens are on the same line (y overlap)
|
||||
if not tokens_on_same_line(start_token, next_token):
|
||||
break
|
||||
|
||||
# Check horizontal proximity
|
||||
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
|
||||
break
|
||||
|
||||
concat_text += next_token.text.strip()
|
||||
used_tokens.append(next_token)
|
||||
|
||||
# Update bounding box
|
||||
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
|
||||
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
|
||||
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, start_token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=tuple(concat_bbox),
|
||||
page_no=start_token.page_no,
|
||||
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
|
||||
matched_text=concat_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
break
|
||||
|
||||
return matches
|
||||
65
packages/shared/shared/matcher/strategies/exact_matcher.py
Normal file
65
packages/shared/shared/matcher/strategies/exact_matcher.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Exact match strategy.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords
|
||||
from ..utils import NON_DIGIT_PATTERN
|
||||
|
||||
|
||||
class ExactMatcher(BaseMatchStrategy):
|
||||
"""Find tokens that exactly match the value."""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find exact matches."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = NON_DIGIT_PATTERN.sub('', value) if field_name in (
|
||||
'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts'
|
||||
) else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif token_index and token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif value_digits is not None:
|
||||
token_digits = NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Boost score if context keywords are nearby
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
score = min(1.0, score + context_boost)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=score,
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
|
||||
return matches
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Flexible date match strategy - finds dates with year-month or nearby date matching.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords, CONTEXT_KEYWORDS
|
||||
from ..utils import DATE_PATTERN
|
||||
|
||||
|
||||
class FlexibleDateMatcher(BaseMatchStrategy):
|
||||
"""
|
||||
Flexible date matching when exact match fails.
|
||||
|
||||
Strategies:
|
||||
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
|
||||
2. Nearby date match: Match dates within 7 days of CSV value
|
||||
3. Heuristic selection: Use context keywords to select the best date
|
||||
|
||||
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
|
||||
but we can still find a reasonable date to label.
|
||||
"""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find flexible date matches."""
|
||||
matches = []
|
||||
|
||||
# Parse the target date from normalized values
|
||||
target_date = None
|
||||
|
||||
# Try to parse YYYY-MM-DD format
|
||||
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
|
||||
if date_match:
|
||||
try:
|
||||
target_date = datetime(
|
||||
int(date_match.group(1)),
|
||||
int(date_match.group(2)),
|
||||
int(date_match.group(3))
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not target_date:
|
||||
return matches
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))
|
||||
)
|
||||
date_str = match.group(0)
|
||||
|
||||
# Calculate date difference
|
||||
days_diff = abs((found_date - target_date).days)
|
||||
|
||||
# Check for context keywords
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
# Check if keyword is in the same token
|
||||
token_lower = token_text.lower()
|
||||
inline_keywords = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_keywords.append(keyword)
|
||||
|
||||
date_candidates.append({
|
||||
'token': token,
|
||||
'date': found_date,
|
||||
'date_str': date_str,
|
||||
'matched_text': token_text,
|
||||
'days_diff': days_diff,
|
||||
'context_keywords': context_keywords + inline_keywords,
|
||||
'context_boost': context_boost + (0.1 if inline_keywords else 0),
|
||||
'same_year_month': (found_date.year == target_date.year and
|
||||
found_date.month == target_date.month),
|
||||
})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not date_candidates:
|
||||
return matches
|
||||
|
||||
# Score and rank candidates
|
||||
for candidate in date_candidates:
|
||||
score = 0.0
|
||||
|
||||
# Strategy 1: Same year-month gets higher score
|
||||
if candidate['same_year_month']:
|
||||
score = 0.7
|
||||
# Bonus if day is close
|
||||
if candidate['days_diff'] <= 7:
|
||||
score = 0.75
|
||||
if candidate['days_diff'] <= 3:
|
||||
score = 0.8
|
||||
# Strategy 2: Nearby dates (within 14 days)
|
||||
elif candidate['days_diff'] <= 14:
|
||||
score = 0.6
|
||||
elif candidate['days_diff'] <= 30:
|
||||
score = 0.55
|
||||
else:
|
||||
# Too far apart, skip unless has strong context
|
||||
if not candidate['context_keywords']:
|
||||
continue
|
||||
score = 0.5
|
||||
|
||||
# Strategy 3: Boost with context keywords
|
||||
score = min(1.0, score + candidate['context_boost'])
|
||||
|
||||
# For InvoiceDate, prefer dates that appear near invoice-related keywords
|
||||
# For InvoiceDueDate, prefer dates near due-date keywords
|
||||
if candidate['context_keywords']:
|
||||
score = min(1.0, score + 0.05)
|
||||
|
||||
if score >= 0.5: # Min threshold for flexible matching
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=candidate['date_str'],
|
||||
bbox=candidate['token'].bbox,
|
||||
page_no=candidate['token'].page_no,
|
||||
score=score,
|
||||
matched_text=candidate['matched_text'],
|
||||
context_keywords=candidate['context_keywords']
|
||||
))
|
||||
|
||||
# Sort by score and return best matches
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Only return the best match to avoid multiple labels for same field
|
||||
return matches[:1] if matches else []
|
||||
52
packages/shared/shared/matcher/strategies/fuzzy_matcher.py
Normal file
52
packages/shared/shared/matcher/strategies/fuzzy_matcher.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Fuzzy match strategy for amounts and dates.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords
|
||||
from ..utils import parse_amount
|
||||
|
||||
|
||||
class FuzzyMatcher(BaseMatchStrategy):
|
||||
"""Find approximate matches for amounts and dates."""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find fuzzy matches."""
|
||||
matches = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
if field_name == 'Amount':
|
||||
# Try to parse both as numbers
|
||||
try:
|
||||
token_num = parse_amount(token_text)
|
||||
value_num = parse_amount(value)
|
||||
|
||||
if token_num is not None and value_num is not None:
|
||||
if abs(token_num - value_num) < 0.01: # Within 1 cent
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, 0.8 + context_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
return matches
|
||||
143
packages/shared/shared/matcher/strategies/substring_matcher.py
Normal file
143
packages/shared/shared/matcher/strategies/substring_matcher.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Substring match strategy - finds value as substring within longer tokens.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords, CONTEXT_KEYWORDS
|
||||
from ..utils import normalize_dashes
|
||||
|
||||
|
||||
class SubstringMatcher(BaseMatchStrategy):
|
||||
"""
|
||||
Find value as a substring within longer tokens.
|
||||
|
||||
Handles cases like:
|
||||
- 'Fakturadatum: 2026-01-09' where the date is embedded
|
||||
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
|
||||
- 'OCR: 1234567890' where reference number is embedded
|
||||
|
||||
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
|
||||
Only matches if the value appears as a distinct segment (not part of a larger number).
|
||||
"""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find substring matches."""
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = (
|
||||
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
|
||||
'Bankgiro', 'Plusgiro', 'Amount',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number'
|
||||
)
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
# Fields where spaces/dashes should be ignored during matching
|
||||
# (e.g., org number "55 65 74-6624" should match "5565746624")
|
||||
ignore_spaces_fields = (
|
||||
'supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts'
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
# Normalize different dash types to hyphen-minus for matching
|
||||
token_text_normalized = normalize_dashes(token_text)
|
||||
|
||||
# For certain fields, also try matching with spaces/dashes removed
|
||||
if field_name in ignore_spaces_fields:
|
||||
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
|
||||
value_compact = value.replace(' ', '').replace('-', '')
|
||||
else:
|
||||
token_text_compact = None
|
||||
value_compact = None
|
||||
|
||||
# Skip if token is the same length as value (would be exact match)
|
||||
if len(token_text_normalized) <= len(value):
|
||||
continue
|
||||
|
||||
# Check if value appears as substring (using normalized text)
|
||||
# Try case-sensitive first, then case-insensitive
|
||||
idx = None
|
||||
case_sensitive_match = True
|
||||
used_compact = False
|
||||
|
||||
if value in token_text_normalized:
|
||||
idx = token_text_normalized.find(value)
|
||||
elif value.lower() in token_text_normalized.lower():
|
||||
idx = token_text_normalized.lower().find(value.lower())
|
||||
case_sensitive_match = False
|
||||
elif token_text_compact and value_compact in token_text_compact:
|
||||
# Try compact matching (spaces/dashes removed)
|
||||
idx = token_text_compact.find(value_compact)
|
||||
used_compact = True
|
||||
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
|
||||
idx = token_text_compact.lower().find(value_compact.lower())
|
||||
case_sensitive_match = False
|
||||
used_compact = True
|
||||
|
||||
if idx is None:
|
||||
continue
|
||||
|
||||
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
|
||||
if used_compact:
|
||||
# Verify proper boundary in compact text
|
||||
if idx > 0 and token_text_compact[idx - 1].isdigit():
|
||||
continue
|
||||
end_idx = idx + len(value_compact)
|
||||
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
|
||||
continue
|
||||
else:
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
# Check character before (if exists)
|
||||
if idx > 0:
|
||||
char_before = token_text_normalized[idx - 1]
|
||||
# Must be non-digit (allow : space - etc)
|
||||
if char_before.isdigit():
|
||||
continue
|
||||
|
||||
# Check character after (if exists)
|
||||
end_idx = idx + len(value)
|
||||
if end_idx < len(token_text_normalized):
|
||||
char_after = token_text_normalized[end_idx]
|
||||
# Must be non-digit
|
||||
if char_after.isdigit():
|
||||
continue
|
||||
|
||||
# Found valid substring match
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
# Check if context keyword is in the same token (like "Fakturadatum:")
|
||||
token_lower = token_text.lower()
|
||||
inline_context = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_context.append(keyword)
|
||||
|
||||
# Boost score if keyword is inline
|
||||
inline_boost = 0.1 if inline_context else 0
|
||||
|
||||
# Lower score for case-insensitive match
|
||||
base_score = 0.75 if case_sensitive_match else 0.70
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox, # Use full token bbox
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, base_score + context_boost + inline_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords + inline_context
|
||||
))
|
||||
|
||||
return matches
|
||||
92
packages/shared/shared/matcher/token_index.py
Normal file
92
packages/shared/shared/matcher/token_index.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Spatial index for fast token lookup.
|
||||
"""
|
||||
|
||||
from .models import TokenLike
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
91
packages/shared/shared/matcher/utils.py
Normal file
91
packages/shared/shared/matcher/utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Utility functions for field matching.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
|
||||
|
||||
|
||||
def normalize_dashes(text: str) -> str:
|
||||
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
|
||||
return DASH_PATTERN.sub('-', text)
|
||||
|
||||
|
||||
def parse_amount(text: str | int | float) -> float | None:
|
||||
"""Try to parse text as a monetary amount."""
|
||||
# Convert to string first
|
||||
text = str(text)
|
||||
|
||||
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
|
||||
if ore_match:
|
||||
kronor = ore_match.group(1)
|
||||
ore = ore_match.group(2)
|
||||
try:
|
||||
return float(f"{kronor}.{ore}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
|
||||
text = re.sub(r'\s*\(.*\)', '', text)
|
||||
|
||||
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
|
||||
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'[:-]', '', text)
|
||||
|
||||
# Remove spaces (thousand separators) but be careful with öre format
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
# Swedish format: "500,00" means 500.00
|
||||
# Need to handle cases like "500,00." (after removing "kr.")
|
||||
if ',' in text:
|
||||
# Remove any trailing dots first (from "kr." removal)
|
||||
text = text.rstrip('.')
|
||||
# Now replace comma with dot
|
||||
if '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
|
||||
# Remove any remaining non-numeric characters except dot
|
||||
text = re.sub(r'[^\d.]', '', text)
|
||||
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def tokens_on_same_line(token1, token2) -> bool:
|
||||
"""Check if two tokens are on the same line."""
|
||||
# Check vertical overlap
|
||||
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
|
||||
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
|
||||
return y_overlap > min_height * 0.5
|
||||
|
||||
|
||||
def bbox_overlap(
|
||||
bbox1: tuple[float, float, float, float],
|
||||
bbox2: tuple[float, float, float, float]
|
||||
) -> float:
|
||||
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
|
||||
x1 = max(bbox1[0], bbox2[0])
|
||||
y1 = max(bbox1[1], bbox2[1])
|
||||
x2 = min(bbox1[2], bbox2[2])
|
||||
y2 = min(bbox1[3], bbox2[3])
|
||||
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
3
packages/shared/shared/normalize/__init__.py
Normal file
3
packages/shared/shared/normalize/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .normalizer import normalize_field, FieldNormalizer
|
||||
|
||||
__all__ = ['normalize_field', 'FieldNormalizer']
|
||||
186
packages/shared/shared/normalize/normalizer.py
Normal file
186
packages/shared/shared/normalize/normalizer.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Field Normalization Module
|
||||
|
||||
Normalizes field values to generate multiple candidate forms for matching.
|
||||
|
||||
This module now delegates to individual normalizer modules for each field type.
|
||||
Each normalizer is a separate, reusable module that can be used independently.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
# Import individual normalizers
|
||||
from .normalizers import (
|
||||
InvoiceNumberNormalizer,
|
||||
OCRNormalizer,
|
||||
BankgiroNormalizer,
|
||||
PlusgiroNormalizer,
|
||||
AmountNormalizer,
|
||||
DateNormalizer,
|
||||
OrganisationNumberNormalizer,
|
||||
SupplierAccountsNormalizer,
|
||||
CustomerNumberNormalizer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalizedValue:
|
||||
"""Represents a normalized value with its variants."""
|
||||
original: str
|
||||
variants: list[str]
|
||||
field_type: str
|
||||
|
||||
|
||||
class FieldNormalizer:
|
||||
"""
|
||||
Handles normalization of different invoice field types.
|
||||
|
||||
This class now acts as a facade that delegates to individual
|
||||
normalizer modules. Each field type has its own specialized
|
||||
normalizer for better modularity and reusability.
|
||||
"""
|
||||
|
||||
# Instantiate individual normalizers
|
||||
_invoice_number = InvoiceNumberNormalizer()
|
||||
_ocr_number = OCRNormalizer()
|
||||
_bankgiro = BankgiroNormalizer()
|
||||
_plusgiro = PlusgiroNormalizer()
|
||||
_amount = AmountNormalizer()
|
||||
_date = DateNormalizer()
|
||||
_organisation_number = OrganisationNumberNormalizer()
|
||||
_supplier_accounts = SupplierAccountsNormalizer()
|
||||
_customer_number = CustomerNumberNormalizer()
|
||||
|
||||
# Common Swedish month names for backward compatibility
|
||||
SWEDISH_MONTHS = DateNormalizer.SWEDISH_MONTHS
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""
|
||||
Remove invisible characters and normalize whitespace and dashes.
|
||||
|
||||
Delegates to shared TextCleaner for consistency.
|
||||
"""
|
||||
return TextCleaner.clean_text(text)
|
||||
|
||||
@staticmethod
|
||||
def normalize_invoice_number(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize invoice number.
|
||||
|
||||
Delegates to InvoiceNumberNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._invoice_number.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_ocr_number(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize OCR number (Swedish payment reference).
|
||||
|
||||
Delegates to OCRNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._ocr_number.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_bankgiro(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize Bankgiro number.
|
||||
|
||||
Delegates to BankgiroNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._bankgiro.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_plusgiro(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize Plusgiro number.
|
||||
|
||||
Delegates to PlusgiroNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._plusgiro.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_organisation_number(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize Swedish organisation number and generate VAT number variants.
|
||||
|
||||
Delegates to OrganisationNumberNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._organisation_number.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_supplier_accounts(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize supplier accounts field.
|
||||
|
||||
Delegates to SupplierAccountsNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._supplier_accounts.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_customer_number(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize customer number.
|
||||
|
||||
Delegates to CustomerNumberNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._customer_number.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_amount(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize monetary amount.
|
||||
|
||||
Delegates to AmountNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._amount.normalize(value)
|
||||
|
||||
@staticmethod
|
||||
def normalize_date(value: str) -> list[str]:
|
||||
"""
|
||||
Normalize date to YYYY-MM-DD and generate variants.
|
||||
|
||||
Delegates to DateNormalizer.
|
||||
"""
|
||||
return FieldNormalizer._date.normalize(value)
|
||||
|
||||
|
||||
# Field type to normalizer mapping
|
||||
NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
|
||||
'InvoiceNumber': FieldNormalizer.normalize_invoice_number,
|
||||
'OCR': FieldNormalizer.normalize_ocr_number,
|
||||
'Bankgiro': FieldNormalizer.normalize_bankgiro,
|
||||
'Plusgiro': FieldNormalizer.normalize_plusgiro,
|
||||
'Amount': FieldNormalizer.normalize_amount,
|
||||
'InvoiceDate': FieldNormalizer.normalize_date,
|
||||
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
||||
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
|
||||
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
|
||||
'customer_number': FieldNormalizer.normalize_customer_number,
|
||||
}
|
||||
|
||||
|
||||
def normalize_field(field_name: str, value: str) -> list[str]:
|
||||
"""
|
||||
Normalize a field value based on its type.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field (e.g., 'InvoiceNumber', 'Amount')
|
||||
value: Raw value to normalize
|
||||
|
||||
Returns:
|
||||
List of normalized variants
|
||||
"""
|
||||
if value is None or (isinstance(value, str) and not value.strip()):
|
||||
return []
|
||||
|
||||
value = str(value)
|
||||
normalizer = NORMALIZERS.get(field_name)
|
||||
|
||||
if normalizer:
|
||||
return normalizer(value)
|
||||
|
||||
# Default: just clean the text
|
||||
return [FieldNormalizer.clean_text(value)]
|
||||
28
packages/shared/shared/normalize/normalizers/__init__.py
Normal file
28
packages/shared/shared/normalize/normalizers/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Normalizer modules for different field types.
|
||||
|
||||
Each normalizer is responsible for generating variants of a field value
|
||||
for matching against OCR text or other data sources.
|
||||
"""
|
||||
|
||||
from .invoice_number_normalizer import InvoiceNumberNormalizer
|
||||
from .ocr_normalizer import OCRNormalizer
|
||||
from .bankgiro_normalizer import BankgiroNormalizer
|
||||
from .plusgiro_normalizer import PlusgiroNormalizer
|
||||
from .amount_normalizer import AmountNormalizer
|
||||
from .date_normalizer import DateNormalizer
|
||||
from .organisation_number_normalizer import OrganisationNumberNormalizer
|
||||
from .supplier_accounts_normalizer import SupplierAccountsNormalizer
|
||||
from .customer_number_normalizer import CustomerNumberNormalizer
|
||||
|
||||
__all__ = [
|
||||
'InvoiceNumberNormalizer',
|
||||
'OCRNormalizer',
|
||||
'BankgiroNormalizer',
|
||||
'PlusgiroNormalizer',
|
||||
'AmountNormalizer',
|
||||
'DateNormalizer',
|
||||
'OrganisationNumberNormalizer',
|
||||
'SupplierAccountsNormalizer',
|
||||
'CustomerNumberNormalizer',
|
||||
]
|
||||
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Amount Normalizer
|
||||
|
||||
Normalizes monetary amounts with various formats and separators.
|
||||
"""
|
||||
|
||||
import re
|
||||
from .base import BaseNormalizer
|
||||
|
||||
|
||||
class AmountNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes monetary amounts.
|
||||
|
||||
Handles Swedish and international formats with different
|
||||
thousand/decimal separators.
|
||||
|
||||
Examples:
|
||||
'114' -> ['114', '114,00', '114.00']
|
||||
'114,00' -> ['114,00', '114.00', '114']
|
||||
'1 234,56' -> ['1234,56', '1234.56', '1 234,56']
|
||||
'3045 52' -> ['3045.52', '3045,52', '304552']
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of amount."""
|
||||
value = self.clean_text(value)
|
||||
|
||||
# Remove currency symbols and common suffixes
|
||||
value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip()
|
||||
|
||||
variants = [value]
|
||||
|
||||
# Check for space as decimal separator: "3045 52"
|
||||
space_decimal_match = re.match(r'^(\d+)\s+(\d{2})$', value)
|
||||
if space_decimal_match:
|
||||
integer_part = space_decimal_match.group(1)
|
||||
decimal_part = space_decimal_match.group(2)
|
||||
variants.append(f"{integer_part}.{decimal_part}")
|
||||
variants.append(f"{integer_part},{decimal_part}")
|
||||
variants.append(f"{integer_part}{decimal_part}")
|
||||
|
||||
# Check for space as thousand separator: "10 571,00"
|
||||
space_thousand_match = re.match(r'^(\d{1,3})[\s\xa0]+(\d{3})([,\.])(\d{2})$', value)
|
||||
if space_thousand_match:
|
||||
part1 = space_thousand_match.group(1)
|
||||
part2 = space_thousand_match.group(2)
|
||||
sep = space_thousand_match.group(3)
|
||||
decimal = space_thousand_match.group(4)
|
||||
combined = f"{part1}{part2}"
|
||||
variants.append(f"{combined}.{decimal}")
|
||||
variants.append(f"{combined},{decimal}")
|
||||
variants.append(f"{combined}{decimal}")
|
||||
other_sep = ',' if sep == '.' else '.'
|
||||
variants.append(f"{part1} {part2}{other_sep}{decimal}")
|
||||
|
||||
# Handle US format: "1,390.00"
|
||||
us_format_match = re.match(r'^(\d{1,3}),(\d{3})\.(\d{2})$', value)
|
||||
if us_format_match:
|
||||
part1 = us_format_match.group(1)
|
||||
part2 = us_format_match.group(2)
|
||||
decimal = us_format_match.group(3)
|
||||
combined = f"{part1}{part2}"
|
||||
variants.append(f"{combined}.{decimal}")
|
||||
variants.append(f"{combined},{decimal}")
|
||||
variants.append(combined)
|
||||
variants.append(f"{part1}.{part2},{decimal}")
|
||||
|
||||
# Handle European format: "1.390,00"
|
||||
eu_format_match = re.match(r'^(\d{1,3})\.(\d{3}),(\d{2})$', value)
|
||||
if eu_format_match:
|
||||
part1 = eu_format_match.group(1)
|
||||
part2 = eu_format_match.group(2)
|
||||
decimal = eu_format_match.group(3)
|
||||
combined = f"{part1}{part2}"
|
||||
variants.append(f"{combined}.{decimal}")
|
||||
variants.append(f"{combined},{decimal}")
|
||||
variants.append(combined)
|
||||
variants.append(f"{part1},{part2}.{decimal}")
|
||||
|
||||
# Remove spaces (thousand separators)
|
||||
no_space = value.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Normalize decimal separator
|
||||
if ',' in no_space:
|
||||
dot_version = no_space.replace(',', '.')
|
||||
variants.append(no_space)
|
||||
variants.append(dot_version)
|
||||
elif '.' in no_space:
|
||||
comma_version = no_space.replace('.', ',')
|
||||
variants.append(no_space)
|
||||
variants.append(comma_version)
|
||||
else:
|
||||
# Integer amount - add decimal versions
|
||||
variants.append(no_space)
|
||||
variants.append(f"{no_space},00")
|
||||
variants.append(f"{no_space}.00")
|
||||
|
||||
# Try to parse and get clean numeric value
|
||||
try:
|
||||
clean = no_space.replace(',', '.')
|
||||
num = float(clean)
|
||||
|
||||
# Integer if no decimals
|
||||
if num == int(num):
|
||||
int_val = int(num)
|
||||
variants.append(str(int_val))
|
||||
variants.append(f"{int_val},00")
|
||||
variants.append(f"{int_val}.00")
|
||||
|
||||
# European format with dot as thousand separator
|
||||
if int_val >= 1000:
|
||||
formatted = f"{int_val:,}".replace(',', '.')
|
||||
variants.append(formatted)
|
||||
variants.append(f"{formatted},00")
|
||||
else:
|
||||
variants.append(f"{num:.2f}")
|
||||
variants.append(f"{num:.2f}".replace('.', ','))
|
||||
|
||||
# European format with dot as thousand separator
|
||||
if num >= 1000:
|
||||
formatted_str = f"{num:.2f}"
|
||||
int_str, dec_str = formatted_str.split(".")
|
||||
int_part = int(int_str)
|
||||
formatted_int = f"{int_part:,}".replace(',', '.')
|
||||
variants.append(f"{formatted_int},{dec_str}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Bankgiro Number Normalizer
|
||||
|
||||
Normalizes Swedish Bankgiro account numbers.
|
||||
"""
|
||||
|
||||
from .base import BaseNormalizer
|
||||
from shared.utils.format_variants import FormatVariants
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class BankgiroNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Bankgiro numbers.
|
||||
|
||||
Generates format variants and OCR error variants.
|
||||
|
||||
Examples:
|
||||
'5393-9484' -> ['5393-9484', '53939484', ...]
|
||||
'53939484' -> ['53939484', '5393-9484', ...]
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of Bankgiro number."""
|
||||
# Use shared module for base variants
|
||||
variants = set(FormatVariants.bankgiro_variants(value))
|
||||
|
||||
# Add OCR error variants
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
34
packages/shared/shared/normalize/normalizers/base.py
Normal file
34
packages/shared/shared/normalize/normalizers/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Base class for field normalizers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class BaseNormalizer(ABC):
|
||||
"""Base class for all field normalizers."""
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean text using shared TextCleaner."""
|
||||
return TextCleaner.clean_text(text)
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""
|
||||
Normalize a field value and return all variants.
|
||||
|
||||
Args:
|
||||
value: Raw field value
|
||||
|
||||
Returns:
|
||||
List of normalized variants for matching
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, value: str) -> list[str]:
|
||||
"""Allow normalizer to be called as a function."""
|
||||
if value is None or (isinstance(value, str) and not value.strip()):
|
||||
return []
|
||||
return self.normalize(str(value))
|
||||
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Customer Number Normalizer
|
||||
|
||||
Normalizes customer numbers (alphanumeric codes).
|
||||
"""
|
||||
|
||||
from .base import BaseNormalizer
|
||||
|
||||
|
||||
class CustomerNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes customer numbers.
|
||||
|
||||
Customer numbers can have various formats:
|
||||
- Alphanumeric codes: 'EMM 256-6', 'ABC123', 'A-1234'
|
||||
- Pure numbers: '12345', '123-456'
|
||||
|
||||
Examples:
|
||||
'EMM 256-6' -> ['EMM 256-6', 'EMM256-6', 'EMM2566']
|
||||
'ABC 123' -> ['ABC 123', 'ABC123']
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of customer number."""
|
||||
value = self.clean_text(value)
|
||||
variants = [value]
|
||||
|
||||
# Version without spaces
|
||||
no_space = value.replace(' ', '')
|
||||
if no_space != value:
|
||||
variants.append(no_space)
|
||||
|
||||
# Version without dashes
|
||||
no_dash = value.replace('-', '')
|
||||
if no_dash != value:
|
||||
variants.append(no_dash)
|
||||
|
||||
# Version without spaces and dashes
|
||||
clean = value.replace(' ', '').replace('-', '')
|
||||
if clean != value and clean not in variants:
|
||||
variants.append(clean)
|
||||
|
||||
# Uppercase and lowercase versions
|
||||
if value.upper() != value:
|
||||
variants.append(value.upper())
|
||||
if value.lower() != value:
|
||||
variants.append(value.lower())
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
190
packages/shared/shared/normalize/normalizers/date_normalizer.py
Normal file
190
packages/shared/shared/normalize/normalizers/date_normalizer.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Date Normalizer
|
||||
|
||||
Normalizes dates in various formats to ISO and generates variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from .base import BaseNormalizer
|
||||
|
||||
|
||||
class DateNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes dates to YYYY-MM-DD and generates variants.
|
||||
|
||||
Handles Swedish and international date formats.
|
||||
|
||||
Examples:
|
||||
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
|
||||
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
|
||||
'13 december 2025' -> ['2025-12-13', ...]
|
||||
"""
|
||||
|
||||
# Swedish month names
|
||||
SWEDISH_MONTHS = {
|
||||
'januari': '01', 'jan': '01',
|
||||
'februari': '02', 'feb': '02',
|
||||
'mars': '03', 'mar': '03',
|
||||
'april': '04', 'apr': '04',
|
||||
'maj': '05',
|
||||
'juni': '06', 'jun': '06',
|
||||
'juli': '07', 'jul': '07',
|
||||
'augusti': '08', 'aug': '08',
|
||||
'september': '09', 'sep': '09', 'sept': '09',
|
||||
'oktober': '10', 'okt': '10',
|
||||
'november': '11', 'nov': '11',
|
||||
'december': '12', 'dec': '12'
|
||||
}
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of date."""
|
||||
value = self.clean_text(value)
|
||||
variants = [value]
|
||||
parsed_dates = []
|
||||
|
||||
# Try unambiguous patterns first
|
||||
date_patterns = [
|
||||
# ISO format with optional time
|
||||
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$',
|
||||
lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
||||
# Swedish format: YYMMDD
|
||||
(r'^(\d{2})(\d{2})(\d{2})$',
|
||||
lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
|
||||
# Swedish format: YYYYMMDD
|
||||
(r'^(\d{4})(\d{2})(\d{2})$',
|
||||
lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
||||
]
|
||||
|
||||
for pattern, extractor in date_patterns:
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
try:
|
||||
year, month, day = extractor(match)
|
||||
parsed_dates.append(datetime(year, month, day))
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try ambiguous patterns with 4-digit year
|
||||
ambiguous_patterns_4digit = [
|
||||
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
|
||||
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
|
||||
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
|
||||
]
|
||||
|
||||
if not parsed_dates:
|
||||
for pattern in ambiguous_patterns_4digit:
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
|
||||
|
||||
# Try DD/MM/YYYY (European - day first)
|
||||
try:
|
||||
parsed_dates.append(datetime(year, n2, n1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try MM/DD/YYYY (US - month first) if different
|
||||
if n1 != n2:
|
||||
try:
|
||||
parsed_dates.append(datetime(year, n1, n2))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if parsed_dates:
|
||||
break
|
||||
|
||||
# Try ambiguous patterns with 2-digit year
|
||||
ambiguous_patterns_2digit = [
|
||||
r'^(\d{1,2})\.(\d{1,2})\.(\d{2})$',
|
||||
r'^(\d{1,2})/(\d{1,2})/(\d{2})$',
|
||||
r'^(\d{1,2})-(\d{1,2})-(\d{2})$',
|
||||
]
|
||||
|
||||
if not parsed_dates:
|
||||
for pattern in ambiguous_patterns_2digit:
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
n1, n2, yy = int(match[1]), int(match[2]), int(match[3])
|
||||
year = 2000 + yy if yy < 50 else 1900 + yy
|
||||
|
||||
# Try DD/MM/YY (European)
|
||||
try:
|
||||
parsed_dates.append(datetime(year, n2, n1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try MM/DD/YY (US) if different
|
||||
if n1 != n2:
|
||||
try:
|
||||
parsed_dates.append(datetime(year, n1, n2))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if parsed_dates:
|
||||
break
|
||||
|
||||
# Try Swedish month names
|
||||
if not parsed_dates:
|
||||
for month_name, month_num in self.SWEDISH_MONTHS.items():
|
||||
if month_name in value.lower():
|
||||
numbers = re.findall(r'\d+', value)
|
||||
if len(numbers) >= 2:
|
||||
day = int(numbers[0])
|
||||
year = int(numbers[-1])
|
||||
if year < 100:
|
||||
year = 2000 + year if year < 50 else 1900 + year
|
||||
try:
|
||||
parsed_dates.append(datetime(year, int(month_num), day))
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Generate variants for all parsed dates
|
||||
swedish_months_full = [
|
||||
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
|
||||
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
|
||||
]
|
||||
swedish_months_abbrev = [
|
||||
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
|
||||
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
|
||||
]
|
||||
|
||||
for parsed_date in parsed_dates:
|
||||
iso = parsed_date.strftime('%Y-%m-%d')
|
||||
eu_slash = parsed_date.strftime('%d/%m/%Y')
|
||||
us_slash = parsed_date.strftime('%m/%d/%Y')
|
||||
eu_dot = parsed_date.strftime('%d.%m.%Y')
|
||||
iso_dot = parsed_date.strftime('%Y.%m.%d')
|
||||
compact = parsed_date.strftime('%Y%m%d')
|
||||
compact_short = parsed_date.strftime('%y%m%d')
|
||||
eu_dot_short = parsed_date.strftime('%d.%m.%y')
|
||||
eu_slash_short = parsed_date.strftime('%d/%m/%y')
|
||||
yy_mm_dd_short = parsed_date.strftime('%y-%m-%d')
|
||||
iso_middot = parsed_date.strftime('%Y·%m·%d')
|
||||
spaced_full = parsed_date.strftime('%Y %m %d')
|
||||
spaced_short = parsed_date.strftime('%y %m %d')
|
||||
|
||||
# Swedish month name formats
|
||||
month_full = swedish_months_full[parsed_date.month - 1]
|
||||
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
|
||||
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
|
||||
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
|
||||
|
||||
month_abbrev_upper = month_abbrev.upper()
|
||||
swedish_hyphen_short = f"{parsed_date.day:02d}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
|
||||
swedish_hyphen_short_lower = f"{parsed_date.day:02d}-{month_abbrev}-{parsed_date.strftime('%y')}"
|
||||
swedish_hyphen_short_no_zero = f"{parsed_date.day}-{month_abbrev_upper}-{parsed_date.strftime('%y')}"
|
||||
month_year_only = f"{month_abbrev_upper}-{parsed_date.strftime('%y')}"
|
||||
swedish_spaced = f"{parsed_date.day:02d} {month_abbrev_upper} {parsed_date.strftime('%y')}"
|
||||
|
||||
variants.extend([
|
||||
iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
|
||||
eu_dot_short, eu_slash_short, yy_mm_dd_short, iso_middot, spaced_full, spaced_short,
|
||||
swedish_format_full, swedish_format_abbrev,
|
||||
swedish_hyphen_short, swedish_hyphen_short_lower, swedish_hyphen_short_no_zero,
|
||||
month_year_only, swedish_spaced
|
||||
])
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Invoice Number Normalizer
|
||||
|
||||
Normalizes invoice numbers for matching.
|
||||
"""
|
||||
|
||||
import re
|
||||
from .base import BaseNormalizer
|
||||
|
||||
|
||||
class InvoiceNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes invoice numbers.
|
||||
|
||||
Keeps only digits for matching while preserving original format.
|
||||
|
||||
Examples:
|
||||
'100017500321' -> ['100017500321']
|
||||
'INV-100017500321' -> ['100017500321', 'INV-100017500321']
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of invoice number."""
|
||||
value = self.clean_text(value)
|
||||
digits_only = re.sub(r'\D', '', value)
|
||||
|
||||
variants = [value]
|
||||
if digits_only and digits_only != value:
|
||||
variants.append(digits_only)
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
OCR Number Normalizer
|
||||
|
||||
Normalizes OCR reference numbers (Swedish payment system).
|
||||
"""
|
||||
|
||||
import re
|
||||
from .base import BaseNormalizer
|
||||
|
||||
|
||||
class OCRNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes OCR reference numbers.
|
||||
|
||||
Similar to invoice number - primarily digits.
|
||||
|
||||
Examples:
|
||||
'94228110015950070' -> ['94228110015950070']
|
||||
'OCR: 94228110015950070' -> ['94228110015950070', 'OCR: 94228110015950070']
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of OCR number."""
|
||||
value = self.clean_text(value)
|
||||
digits_only = re.sub(r'\D', '', value)
|
||||
|
||||
variants = [value]
|
||||
if digits_only and digits_only != value:
|
||||
variants.append(digits_only)
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Organisation Number Normalizer
|
||||
|
||||
Normalizes Swedish organisation numbers and VAT numbers.
|
||||
"""
|
||||
|
||||
from .base import BaseNormalizer
|
||||
from shared.utils.format_variants import FormatVariants
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class OrganisationNumberNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Swedish organisation numbers and VAT numbers.
|
||||
|
||||
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
|
||||
Swedish VAT format: SE + org_number (10 digits) + 01
|
||||
|
||||
Examples:
|
||||
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
|
||||
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
|
||||
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of organisation number."""
|
||||
# Use shared module for base variants
|
||||
variants = set(FormatVariants.organisation_number_variants(value))
|
||||
|
||||
# Add OCR error variants for digit sequences
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits and len(digits) >= 10:
|
||||
# Generate variants where OCR might have misread characters
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits[:10]):
|
||||
variants.add(ocr_var)
|
||||
if len(ocr_var) == 10:
|
||||
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Plusgiro Number Normalizer
|
||||
|
||||
Normalizes Swedish Plusgiro account numbers.
|
||||
"""
|
||||
|
||||
from .base import BaseNormalizer
|
||||
from shared.utils.format_variants import FormatVariants
|
||||
from shared.utils.text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class PlusgiroNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes Plusgiro numbers.
|
||||
|
||||
Generates format variants and OCR error variants.
|
||||
|
||||
Examples:
|
||||
'1234567-8' -> ['1234567-8', '12345678', ...]
|
||||
'12345678' -> ['12345678', '1234567-8', ...]
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of Plusgiro number."""
|
||||
# Use shared module for base variants
|
||||
variants = set(FormatVariants.plusgiro_variants(value))
|
||||
|
||||
# Add OCR error variants
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Supplier Accounts Normalizer
|
||||
|
||||
Normalizes supplier account numbers (Bankgiro/Plusgiro).
|
||||
"""
|
||||
|
||||
import re
|
||||
from .base import BaseNormalizer
|
||||
|
||||
|
||||
class SupplierAccountsNormalizer(BaseNormalizer):
|
||||
"""
|
||||
Normalizes supplier accounts field.
|
||||
|
||||
The field may contain multiple accounts separated by ' | '.
|
||||
Format examples:
|
||||
'PG:48676043 | PG:49128028 | PG:8915035'
|
||||
'BG:5393-9484'
|
||||
|
||||
Each account is normalized separately to generate variants.
|
||||
|
||||
Examples:
|
||||
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
|
||||
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
|
||||
"""
|
||||
|
||||
def normalize(self, value: str) -> list[str]:
|
||||
"""Generate variants of supplier accounts."""
|
||||
value = self.clean_text(value)
|
||||
variants = []
|
||||
|
||||
# Split by ' | ' to handle multiple accounts
|
||||
accounts = [acc.strip() for acc in value.split('|')]
|
||||
|
||||
for account in accounts:
|
||||
account = account.strip()
|
||||
if not account:
|
||||
continue
|
||||
|
||||
# Add original value
|
||||
variants.append(account)
|
||||
|
||||
# Remove prefix (PG:, BG:, etc.)
|
||||
if ':' in account:
|
||||
prefix, number = account.split(':', 1)
|
||||
number = number.strip()
|
||||
variants.append(number) # Just the number without prefix
|
||||
|
||||
# Also add with different prefix formats
|
||||
prefix_upper = prefix.strip().upper()
|
||||
variants.append(f"{prefix_upper}:{number}")
|
||||
variants.append(f"{prefix_upper}: {number}") # With space
|
||||
else:
|
||||
number = account
|
||||
|
||||
# Extract digits only
|
||||
digits_only = re.sub(r'\D', '', number)
|
||||
|
||||
if digits_only:
|
||||
variants.append(digits_only)
|
||||
|
||||
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
|
||||
if len(digits_only) == 8:
|
||||
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||
variants.append(with_dash)
|
||||
# Also try 4-4 format for bankgiro
|
||||
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
|
||||
elif len(digits_only) == 7:
|
||||
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||
variants.append(with_dash)
|
||||
elif len(digits_only) == 10:
|
||||
# 6-4 format (like org number)
|
||||
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
16
packages/shared/shared/ocr/__init__.py
Normal file
16
packages/shared/shared/ocr/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .paddle_ocr import OCREngine, OCRResult, OCRToken, extract_ocr_tokens
|
||||
from .machine_code_parser import (
|
||||
MachineCodeParser,
|
||||
MachineCodeResult,
|
||||
parse_machine_code,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'OCREngine',
|
||||
'OCRResult',
|
||||
'OCRToken',
|
||||
'extract_ocr_tokens',
|
||||
'MachineCodeParser',
|
||||
'MachineCodeResult',
|
||||
'parse_machine_code',
|
||||
]
|
||||
929
packages/shared/shared/ocr/machine_code_parser.py
Normal file
929
packages/shared/shared/ocr/machine_code_parser.py
Normal file
@@ -0,0 +1,929 @@
|
||||
"""
|
||||
Machine Code Line Parser for Swedish Invoices
|
||||
|
||||
Parses the bottom machine-readable payment line to extract:
|
||||
- OCR reference number (10-25 digits)
|
||||
- Amount (payment amount in SEK)
|
||||
- Bankgiro account number (XXX-XXXX or XXXX-XXXX format)
|
||||
- Plusgiro account number (XXXXXXX-X format)
|
||||
|
||||
The machine code line is typically found at the bottom of Swedish invoices,
|
||||
in the payment slip (Inbetalningskort) section. It contains machine-readable
|
||||
data for automated payment processing.
|
||||
|
||||
## Swedish Payment Line Standard Format
|
||||
|
||||
The standard machine-readable payment line follows this structure:
|
||||
|
||||
# <OCR> # <Kronor> <Öre> <Type> > <Bankgiro>#<Control>#
|
||||
|
||||
Example:
|
||||
# 31130954410 # 315 00 2 > 8983025#14#
|
||||
|
||||
Components:
|
||||
- `#` - Start delimiter
|
||||
- `31130954410` - OCR number (with Mod 10 check digit)
|
||||
- `#` - Separator
|
||||
- `315 00` - Amount: 315 SEK and 00 öre (315.00 SEK)
|
||||
- `2` - Payment type / record type
|
||||
- `>` - Points to recipient info
|
||||
- `8983025` - Bankgiro number
|
||||
- `#14#` - End marker with control code
|
||||
|
||||
Legacy patterns also supported:
|
||||
- OCR: 8120000849965361 (10-25 consecutive digits)
|
||||
- Bankgiro: 5393-9484 or 53939484
|
||||
- Plusgiro: 1234567-8
|
||||
- Amount: 1234,56 or 1234.56 (with decimal separator)
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from shared.pdf.extractor import Token as TextToken
|
||||
from shared.utils.validators import FieldValidators
|
||||
|
||||
|
||||
@dataclass
|
||||
class MachineCodeResult:
|
||||
"""Result of machine code parsing."""
|
||||
ocr: Optional[str] = None
|
||||
amount: Optional[str] = None
|
||||
bankgiro: Optional[str] = None
|
||||
plusgiro: Optional[str] = None
|
||||
confidence: float = 0.0
|
||||
source_tokens: list[TextToken] = field(default_factory=list)
|
||||
raw_line: str = ""
|
||||
# Region bounding box in PDF coordinates (x0, y0, x1, y1)
|
||||
region_bbox: Optional[tuple[float, float, float, float]] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
'ocr': self.ocr,
|
||||
'amount': self.amount,
|
||||
'bankgiro': self.bankgiro,
|
||||
'plusgiro': self.plusgiro,
|
||||
'confidence': self.confidence,
|
||||
'raw_line': self.raw_line,
|
||||
'region_bbox': self.region_bbox,
|
||||
}
|
||||
|
||||
def get_region_bbox(self) -> Optional[tuple[float, float, float, float]]:
|
||||
"""
|
||||
Get the bounding box of the payment slip region.
|
||||
|
||||
Returns:
|
||||
Tuple (x0, y0, x1, y1) in PDF coordinates, or None if no region detected
|
||||
"""
|
||||
if self.region_bbox:
|
||||
return self.region_bbox
|
||||
|
||||
if not self.source_tokens:
|
||||
return None
|
||||
|
||||
# Calculate bbox from source tokens
|
||||
x0 = min(t.bbox[0] for t in self.source_tokens)
|
||||
y0 = min(t.bbox[1] for t in self.source_tokens)
|
||||
x1 = max(t.bbox[2] for t in self.source_tokens)
|
||||
y1 = max(t.bbox[3] for t in self.source_tokens)
|
||||
|
||||
return (x0, y0, x1, y1)
|
||||
|
||||
|
||||
class MachineCodeParser:
|
||||
"""
|
||||
Parser for machine-readable payment lines on Swedish invoices.
|
||||
|
||||
The parser focuses on the bottom region of the invoice where
|
||||
the payment slip (Inbetalningskort) is typically located.
|
||||
"""
|
||||
|
||||
# Payment slip detection keywords (Swedish)
|
||||
PAYMENT_SLIP_KEYWORDS = [
|
||||
'inbetalning', 'girering', 'avi', 'betalning',
|
||||
'plusgiro', 'postgiro', 'bankgiro', 'bankgirot',
|
||||
'betalningsavsändare', 'betalningsmottagare',
|
||||
'maskinellt', 'ändringar', # "DEN AVLÄSES MASKINELLT"
|
||||
]
|
||||
|
||||
# Patterns for field extraction
|
||||
# OCR: 10-25 consecutive digits (may have spaces or # at end)
|
||||
OCR_PATTERN = re.compile(r'(?<!\d)(\d{10,25})(?!\d)')
|
||||
|
||||
# Bankgiro: XXX-XXXX or XXXX-XXXX (7-8 digits with optional dash)
|
||||
BANKGIRO_PATTERN = re.compile(r'\b(\d{3,4}[-\s]?\d{4})\b')
|
||||
|
||||
# Plusgiro: XXXXXXX-X (7-8 digits with dash before last digit)
|
||||
PLUSGIRO_PATTERN = re.compile(r'\b(\d{6,7}[-\s]?\d)\b')
|
||||
|
||||
# Amount: digits with comma or dot as decimal separator
|
||||
# Supports formats: 1234,56 | 1234.56 | 1 234,56 | 1.234,56
|
||||
AMOUNT_PATTERN = re.compile(
|
||||
r'\b(\d{1,3}(?:[\s\.\xa0]\d{3})*[,\.]\d{2})\b'
|
||||
)
|
||||
|
||||
# Alternative amount pattern for integers (no decimal)
|
||||
AMOUNT_INTEGER_PATTERN = re.compile(r'\b(\d{2,6})\b')
|
||||
|
||||
# Standard Swedish payment line pattern
|
||||
# Format: # <OCR> # <Kronor> <Öre> <Type> > <Bankgiro/Plusgiro>#<Control>#
|
||||
# Example: # 31130954410 # 315 00 2 > 8983025#14#
|
||||
# This pattern captures both Bankgiro and Plusgiro accounts
|
||||
PAYMENT_LINE_PATTERN = re.compile(
|
||||
r'#\s*' # Start delimiter
|
||||
r'(\d{5,25})\s*' # OCR number (capture group 1)
|
||||
r'#\s*' # Separator
|
||||
r'(\d{1,7})\s+' # Kronor (capture group 2)
|
||||
r'(\d{2})\s+' # Öre (capture group 3)
|
||||
r'(\d)\s*' # Type (capture group 4)
|
||||
r'>\s*' # Direction marker
|
||||
r'(\d{5,10})' # Bankgiro/Plusgiro (capture group 5)
|
||||
r'(?:#\d{1,3}#)?' # Optional end marker
|
||||
)
|
||||
|
||||
# Alternative pattern with different spacing
|
||||
PAYMENT_LINE_PATTERN_ALT = re.compile(
|
||||
r'#?\s*' # Optional start delimiter
|
||||
r'(\d{8,25})\s*' # OCR number
|
||||
r'#?\s*' # Optional separator
|
||||
r'(\d{1,7})\s+' # Kronor
|
||||
r'(\d{2})\s+' # Öre
|
||||
r'\d\s*' # Type
|
||||
r'>?\s*' # Optional direction marker
|
||||
r'(\d{5,10})' # Bankgiro
|
||||
)
|
||||
|
||||
# Reverse format pattern (Bankgiro first, then OCR)
|
||||
# Format: <Bankgiro>#<Control># <Kronor> <Öre> <Type> > <OCR> #
|
||||
# Example: 53241469#41# 2428 00 1 > 4388595300 #
|
||||
PAYMENT_LINE_PATTERN_REVERSE = re.compile(
|
||||
r'(\d{7,8})' # Bankgiro (capture group 1)
|
||||
r'#\d{1,3}#\s+' # Control marker
|
||||
r'(\d{1,7})\s+' # Kronor (capture group 2)
|
||||
r'(\d{2})\s+' # Öre (capture group 3)
|
||||
r'\d\s*' # Type
|
||||
r'>\s*' # Direction marker
|
||||
r'(\d{5,25})' # OCR number (capture group 4)
|
||||
)
|
||||
|
||||
def __init__(self, bottom_region_ratio: float = 0.35):
|
||||
"""
|
||||
Initialize the parser.
|
||||
|
||||
Args:
|
||||
bottom_region_ratio: Fraction of page height to consider as bottom region.
|
||||
Default 0.35 means bottom 35% of the page.
|
||||
"""
|
||||
self.bottom_region_ratio = bottom_region_ratio
|
||||
|
||||
def _detect_account_context(self, tokens: list[TextToken]) -> dict[str, bool]:
|
||||
"""
|
||||
Detect account type keywords in context.
|
||||
|
||||
Returns:
|
||||
Dict with 'bankgiro' and 'plusgiro' boolean flags
|
||||
"""
|
||||
context_text = ' '.join(t.text.lower() for t in tokens)
|
||||
|
||||
return {
|
||||
'bankgiro': any(kw in context_text for kw in ['bankgiro', 'bg:', 'bg ']),
|
||||
'plusgiro': any(kw in context_text for kw in ['plusgiro', 'postgiro', 'plusgirokonto', 'pg:', 'pg ']),
|
||||
}
|
||||
|
||||
def _normalize_account_spaces(self, line: str) -> str:
|
||||
"""
|
||||
Remove spaces in account number portion after > marker.
|
||||
|
||||
Args:
|
||||
line: Payment line text
|
||||
|
||||
Returns:
|
||||
Line with normalized account number spacing
|
||||
"""
|
||||
if '>' not in line:
|
||||
return line
|
||||
|
||||
parts = line.split('>', 1)
|
||||
# After >, remove spaces between digits (but keep # markers)
|
||||
after_arrow = parts[1]
|
||||
# Extract digits and # markers, remove spaces between digits
|
||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', after_arrow)
|
||||
# May need multiple passes for sequences like "78 2 1 713"
|
||||
while re.search(r'(\d)\s+(\d)', normalized):
|
||||
normalized = re.sub(r'(\d)\s+(\d)', r'\1\2', normalized)
|
||||
return parts[0] + '>' + normalized
|
||||
|
||||
def _format_account(
|
||||
self,
|
||||
account_digits: str,
|
||||
is_plusgiro_context: bool
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Format account number and determine type (bankgiro or plusgiro).
|
||||
|
||||
Uses context keywords first, then falls back to Luhn validation
|
||||
to determine the most likely account type.
|
||||
|
||||
Args:
|
||||
account_digits: Raw digits of account number
|
||||
is_plusgiro_context: Whether context indicates Plusgiro
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_account, account_type)
|
||||
"""
|
||||
if is_plusgiro_context:
|
||||
# Context explicitly indicates Plusgiro
|
||||
formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
||||
return formatted, 'plusgiro'
|
||||
|
||||
# No explicit context - use Luhn validation to determine type
|
||||
# Try both formats and see which passes Luhn check
|
||||
|
||||
# Format as Plusgiro: XXXXXXX-X (all digits, check digit at end)
|
||||
pg_formatted = f"{account_digits[:-1]}-{account_digits[-1]}"
|
||||
pg_valid = FieldValidators.is_valid_plusgiro(account_digits)
|
||||
|
||||
# Format as Bankgiro: XXX-XXXX or XXXX-XXXX
|
||||
if len(account_digits) == 7:
|
||||
bg_formatted = f"{account_digits[:3]}-{account_digits[3:]}"
|
||||
elif len(account_digits) == 8:
|
||||
bg_formatted = f"{account_digits[:4]}-{account_digits[4:]}"
|
||||
else:
|
||||
bg_formatted = account_digits
|
||||
bg_valid = FieldValidators.is_valid_bankgiro(account_digits)
|
||||
|
||||
# Decision logic:
|
||||
# 1. If only one format passes Luhn, use that
|
||||
# 2. If both pass or both fail, default to Bankgiro (more common in payment lines)
|
||||
if pg_valid and not bg_valid:
|
||||
return pg_formatted, 'plusgiro'
|
||||
elif bg_valid and not pg_valid:
|
||||
return bg_formatted, 'bankgiro'
|
||||
else:
|
||||
# Both valid or both invalid - default to bankgiro
|
||||
return bg_formatted, 'bankgiro'
|
||||
|
||||
def parse(
|
||||
self,
|
||||
tokens: list[TextToken],
|
||||
page_height: float,
|
||||
page_width: float | None = None,
|
||||
) -> MachineCodeResult:
|
||||
"""
|
||||
Parse machine code from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens from OCR or text extraction
|
||||
page_height: Height of the page in points
|
||||
page_width: Width of the page in points (optional)
|
||||
|
||||
Returns:
|
||||
MachineCodeResult with extracted fields
|
||||
"""
|
||||
if not tokens:
|
||||
return MachineCodeResult()
|
||||
|
||||
# Filter to bottom region tokens
|
||||
bottom_y_threshold = page_height * (1 - self.bottom_region_ratio)
|
||||
bottom_tokens = [
|
||||
t for t in tokens
|
||||
if t.bbox[1] >= bottom_y_threshold # y0 >= threshold
|
||||
]
|
||||
|
||||
if not bottom_tokens:
|
||||
return MachineCodeResult()
|
||||
|
||||
# Sort by y position (top to bottom) then x (left to right)
|
||||
bottom_tokens.sort(key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
|
||||
# Check if this looks like a payment slip region
|
||||
combined_text = ' '.join(t.text for t in bottom_tokens).lower()
|
||||
has_payment_keywords = any(
|
||||
kw in combined_text for kw in self.PAYMENT_SLIP_KEYWORDS
|
||||
)
|
||||
|
||||
# Build raw line from bottom tokens
|
||||
raw_line = ' '.join(t.text for t in bottom_tokens)
|
||||
|
||||
# Try standard payment line format first and find the matching tokens
|
||||
standard_result, matched_tokens = self._parse_standard_payment_line_with_tokens(
|
||||
raw_line, bottom_tokens
|
||||
)
|
||||
|
||||
if standard_result and matched_tokens:
|
||||
# Calculate bbox only from tokens that contain the machine code
|
||||
x0 = min(t.bbox[0] for t in matched_tokens)
|
||||
y0 = min(t.bbox[1] for t in matched_tokens)
|
||||
x1 = max(t.bbox[2] for t in matched_tokens)
|
||||
y1 = max(t.bbox[3] for t in matched_tokens)
|
||||
region_bbox = (x0, y0, x1, y1)
|
||||
|
||||
result = MachineCodeResult(
|
||||
ocr=standard_result.get('ocr'),
|
||||
amount=standard_result.get('amount'),
|
||||
bankgiro=standard_result.get('bankgiro'),
|
||||
plusgiro=standard_result.get('plusgiro'),
|
||||
confidence=0.95,
|
||||
source_tokens=matched_tokens,
|
||||
raw_line=raw_line,
|
||||
region_bbox=region_bbox,
|
||||
)
|
||||
return result
|
||||
|
||||
# Fall back to individual field extraction
|
||||
result = MachineCodeResult(
|
||||
source_tokens=bottom_tokens,
|
||||
raw_line=raw_line,
|
||||
)
|
||||
|
||||
# Extract OCR number (longest digit sequence 10-25 digits)
|
||||
result.ocr = self._extract_ocr(bottom_tokens)
|
||||
|
||||
# Extract Bankgiro
|
||||
result.bankgiro = self._extract_bankgiro(bottom_tokens)
|
||||
|
||||
# Extract Plusgiro (if no Bankgiro found)
|
||||
if not result.bankgiro:
|
||||
result.plusgiro = self._extract_plusgiro(bottom_tokens)
|
||||
|
||||
# Extract Amount
|
||||
result.amount = self._extract_amount(bottom_tokens)
|
||||
|
||||
# Calculate confidence
|
||||
result.confidence = self._calculate_confidence(
|
||||
result, has_payment_keywords
|
||||
)
|
||||
|
||||
# For fallback extraction, compute bbox from tokens that contain the extracted values
|
||||
matched_tokens = self._find_tokens_with_values(bottom_tokens, result)
|
||||
if matched_tokens:
|
||||
x0 = min(t.bbox[0] for t in matched_tokens)
|
||||
y0 = min(t.bbox[1] for t in matched_tokens)
|
||||
x1 = max(t.bbox[2] for t in matched_tokens)
|
||||
y1 = max(t.bbox[3] for t in matched_tokens)
|
||||
result.region_bbox = (x0, y0, x1, y1)
|
||||
result.source_tokens = matched_tokens
|
||||
|
||||
return result
|
||||
|
||||
def _find_tokens_with_values(
|
||||
self,
|
||||
tokens: list[TextToken],
|
||||
result: MachineCodeResult
|
||||
) -> list[TextToken]:
|
||||
"""Find tokens that contain the extracted values (OCR, Amount, Bankgiro)."""
|
||||
matched = []
|
||||
values_to_find = []
|
||||
|
||||
if result.ocr:
|
||||
values_to_find.append(result.ocr)
|
||||
if result.amount:
|
||||
# Amount might be just digits
|
||||
amount_digits = re.sub(r'\D', '', result.amount)
|
||||
values_to_find.append(amount_digits)
|
||||
values_to_find.append(result.amount)
|
||||
if result.bankgiro:
|
||||
# Bankgiro might have dash or not
|
||||
bg_digits = re.sub(r'\D', '', result.bankgiro)
|
||||
values_to_find.append(bg_digits)
|
||||
values_to_find.append(result.bankgiro)
|
||||
if result.plusgiro:
|
||||
pg_digits = re.sub(r'\D', '', result.plusgiro)
|
||||
values_to_find.append(pg_digits)
|
||||
values_to_find.append(result.plusgiro)
|
||||
|
||||
for token in tokens:
|
||||
text = token.text.replace(' ', '').replace('#', '')
|
||||
text_digits = re.sub(r'\D', '', token.text)
|
||||
|
||||
for value in values_to_find:
|
||||
if value in text or value in text_digits:
|
||||
if token not in matched:
|
||||
matched.append(token)
|
||||
break
|
||||
|
||||
return matched
|
||||
|
||||
def _find_machine_code_line_tokens(
|
||||
self,
|
||||
tokens: list[TextToken]
|
||||
) -> list[TextToken]:
|
||||
"""
|
||||
Find tokens that belong to the machine code line using pure regex patterns.
|
||||
|
||||
The machine code line typically contains:
|
||||
- Control markers like #14#, #41#
|
||||
- Direction marker >
|
||||
- Account numbers with # suffix
|
||||
|
||||
Returns:
|
||||
List of tokens belonging to the machine code line
|
||||
"""
|
||||
# Find tokens with characteristic machine code patterns
|
||||
ref_y = None
|
||||
|
||||
# First, find the reference y-coordinate from tokens with machine code patterns
|
||||
for token in tokens:
|
||||
text = token.text
|
||||
|
||||
# Check if token contains machine code patterns
|
||||
# Priority 1: Control marker like #14#, 47304035#14#
|
||||
has_control_marker = bool(re.search(r'#\d+#', text))
|
||||
# Priority 2: Direction marker >
|
||||
has_direction = '>' in text
|
||||
|
||||
if has_control_marker:
|
||||
# This is very likely part of the machine code line
|
||||
ref_y = token.bbox[1]
|
||||
break
|
||||
elif has_direction and ref_y is None:
|
||||
# Direction marker is also a good indicator
|
||||
ref_y = token.bbox[1]
|
||||
|
||||
if ref_y is None:
|
||||
return []
|
||||
|
||||
# Collect all tokens on the same line (within 3 points of ref_y)
|
||||
# Use very small tolerance because Swedish invoices often have duplicate
|
||||
# machine code lines (upper and lower part of payment slip)
|
||||
y_tolerance = 3
|
||||
machine_code_tokens = []
|
||||
for token in tokens:
|
||||
if abs(token.bbox[1] - ref_y) < y_tolerance:
|
||||
text = token.text
|
||||
# Include token if it contains:
|
||||
# - Digits (OCR, amount, account numbers)
|
||||
# - # symbol (delimiters, control markers)
|
||||
# - > symbol (direction marker)
|
||||
if (re.search(r'\d', text) or '#' in text or '>' in text):
|
||||
machine_code_tokens.append(token)
|
||||
|
||||
# If we found very few tokens, try to expand to nearby y values
|
||||
# that might be part of the same logical line
|
||||
if len(machine_code_tokens) < 3:
|
||||
y_tolerance = 10
|
||||
machine_code_tokens = []
|
||||
for token in tokens:
|
||||
if abs(token.bbox[1] - ref_y) < y_tolerance:
|
||||
text = token.text
|
||||
if (re.search(r'\d', text) or '#' in text or '>' in text):
|
||||
machine_code_tokens.append(token)
|
||||
|
||||
return machine_code_tokens
|
||||
|
||||
def _parse_standard_payment_line_with_tokens(
|
||||
self,
|
||||
raw_line: str,
|
||||
tokens: list[TextToken]
|
||||
) -> tuple[Optional[dict], list[TextToken]]:
|
||||
"""
|
||||
Parse standard Swedish payment line format and find matching tokens.
|
||||
|
||||
Uses pure regex to identify the machine code line, then finds tokens
|
||||
that are part of that line based on their position.
|
||||
|
||||
Format: # <OCR> # <Kronor> <Öre> <Type> > <Bankgiro/Plusgiro>#<Control>#
|
||||
Example: # 31130954410 # 315 00 2 > 8983025#14#
|
||||
|
||||
Returns:
|
||||
Tuple of (parsed_dict, matched_tokens) or (None, [])
|
||||
"""
|
||||
# First find the machine code line tokens using pattern matching
|
||||
machine_code_tokens = self._find_machine_code_line_tokens(tokens)
|
||||
|
||||
if not machine_code_tokens:
|
||||
# Fall back to regex on raw_line
|
||||
parsed = self._parse_standard_payment_line(raw_line, raw_line)
|
||||
return parsed, []
|
||||
|
||||
# Build a line from just the machine code tokens (sorted by x position)
|
||||
# Group tokens by approximate x position to handle duplicate OCR results
|
||||
mc_tokens_sorted = sorted(machine_code_tokens, key=lambda t: t.bbox[0])
|
||||
|
||||
# Deduplicate tokens at similar x positions (keep the first one)
|
||||
deduped_tokens = []
|
||||
last_x = -100
|
||||
for t in mc_tokens_sorted:
|
||||
# Skip tokens that are too close to the previous one (likely duplicates)
|
||||
if t.bbox[0] - last_x < 5:
|
||||
continue
|
||||
deduped_tokens.append(t)
|
||||
last_x = t.bbox[2] # Use end x for next comparison
|
||||
|
||||
mc_line = ' '.join(t.text for t in deduped_tokens)
|
||||
|
||||
# Try to parse this line, using raw_line for context detection
|
||||
parsed = self._parse_standard_payment_line(mc_line, raw_line)
|
||||
if parsed:
|
||||
return parsed, deduped_tokens
|
||||
|
||||
# If machine code line parsing failed, try the full raw_line
|
||||
parsed = self._parse_standard_payment_line(raw_line, raw_line)
|
||||
if parsed:
|
||||
return parsed, machine_code_tokens
|
||||
|
||||
return None, []
|
||||
|
||||
def _parse_standard_payment_line(
|
||||
self,
|
||||
raw_line: str,
|
||||
context_line: str | None = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Parse standard Swedish payment line format.
|
||||
|
||||
Format: # <OCR> # <Kronor> <Öre> <Type> > <Bankgiro/Plusgiro>#<Control>#
|
||||
Example: # 31130954410 # 315 00 2 > 8983025#14#
|
||||
|
||||
Args:
|
||||
raw_line: The line to parse (may be just the machine code tokens)
|
||||
context_line: Optional full line for context detection (e.g., to find "plusgiro" keywords)
|
||||
|
||||
Returns:
|
||||
Dict with 'ocr', 'amount', and 'bankgiro' or 'plusgiro' if matched, None otherwise
|
||||
"""
|
||||
# Use context_line for detecting Plusgiro/Bankgiro, fall back to raw_line
|
||||
context = (context_line or raw_line).lower()
|
||||
is_plusgiro_context = (
|
||||
('plusgiro' in context or 'postgiro' in context or 'plusgirokonto' in context)
|
||||
and 'bankgiro' not in context
|
||||
)
|
||||
|
||||
# Preprocess: remove spaces in the account number part (after >)
|
||||
raw_line = self._normalize_account_spaces(raw_line)
|
||||
|
||||
# Try primary pattern
|
||||
match = self.PAYMENT_LINE_PATTERN.search(raw_line)
|
||||
if match:
|
||||
ocr = match.group(1)
|
||||
kronor = match.group(2)
|
||||
ore = match.group(3)
|
||||
account_digits = match.group(5)
|
||||
|
||||
# Format amount: combine kronor and öre
|
||||
amount = f"{kronor},{ore}" if ore != "00" else kronor
|
||||
|
||||
formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
|
||||
|
||||
return {
|
||||
'ocr': ocr,
|
||||
'amount': amount,
|
||||
account_type: formatted_account,
|
||||
}
|
||||
|
||||
# Try alternative pattern
|
||||
match = self.PAYMENT_LINE_PATTERN_ALT.search(raw_line)
|
||||
if match:
|
||||
ocr = match.group(1)
|
||||
kronor = match.group(2)
|
||||
ore = match.group(3)
|
||||
account_digits = match.group(4)
|
||||
|
||||
amount = f"{kronor},{ore}" if ore != "00" else kronor
|
||||
|
||||
formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
|
||||
|
||||
return {
|
||||
'ocr': ocr,
|
||||
'amount': amount,
|
||||
account_type: formatted_account,
|
||||
}
|
||||
|
||||
# Try reverse pattern (Account first, then OCR)
|
||||
match = self.PAYMENT_LINE_PATTERN_REVERSE.search(raw_line)
|
||||
if match:
|
||||
account_digits = match.group(1)
|
||||
kronor = match.group(2)
|
||||
ore = match.group(3)
|
||||
ocr = match.group(4)
|
||||
|
||||
amount = f"{kronor},{ore}" if ore != "00" else kronor
|
||||
|
||||
formatted_account, account_type = self._format_account(account_digits, is_plusgiro_context)
|
||||
|
||||
return {
|
||||
'ocr': ocr,
|
||||
'amount': amount,
|
||||
account_type: formatted_account,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _extract_ocr(self, tokens: list[TextToken]) -> Optional[str]:
|
||||
"""Extract OCR reference number."""
|
||||
candidates = []
|
||||
|
||||
# First, collect all bankgiro-like patterns to exclude
|
||||
bankgiro_digits = set()
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
bg_matches = self.BANKGIRO_PATTERN.findall(text)
|
||||
for bg in bg_matches:
|
||||
digits = re.sub(r'\D', '', bg)
|
||||
bankgiro_digits.add(digits)
|
||||
# Also add with potential check digits (common pattern)
|
||||
for i in range(10):
|
||||
bankgiro_digits.add(digits + str(i))
|
||||
bankgiro_digits.add(digits + str(i) + str(i))
|
||||
|
||||
for token in tokens:
|
||||
# Remove spaces and common suffixes
|
||||
text = token.text.replace(' ', '').replace('#', '').strip()
|
||||
|
||||
# Find all digit sequences
|
||||
matches = self.OCR_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
# OCR numbers are typically 10-25 digits
|
||||
if 10 <= len(match) <= 25:
|
||||
# Skip if this looks like a bankgiro number with check digit
|
||||
is_bankgiro_variant = any(
|
||||
match.startswith(bg) or match.endswith(bg)
|
||||
for bg in bankgiro_digits if len(bg) >= 7
|
||||
)
|
||||
|
||||
# Also check if it's exactly bankgiro with 2-3 extra digits
|
||||
for bg in bankgiro_digits:
|
||||
if len(bg) >= 7 and (
|
||||
match == bg or
|
||||
(len(match) - len(bg) <= 3 and match.startswith(bg))
|
||||
):
|
||||
is_bankgiro_variant = True
|
||||
break
|
||||
|
||||
if not is_bankgiro_variant:
|
||||
candidates.append((match, len(match), token))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Prefer longer sequences (more likely to be OCR)
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
return candidates[0][0]
|
||||
|
||||
def _extract_bankgiro(self, tokens: list[TextToken]) -> Optional[str]:
|
||||
"""Extract Bankgiro account number.
|
||||
|
||||
Bankgiro format: XXX-XXXX or XXXX-XXXX (dash in middle)
|
||||
NOT Plusgiro: XXXXXXX-X (dash before last digit)
|
||||
"""
|
||||
candidates = []
|
||||
context = self._detect_account_context(tokens)
|
||||
|
||||
# If clearly Plusgiro context (and not bankgiro), don't extract as Bankgiro
|
||||
if context['plusgiro'] and not context['bankgiro']:
|
||||
return None
|
||||
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
|
||||
# Look for Bankgiro pattern
|
||||
matches = self.BANKGIRO_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
# Check if this looks like Plusgiro format (dash before last digit)
|
||||
# Plusgiro: 1234567-8 (dash at position -2)
|
||||
if '-' in match:
|
||||
parts = match.replace(' ', '').split('-')
|
||||
if len(parts) == 2 and len(parts[1]) == 1:
|
||||
# This is Plusgiro format, skip
|
||||
continue
|
||||
|
||||
# Normalize: remove spaces, ensure dash
|
||||
digits = re.sub(r'\D', '', match)
|
||||
if len(digits) == 7:
|
||||
normalized = f"{digits[:3]}-{digits[3:]}"
|
||||
elif len(digits) == 8:
|
||||
normalized = f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
continue
|
||||
|
||||
candidates.append((normalized, context['bankgiro'], token))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Prefer matches with bankgiro context
|
||||
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
|
||||
return candidates[0][0]
|
||||
|
||||
def _extract_plusgiro(self, tokens: list[TextToken]) -> Optional[str]:
|
||||
"""Extract Plusgiro account number."""
|
||||
candidates = []
|
||||
context = self._detect_account_context(tokens)
|
||||
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
|
||||
matches = self.PLUSGIRO_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
# Normalize: remove spaces, ensure dash before last digit
|
||||
digits = re.sub(r'\D', '', match)
|
||||
if 7 <= len(digits) <= 8:
|
||||
normalized = f"{digits[:-1]}-{digits[-1]}"
|
||||
candidates.append((normalized, context['plusgiro'], token))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda x: (x[1], 1), reverse=True)
|
||||
return candidates[0][0]
|
||||
|
||||
def _extract_amount(self, tokens: list[TextToken]) -> Optional[str]:
|
||||
"""Extract payment amount."""
|
||||
candidates = []
|
||||
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
|
||||
# Try decimal amount pattern first
|
||||
matches = self.AMOUNT_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
# Normalize: remove thousand separators, use comma as decimal
|
||||
normalized = match.replace(' ', '').replace('\xa0', '')
|
||||
# Convert dot thousand separator to none, keep comma decimal
|
||||
if '.' in normalized and ',' in normalized:
|
||||
# Format like 1.234,56 -> 1234,56
|
||||
normalized = normalized.replace('.', '')
|
||||
elif '.' in normalized:
|
||||
# Could be 1234.56 -> 1234,56
|
||||
parts = normalized.split('.')
|
||||
if len(parts) == 2 and len(parts[1]) == 2:
|
||||
normalized = f"{parts[0]},{parts[1]}"
|
||||
|
||||
# Parse to verify it's a valid amount
|
||||
try:
|
||||
value = float(normalized.replace(',', '.'))
|
||||
if 0 < value < 1000000: # Reasonable amount range
|
||||
candidates.append((normalized, value, token))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# If no decimal amounts found, try integer amounts
|
||||
# Look for "Kronor" label nearby and extract integer
|
||||
if not candidates:
|
||||
for i, token in enumerate(tokens):
|
||||
text = token.text.strip().lower()
|
||||
if 'kronor' in text or 'kr' == text or text.endswith(' kr'):
|
||||
# Look at nearby tokens for amounts (wider range)
|
||||
for j in range(max(0, i - 5), min(len(tokens), i + 5)):
|
||||
nearby_text = tokens[j].text.strip()
|
||||
# Match pure integer (1-6 digits)
|
||||
int_match = re.match(r'^(\d{1,6})$', nearby_text)
|
||||
if int_match:
|
||||
value = int(int_match.group(1))
|
||||
if 0 < value < 1000000:
|
||||
candidates.append((str(value), float(value), tokens[j]))
|
||||
|
||||
# Also try to find amounts near "öre" label (Swedish cents)
|
||||
if not candidates:
|
||||
for i, token in enumerate(tokens):
|
||||
text = token.text.strip().lower()
|
||||
if 'öre' in text:
|
||||
# Look at nearby tokens for amounts
|
||||
for j in range(max(0, i - 5), min(len(tokens), i + 5)):
|
||||
nearby_text = tokens[j].text.strip()
|
||||
int_match = re.match(r'^(\d{1,6})$', nearby_text)
|
||||
if int_match:
|
||||
value = int(int_match.group(1))
|
||||
if 0 < value < 1000000:
|
||||
candidates.append((str(value), float(value), tokens[j]))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Sort by value (prefer larger amounts - likely total)
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
return candidates[0][0]
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
result: MachineCodeResult,
|
||||
has_payment_keywords: bool
|
||||
) -> float:
|
||||
"""Calculate confidence score for the extraction."""
|
||||
confidence = 0.0
|
||||
|
||||
# Base confidence from payment keywords
|
||||
if has_payment_keywords:
|
||||
confidence += 0.3
|
||||
|
||||
# Points for each extracted field
|
||||
if result.ocr:
|
||||
confidence += 0.25
|
||||
# Bonus for typical OCR length (15-17 digits)
|
||||
if 15 <= len(result.ocr) <= 17:
|
||||
confidence += 0.1
|
||||
|
||||
if result.bankgiro or result.plusgiro:
|
||||
confidence += 0.2
|
||||
|
||||
if result.amount:
|
||||
confidence += 0.15
|
||||
|
||||
return min(confidence, 1.0)
|
||||
|
||||
def cross_validate(
|
||||
self,
|
||||
machine_result: MachineCodeResult,
|
||||
csv_values: dict[str, str],
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Cross-validate machine code extraction with CSV ground truth.
|
||||
|
||||
Args:
|
||||
machine_result: Result from parse()
|
||||
csv_values: Dict of field values from CSV
|
||||
(keys: 'ocr', 'amount', 'bankgiro', 'plusgiro')
|
||||
|
||||
Returns:
|
||||
Dict with validation results for each field:
|
||||
{
|
||||
'ocr': {
|
||||
'machine': '123456789',
|
||||
'csv': '123456789',
|
||||
'match': True,
|
||||
'use_machine': False, # CSV has value
|
||||
},
|
||||
...
|
||||
}
|
||||
"""
|
||||
from shared.normalize import normalize_field
|
||||
|
||||
results = {}
|
||||
|
||||
field_mapping = [
|
||||
('ocr', 'OCR', machine_result.ocr),
|
||||
('amount', 'Amount', machine_result.amount),
|
||||
('bankgiro', 'Bankgiro', machine_result.bankgiro),
|
||||
('plusgiro', 'Plusgiro', machine_result.plusgiro),
|
||||
]
|
||||
|
||||
for field_key, normalizer_name, machine_value in field_mapping:
|
||||
csv_value = csv_values.get(field_key, '').strip()
|
||||
|
||||
result_entry = {
|
||||
'machine': machine_value,
|
||||
'csv': csv_value if csv_value else None,
|
||||
'match': False,
|
||||
'use_machine': False,
|
||||
}
|
||||
|
||||
if machine_value and csv_value:
|
||||
# Both have values - check if they match
|
||||
machine_variants = normalize_field(normalizer_name, machine_value)
|
||||
csv_variants = normalize_field(normalizer_name, csv_value)
|
||||
|
||||
# Check for any overlap
|
||||
result_entry['match'] = bool(
|
||||
set(machine_variants) & set(csv_variants)
|
||||
)
|
||||
|
||||
# Special handling for amounts - allow rounding differences
|
||||
if not result_entry['match'] and field_key == 'amount':
|
||||
try:
|
||||
# Parse both values as floats
|
||||
machine_float = float(
|
||||
machine_value.replace(' ', '')
|
||||
.replace(',', '.').replace('\xa0', '')
|
||||
)
|
||||
csv_float = float(
|
||||
csv_value.replace(' ', '')
|
||||
.replace(',', '.').replace('\xa0', '')
|
||||
)
|
||||
# Allow 1 unit difference (rounding)
|
||||
if abs(machine_float - csv_float) <= 1.0:
|
||||
result_entry['match'] = True
|
||||
result_entry['rounding_diff'] = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
elif machine_value and not csv_value:
|
||||
# CSV is missing, use machine value
|
||||
result_entry['use_machine'] = True
|
||||
|
||||
results[field_key] = result_entry
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def parse_machine_code(
|
||||
tokens: list[TextToken],
|
||||
page_height: float,
|
||||
page_width: float | None = None,
|
||||
bottom_ratio: float = 0.35,
|
||||
) -> MachineCodeResult:
|
||||
"""
|
||||
Convenience function to parse machine code from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens
|
||||
page_height: Page height in points
|
||||
page_width: Page width in points (optional)
|
||||
bottom_ratio: Fraction of page to consider as bottom region
|
||||
|
||||
Returns:
|
||||
MachineCodeResult with extracted fields
|
||||
"""
|
||||
parser = MachineCodeParser(bottom_region_ratio=bottom_ratio)
|
||||
return parser.parse(tokens, page_height, page_width)
|
||||
405
packages/shared/shared/ocr/paddle_ocr.py
Normal file
405
packages/shared/shared/ocr/paddle_ocr.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
OCR Extraction Module using PaddleOCR
|
||||
|
||||
Extracts text tokens with bounding boxes from scanned PDFs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
import numpy as np
|
||||
|
||||
# Suppress PaddlePaddle reinitialization warnings
|
||||
os.environ.setdefault('GLOG_minloglevel', '2')
|
||||
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
|
||||
warnings.filterwarnings('ignore', message='.*reinitialization.*')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRToken:
|
||||
"""Represents an OCR-extracted text token with its bounding box."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
confidence: float
|
||||
page_no: int = 0
|
||||
|
||||
@property
|
||||
def x0(self) -> float:
|
||||
return self.bbox[0]
|
||||
|
||||
@property
|
||||
def y0(self) -> float:
|
||||
return self.bbox[1]
|
||||
|
||||
@property
|
||||
def x1(self) -> float:
|
||||
return self.bbox[2]
|
||||
|
||||
@property
|
||||
def y1(self) -> float:
|
||||
return self.bbox[3]
|
||||
|
||||
@property
|
||||
def center(self) -> tuple[float, float]:
|
||||
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Result from OCR extraction including tokens and preprocessed image."""
|
||||
tokens: list[OCRToken]
|
||||
output_img: np.ndarray | None = None # Preprocessed image from PaddleOCR
|
||||
|
||||
|
||||
class OCREngine:
|
||||
"""PaddleOCR wrapper for text extraction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lang: str = "en",
|
||||
det_model_dir: str | None = None,
|
||||
rec_model_dir: str | None = None,
|
||||
use_doc_orientation_classify: bool = True,
|
||||
use_doc_unwarping: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize OCR engine.
|
||||
|
||||
Args:
|
||||
lang: Language code ('en', 'sv', 'ch', etc.)
|
||||
det_model_dir: Custom detection model directory
|
||||
rec_model_dir: Custom recognition model directory
|
||||
use_doc_orientation_classify: Whether to auto-detect and correct document orientation.
|
||||
Default True to handle rotated documents.
|
||||
use_doc_unwarping: Whether to use UVDoc document unwarping for curved/warped documents.
|
||||
Default False to preserve original image layout,
|
||||
especially important for payment OCR lines at bottom.
|
||||
Enable for severely warped documents at the cost of potentially
|
||||
losing bottom content.
|
||||
|
||||
Note:
|
||||
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
|
||||
Use `paddle.set_device('gpu')` before initialization to force GPU.
|
||||
"""
|
||||
# Suppress warnings during import and initialization
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
|
||||
init_params = {
|
||||
'lang': lang,
|
||||
# Enable orientation classification to handle rotated documents
|
||||
'use_doc_orientation_classify': use_doc_orientation_classify,
|
||||
# Disable UVDoc unwarping to preserve original image layout
|
||||
# This prevents the bottom payment OCR line from being cut off
|
||||
# For severely warped documents, enable this but expect potential content loss
|
||||
'use_doc_unwarping': use_doc_unwarping,
|
||||
}
|
||||
if det_model_dir:
|
||||
init_params['text_detection_model_dir'] = det_model_dir
|
||||
if rec_model_dir:
|
||||
init_params['text_recognition_model_dir'] = rec_model_dir
|
||||
|
||||
self.ocr = PaddleOCR(**init_params)
|
||||
|
||||
def extract_from_image(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0,
|
||||
max_size: int = 2000,
|
||||
scale_to_pdf_points: float | None = None,
|
||||
scan_bottom_region: bool = True,
|
||||
bottom_region_ratio: float = 0.15
|
||||
) -> list[OCRToken]:
|
||||
"""
|
||||
Extract text tokens from an image.
|
||||
|
||||
Args:
|
||||
image: Image path or numpy array
|
||||
page_no: Page number for reference
|
||||
max_size: Maximum image dimension. Larger images will be scaled down
|
||||
to avoid OCR issues with PaddleOCR on large images.
|
||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||
to convert from pixel to PDF point coordinates.
|
||||
Use (72 / dpi) for images rendered at a specific DPI.
|
||||
scan_bottom_region: If True, also scan the bottom region separately to catch
|
||||
OCR payment lines that may be missed in full-page scan.
|
||||
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
|
||||
|
||||
Returns:
|
||||
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
|
||||
"""
|
||||
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
|
||||
tokens = result.tokens
|
||||
|
||||
# Optionally scan bottom region separately for Swedish OCR payment lines
|
||||
if scan_bottom_region:
|
||||
bottom_tokens = self._scan_bottom_region(
|
||||
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
|
||||
)
|
||||
tokens = self._merge_tokens(tokens, bottom_tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _scan_bottom_region(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int,
|
||||
max_size: int,
|
||||
scale_to_pdf_points: float | None,
|
||||
bottom_ratio: float
|
||||
) -> list[OCRToken]:
|
||||
"""Scan the bottom region of the image separately."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Load image if path
|
||||
if isinstance(image, (str, Path)):
|
||||
img = PILImage.open(str(image))
|
||||
img_array = np.array(img)
|
||||
else:
|
||||
img_array = image
|
||||
|
||||
h, w = img_array.shape[:2]
|
||||
crop_y = int(h * (1 - bottom_ratio))
|
||||
|
||||
# Crop bottom region
|
||||
bottom_crop = img_array[crop_y:h, :, :] if len(img_array.shape) == 3 else img_array[crop_y:h, :]
|
||||
|
||||
# OCR the cropped region (without recursive bottom scan to avoid infinite loop)
|
||||
result = self.extract_with_image(
|
||||
bottom_crop, page_no, max_size,
|
||||
scale_to_pdf_points=None,
|
||||
scan_bottom_region=False # Important: disable to prevent recursion
|
||||
)
|
||||
|
||||
# Adjust bbox y-coordinates to full image space
|
||||
adjusted_tokens = []
|
||||
for token in result.tokens:
|
||||
# Scale factor for coordinates
|
||||
scale = scale_to_pdf_points if scale_to_pdf_points else 1.0
|
||||
|
||||
adjusted_bbox = (
|
||||
token.bbox[0] * scale,
|
||||
(token.bbox[1] + crop_y) * scale,
|
||||
token.bbox[2] * scale,
|
||||
(token.bbox[3] + crop_y) * scale
|
||||
)
|
||||
adjusted_tokens.append(OCRToken(
|
||||
text=token.text,
|
||||
bbox=adjusted_bbox,
|
||||
confidence=token.confidence,
|
||||
page_no=token.page_no
|
||||
))
|
||||
|
||||
return adjusted_tokens
|
||||
|
||||
def _merge_tokens(
|
||||
self,
|
||||
main_tokens: list[OCRToken],
|
||||
bottom_tokens: list[OCRToken]
|
||||
) -> list[OCRToken]:
|
||||
"""Merge tokens from main scan and bottom region scan, removing duplicates."""
|
||||
if not bottom_tokens:
|
||||
return main_tokens
|
||||
|
||||
# Create a set of existing token texts for deduplication
|
||||
existing_texts = {t.text.strip() for t in main_tokens}
|
||||
|
||||
# Add bottom tokens that aren't duplicates
|
||||
merged = list(main_tokens)
|
||||
for token in bottom_tokens:
|
||||
if token.text.strip() not in existing_texts:
|
||||
merged.append(token)
|
||||
existing_texts.add(token.text.strip())
|
||||
|
||||
return merged
|
||||
|
||||
def extract_with_image(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0,
|
||||
max_size: int = 2000,
|
||||
scale_to_pdf_points: float | None = None,
|
||||
scan_bottom_region: bool = True,
|
||||
bottom_region_ratio: float = 0.15
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Extract text tokens from an image and return the preprocessed image.
|
||||
|
||||
PaddleOCR applies document preprocessing (unwarping, rotation, enhancement)
|
||||
and returns coordinates relative to the preprocessed image (output_img).
|
||||
This method returns both tokens and output_img so the caller can save
|
||||
the correct image that matches the coordinates.
|
||||
|
||||
Args:
|
||||
image: Image path or numpy array
|
||||
page_no: Page number for reference
|
||||
max_size: Maximum image dimension. Larger images will be scaled down
|
||||
to avoid OCR issues with PaddleOCR on large images.
|
||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||
to convert from pixel to PDF point coordinates.
|
||||
Use (72 / dpi) for images rendered at a specific DPI.
|
||||
scan_bottom_region: If True, also scan the bottom region separately to catch
|
||||
OCR payment lines that may be missed in full-page scan.
|
||||
bottom_region_ratio: Ratio of page height to scan as bottom region (default 0.15 = 15%)
|
||||
|
||||
Returns:
|
||||
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
|
||||
"""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Load image if path
|
||||
if isinstance(image, (str, Path)):
|
||||
img = PILImage.open(str(image))
|
||||
img_array = np.array(img)
|
||||
else:
|
||||
img_array = image
|
||||
|
||||
# Check if image needs scaling for OCR
|
||||
h, w = img_array.shape[:2]
|
||||
ocr_scale_factor = 1.0
|
||||
|
||||
if max(h, w) > max_size:
|
||||
ocr_scale_factor = max_size / max(h, w)
|
||||
new_w = int(w * ocr_scale_factor)
|
||||
new_h = int(h * ocr_scale_factor)
|
||||
# Resize image for OCR
|
||||
img = PILImage.fromarray(img_array)
|
||||
img = img.resize((new_w, new_h), PILImage.Resampling.LANCZOS)
|
||||
img_array = np.array(img)
|
||||
|
||||
# PaddleOCR 3.x uses predict() method instead of ocr()
|
||||
result = self.ocr.predict(img_array)
|
||||
|
||||
tokens = []
|
||||
output_img = None
|
||||
|
||||
if result:
|
||||
for item in result:
|
||||
# PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys'
|
||||
if isinstance(item, dict):
|
||||
rec_texts = item.get('rec_texts', [])
|
||||
rec_scores = item.get('rec_scores', [])
|
||||
dt_polys = item.get('dt_polys', [])
|
||||
|
||||
# Get output_img from doc_preprocessor_res
|
||||
# This is the preprocessed image that coordinates are relative to
|
||||
doc_preproc = item.get('doc_preprocessor_res', {})
|
||||
if isinstance(doc_preproc, dict):
|
||||
output_img = doc_preproc.get('output_img')
|
||||
|
||||
# Coordinates are relative to output_img (preprocessed image)
|
||||
# No rotation compensation needed - just use coordinates directly
|
||||
for text, score, poly in zip(rec_texts, rec_scores, dt_polys):
|
||||
# poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
x_coords = [float(p[0]) for p in poly]
|
||||
y_coords = [float(p[1]) for p in poly]
|
||||
|
||||
# Apply PDF points scale if requested
|
||||
if scale_to_pdf_points is not None:
|
||||
final_scale = scale_to_pdf_points
|
||||
else:
|
||||
final_scale = 1.0
|
||||
|
||||
bbox = (
|
||||
min(x_coords) * final_scale,
|
||||
min(y_coords) * final_scale,
|
||||
max(x_coords) * final_scale,
|
||||
max(y_coords) * final_scale
|
||||
)
|
||||
|
||||
tokens.append(OCRToken(
|
||||
text=text,
|
||||
bbox=bbox,
|
||||
confidence=float(score),
|
||||
page_no=page_no
|
||||
))
|
||||
elif isinstance(item, (list, tuple)) and len(item) == 2:
|
||||
# Legacy format: [[bbox_points], (text, confidence)]
|
||||
bbox_points, (text, confidence) = item
|
||||
|
||||
x_coords = [p[0] for p in bbox_points]
|
||||
y_coords = [p[1] for p in bbox_points]
|
||||
|
||||
# Apply PDF points scale if requested
|
||||
if scale_to_pdf_points is not None:
|
||||
final_scale = scale_to_pdf_points
|
||||
else:
|
||||
final_scale = 1.0
|
||||
|
||||
bbox = (
|
||||
min(x_coords) * final_scale,
|
||||
min(y_coords) * final_scale,
|
||||
max(x_coords) * final_scale,
|
||||
max(y_coords) * final_scale
|
||||
)
|
||||
|
||||
tokens.append(OCRToken(
|
||||
text=text,
|
||||
bbox=bbox,
|
||||
confidence=confidence,
|
||||
page_no=page_no
|
||||
))
|
||||
|
||||
# If no output_img was found, use the original input array
|
||||
if output_img is None:
|
||||
output_img = img_array
|
||||
|
||||
# Optionally scan bottom region separately for Swedish OCR payment lines
|
||||
if scan_bottom_region:
|
||||
bottom_tokens = self._scan_bottom_region(
|
||||
image, page_no, max_size, scale_to_pdf_points, bottom_region_ratio
|
||||
)
|
||||
tokens = self._merge_tokens(tokens, bottom_tokens)
|
||||
|
||||
return OCRResult(tokens=tokens, output_img=output_img)
|
||||
|
||||
def extract_from_pdf(
|
||||
self,
|
||||
pdf_path: str | Path,
|
||||
dpi: int = 300
|
||||
) -> Generator[list[OCRToken], None, None]:
|
||||
"""
|
||||
Extract text from all pages of a scanned PDF.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
dpi: Resolution for rendering
|
||||
|
||||
Yields:
|
||||
List of OCRToken for each page
|
||||
"""
|
||||
from ..pdf.renderer import render_pdf_to_images
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
|
||||
# Convert bytes to numpy array
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
|
||||
tokens = self.extract_from_image(image_array, page_no=page_no)
|
||||
yield tokens
|
||||
|
||||
|
||||
def extract_ocr_tokens(
|
||||
image_path: str | Path,
|
||||
lang: str = "en",
|
||||
page_no: int = 0
|
||||
) -> list[OCRToken]:
|
||||
"""
|
||||
Convenience function to extract OCR tokens from an image.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
lang: Language code
|
||||
page_no: Page number for reference
|
||||
|
||||
Returns:
|
||||
List of OCRToken objects
|
||||
"""
|
||||
engine = OCREngine(lang=lang)
|
||||
return engine.extract_from_image(image_path, page_no=page_no)
|
||||
12
packages/shared/shared/pdf/__init__.py
Normal file
12
packages/shared/shared/pdf/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .detector import is_text_pdf, get_pdf_type
|
||||
from .renderer import render_pdf_to_images
|
||||
from .extractor import extract_text_tokens, PDFDocument, Token
|
||||
|
||||
__all__ = [
|
||||
'is_text_pdf',
|
||||
'get_pdf_type',
|
||||
'render_pdf_to_images',
|
||||
'extract_text_tokens',
|
||||
'PDFDocument',
|
||||
'Token',
|
||||
]
|
||||
150
packages/shared/shared/pdf/detector.py
Normal file
150
packages/shared/shared/pdf/detector.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
PDF Type Detection Module
|
||||
|
||||
Automatically distinguishes between:
|
||||
- Text-based PDFs (digitally generated)
|
||||
- Scanned image PDFs
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
import fitz # PyMuPDF
|
||||
|
||||
|
||||
PDFType = Literal["text", "scanned", "mixed"]
|
||||
|
||||
|
||||
def extract_text_first_page(pdf_path: str | Path) -> str:
|
||||
"""Extract text from the first page of a PDF."""
|
||||
doc = fitz.open(pdf_path)
|
||||
if len(doc) == 0:
|
||||
return ""
|
||||
|
||||
first_page = doc[0]
|
||||
text = first_page.get_text()
|
||||
doc.close()
|
||||
return text
|
||||
|
||||
|
||||
def is_text_pdf(pdf_path: str | Path, min_chars: int = 30) -> bool:
|
||||
"""
|
||||
Check if PDF has extractable AND READABLE text layer.
|
||||
|
||||
Some PDFs have custom font encodings that produce garbled text.
|
||||
This function checks both the presence and readability of text.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
min_chars: Minimum characters to consider it a text PDF
|
||||
|
||||
Returns:
|
||||
True if PDF has readable text layer, False if scanned or garbled
|
||||
"""
|
||||
text = extract_text_first_page(pdf_path)
|
||||
stripped_text = text.strip()
|
||||
|
||||
# First check: enough characters (basic minimum)
|
||||
if len(stripped_text) <= min_chars:
|
||||
return False
|
||||
|
||||
# Second check: text readability
|
||||
# PDFs with custom font encoding often produce garbled text
|
||||
# Check if common invoice-related keywords are present
|
||||
text_lower = stripped_text.lower()
|
||||
invoice_keywords = [
|
||||
'faktura', 'invoice', 'datum', 'date', 'belopp', 'amount',
|
||||
'moms', 'vat', 'bankgiro', 'plusgiro', 'ocr', 'betala',
|
||||
'summa', 'total', 'pris', 'price', 'kr', 'sek'
|
||||
]
|
||||
found_keywords = sum(1 for kw in invoice_keywords if kw in text_lower)
|
||||
|
||||
# If at least 2 keywords found, likely readable text
|
||||
if found_keywords >= 2:
|
||||
return True
|
||||
|
||||
# Third check: minimum content threshold
|
||||
# A real text PDF invoice should have at least 200 chars of content
|
||||
# PDFs with only headers/footers (like "Brandsign") should use OCR
|
||||
if len(stripped_text) < 200:
|
||||
return False
|
||||
|
||||
# Fourth check: character readability ratio
|
||||
# Count printable ASCII and common Swedish/European characters
|
||||
readable_chars = 0
|
||||
total_chars = len(stripped_text)
|
||||
|
||||
for c in stripped_text:
|
||||
# Printable ASCII (32-126) or common Swedish/European chars
|
||||
if 32 <= ord(c) <= 126 or c in 'åäöÅÄÖéèêëÉÈÊËüÜ':
|
||||
readable_chars += 1
|
||||
|
||||
# If less than 70% readable, treat as garbled/scanned
|
||||
readable_ratio = readable_chars / total_chars if total_chars > 0 else 0
|
||||
if readable_ratio < 0.70:
|
||||
return False
|
||||
|
||||
# Fifth check: if no keywords found but passes basic readability,
|
||||
# require higher readability threshold (85%) or at least 1 keyword
|
||||
# This catches garbled PDFs that have high ASCII ratio but unreadable content
|
||||
# (e.g., custom font encoding that maps to different characters)
|
||||
if found_keywords == 0 and readable_ratio < 0.85:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_pdf_type(pdf_path: str | Path) -> PDFType:
|
||||
"""
|
||||
Determine the PDF type.
|
||||
|
||||
Returns:
|
||||
'text' - Has extractable text layer
|
||||
'scanned' - Image-based, needs OCR
|
||||
'mixed' - Some pages have text, some don't
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
|
||||
if len(doc) == 0:
|
||||
doc.close()
|
||||
return "scanned"
|
||||
|
||||
text_pages = 0
|
||||
total_pages = len(doc)
|
||||
for page in doc:
|
||||
text = page.get_text().strip()
|
||||
if len(text) > 30:
|
||||
text_pages += 1
|
||||
|
||||
doc.close()
|
||||
|
||||
if text_pages == total_pages:
|
||||
return "text"
|
||||
elif text_pages == 0:
|
||||
return "scanned"
|
||||
else:
|
||||
return "mixed"
|
||||
|
||||
|
||||
def get_page_info(pdf_path: str | Path) -> list[dict]:
|
||||
"""
|
||||
Get information about each page in the PDF.
|
||||
|
||||
Returns:
|
||||
List of dicts with page info (number, width, height, has_text)
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
pages = []
|
||||
|
||||
for i, page in enumerate(doc):
|
||||
text = page.get_text().strip()
|
||||
rect = page.rect
|
||||
pages.append({
|
||||
"page_no": i,
|
||||
"width": rect.width,
|
||||
"height": rect.height,
|
||||
"has_text": len(text) > 30,
|
||||
"char_count": len(text)
|
||||
})
|
||||
|
||||
doc.close()
|
||||
return pages
|
||||
323
packages/shared/shared/pdf/extractor.py
Normal file
323
packages/shared/shared/pdf/extractor.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
PDF Text Extraction Module
|
||||
|
||||
Extracts text tokens with bounding boxes from text-layer PDFs.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional
|
||||
import fitz # PyMuPDF
|
||||
|
||||
from .detector import is_text_pdf as _is_text_pdf_standalone
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
"""Represents a text token with its bounding box."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
page_no: int
|
||||
|
||||
@property
|
||||
def x0(self) -> float:
|
||||
return self.bbox[0]
|
||||
|
||||
@property
|
||||
def y0(self) -> float:
|
||||
return self.bbox[1]
|
||||
|
||||
@property
|
||||
def x1(self) -> float:
|
||||
return self.bbox[2]
|
||||
|
||||
@property
|
||||
def y1(self) -> float:
|
||||
return self.bbox[3]
|
||||
|
||||
@property
|
||||
def width(self) -> float:
|
||||
return self.x1 - self.x0
|
||||
|
||||
@property
|
||||
def height(self) -> float:
|
||||
return self.y1 - self.y0
|
||||
|
||||
@property
|
||||
def center(self) -> tuple[float, float]:
|
||||
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||
|
||||
|
||||
class PDFDocument:
|
||||
"""
|
||||
Context manager for efficient PDF document handling.
|
||||
|
||||
Caches the open document handle to avoid repeated open/close cycles.
|
||||
Use this when you need to perform multiple operations on the same PDF.
|
||||
"""
|
||||
|
||||
def __init__(self, pdf_path: str | Path):
|
||||
self.pdf_path = Path(pdf_path)
|
||||
self._doc: Optional[fitz.Document] = None
|
||||
self._dimensions_cache: dict[int, tuple[float, float]] = {}
|
||||
|
||||
def __enter__(self) -> 'PDFDocument':
|
||||
self._doc = fitz.open(self.pdf_path)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._doc:
|
||||
self._doc.close()
|
||||
self._doc = None
|
||||
|
||||
@property
|
||||
def doc(self) -> fitz.Document:
|
||||
if self._doc is None:
|
||||
raise RuntimeError("PDFDocument must be used within a context manager")
|
||||
return self._doc
|
||||
|
||||
@property
|
||||
def page_count(self) -> int:
|
||||
return len(self.doc)
|
||||
|
||||
def is_text_pdf(self, min_chars: int = 30) -> bool:
|
||||
"""
|
||||
Check if PDF has extractable AND READABLE text layer.
|
||||
|
||||
Uses the improved detection from detector.py that also checks
|
||||
for garbled text (custom font encoding issues).
|
||||
"""
|
||||
return _is_text_pdf_standalone(self.pdf_path, min_chars)
|
||||
|
||||
def get_page_dimensions(self, page_no: int = 0) -> tuple[float, float]:
|
||||
"""Get page dimensions in points (cached)."""
|
||||
if page_no not in self._dimensions_cache:
|
||||
page = self.doc[page_no]
|
||||
rect = page.rect
|
||||
self._dimensions_cache[page_no] = (rect.width, rect.height)
|
||||
return self._dimensions_cache[page_no]
|
||||
|
||||
def get_render_dimensions(self, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
|
||||
"""Get rendered image dimensions in pixels."""
|
||||
width, height = self.get_page_dimensions(page_no)
|
||||
zoom = dpi / 72
|
||||
return int(width * zoom), int(height * zoom)
|
||||
|
||||
def extract_text_tokens(self, page_no: int) -> Generator[Token, None, None]:
|
||||
"""Extract text tokens from a specific page."""
|
||||
page = self.doc[page_no]
|
||||
|
||||
text_dict = page.get_text("dict")
|
||||
tokens_found = False
|
||||
|
||||
for block in text_dict.get("blocks", []):
|
||||
if block.get("type") != 0:
|
||||
continue
|
||||
|
||||
for line in block.get("lines", []):
|
||||
for span in line.get("spans", []):
|
||||
text = span.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
bbox = span.get("bbox")
|
||||
if bbox and all(abs(b) < 1e9 for b in bbox):
|
||||
tokens_found = True
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=tuple(bbox),
|
||||
page_no=page_no
|
||||
)
|
||||
|
||||
# Fallback: if dict mode failed, use words mode
|
||||
if not tokens_found:
|
||||
words = page.get_text("words")
|
||||
for word_info in words:
|
||||
x0, y0, x1, y1, text, *_ = word_info
|
||||
text = text.strip()
|
||||
if text:
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=(x0, y0, x1, y1),
|
||||
page_no=page_no
|
||||
)
|
||||
|
||||
def render_page(self, page_no: int, output_path: Path, dpi: int = 300) -> Path:
|
||||
"""Render a page to an image file."""
|
||||
zoom = dpi / 72
|
||||
matrix = fitz.Matrix(zoom, zoom)
|
||||
|
||||
page = self.doc[page_no]
|
||||
pix = page.get_pixmap(matrix=matrix)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pix.save(str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
def render_all_pages(
|
||||
self,
|
||||
output_dir: Path,
|
||||
dpi: int = 300
|
||||
) -> Generator[tuple[int, Path], None, None]:
|
||||
"""Render all pages to images."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
pdf_name = self.pdf_path.stem
|
||||
|
||||
zoom = dpi / 72
|
||||
matrix = fitz.Matrix(zoom, zoom)
|
||||
|
||||
for page_no in range(self.page_count):
|
||||
page = self.doc[page_no]
|
||||
pix = page.get_pixmap(matrix=matrix)
|
||||
|
||||
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
yield page_no, image_path
|
||||
|
||||
|
||||
def extract_text_tokens(
|
||||
pdf_path: str | Path,
|
||||
page_no: int | None = None
|
||||
) -> Generator[Token, None, None]:
|
||||
"""
|
||||
Extract text tokens with bounding boxes from PDF.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
page_no: Specific page to extract (None for all pages)
|
||||
|
||||
Yields:
|
||||
Token objects with text and bbox
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
|
||||
pages_to_process = [page_no] if page_no is not None else range(len(doc))
|
||||
|
||||
for pg_no in pages_to_process:
|
||||
page = doc[pg_no]
|
||||
|
||||
# Get text with position info using "dict" mode
|
||||
text_dict = page.get_text("dict")
|
||||
|
||||
tokens_found = False
|
||||
for block in text_dict.get("blocks", []):
|
||||
if block.get("type") != 0: # Skip non-text blocks
|
||||
continue
|
||||
|
||||
for line in block.get("lines", []):
|
||||
for span in line.get("spans", []):
|
||||
text = span.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
bbox = span.get("bbox")
|
||||
# Check for corrupted bbox (overflow values)
|
||||
if bbox and all(abs(b) < 1e9 for b in bbox):
|
||||
tokens_found = True
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=tuple(bbox),
|
||||
page_no=pg_no
|
||||
)
|
||||
|
||||
# Fallback: if dict mode failed, use words mode
|
||||
if not tokens_found:
|
||||
words = page.get_text("words")
|
||||
for word_info in words:
|
||||
x0, y0, x1, y1, text, *_ = word_info
|
||||
text = text.strip()
|
||||
if text:
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=(x0, y0, x1, y1),
|
||||
page_no=pg_no
|
||||
)
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
def extract_words(
|
||||
pdf_path: str | Path,
|
||||
page_no: int | None = None
|
||||
) -> Generator[Token, None, None]:
|
||||
"""
|
||||
Extract individual words with bounding boxes.
|
||||
|
||||
Uses PyMuPDF's word extraction which splits text into words.
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
|
||||
pages_to_process = [page_no] if page_no is not None else range(len(doc))
|
||||
|
||||
for pg_no in pages_to_process:
|
||||
page = doc[pg_no]
|
||||
|
||||
# get_text("words") returns list of (x0, y0, x1, y1, word, block_no, line_no, word_no)
|
||||
words = page.get_text("words")
|
||||
|
||||
for word_info in words:
|
||||
x0, y0, x1, y1, text, *_ = word_info
|
||||
text = text.strip()
|
||||
if text:
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=(x0, y0, x1, y1),
|
||||
page_no=pg_no
|
||||
)
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
def extract_lines(
|
||||
pdf_path: str | Path,
|
||||
page_no: int | None = None
|
||||
) -> Generator[Token, None, None]:
|
||||
"""
|
||||
Extract text lines with bounding boxes.
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
|
||||
pages_to_process = [page_no] if page_no is not None else range(len(doc))
|
||||
|
||||
for pg_no in pages_to_process:
|
||||
page = doc[pg_no]
|
||||
text_dict = page.get_text("dict")
|
||||
|
||||
for block in text_dict.get("blocks", []):
|
||||
if block.get("type") != 0:
|
||||
continue
|
||||
|
||||
for line in block.get("lines", []):
|
||||
spans = line.get("spans", [])
|
||||
if not spans:
|
||||
continue
|
||||
|
||||
# Combine all spans in the line
|
||||
line_text = " ".join(s.get("text", "") for s in spans).strip()
|
||||
if not line_text:
|
||||
continue
|
||||
|
||||
# Get line bbox from all spans
|
||||
x0 = min(s["bbox"][0] for s in spans)
|
||||
y0 = min(s["bbox"][1] for s in spans)
|
||||
x1 = max(s["bbox"][2] for s in spans)
|
||||
y1 = max(s["bbox"][3] for s in spans)
|
||||
|
||||
yield Token(
|
||||
text=line_text,
|
||||
bbox=(x0, y0, x1, y1),
|
||||
page_no=pg_no
|
||||
)
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
def get_page_dimensions(pdf_path: str | Path, page_no: int = 0) -> tuple[float, float]:
|
||||
"""Get the dimensions of a PDF page in points."""
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[page_no]
|
||||
rect = page.rect
|
||||
doc.close()
|
||||
return rect.width, rect.height
|
||||
117
packages/shared/shared/pdf/renderer.py
Normal file
117
packages/shared/shared/pdf/renderer.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
PDF Rendering Module
|
||||
|
||||
Converts PDF pages to images for YOLO training.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
import fitz # PyMuPDF
|
||||
|
||||
|
||||
def render_pdf_to_images(
|
||||
pdf_path: str | Path,
|
||||
output_dir: str | Path | None = None,
|
||||
dpi: int = 300,
|
||||
image_format: str = "png"
|
||||
) -> Generator[tuple[int, Path | bytes], None, None]:
|
||||
"""
|
||||
Render PDF pages to images.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
output_dir: Directory to save images (if None, returns bytes)
|
||||
dpi: Resolution for rendering (default 300)
|
||||
image_format: Output format ('png' or 'jpg')
|
||||
|
||||
Yields:
|
||||
Tuple of (page_number, image_path or image_bytes)
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
|
||||
# Calculate zoom factor for desired DPI (72 is base DPI for PDF)
|
||||
zoom = dpi / 72
|
||||
matrix = fitz.Matrix(zoom, zoom)
|
||||
|
||||
if output_dir:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for page_no, page in enumerate(doc):
|
||||
# Render page to pixmap
|
||||
pix = page.get_pixmap(matrix=matrix)
|
||||
|
||||
if output_dir:
|
||||
# Save to file
|
||||
ext = "jpg" if image_format.lower() in ("jpg", "jpeg") else "png"
|
||||
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.{ext}"
|
||||
|
||||
if ext == "jpg":
|
||||
pix.save(str(image_path), "jpeg")
|
||||
else:
|
||||
pix.save(str(image_path))
|
||||
|
||||
yield page_no, image_path
|
||||
else:
|
||||
# Return bytes
|
||||
if image_format.lower() in ("jpg", "jpeg"):
|
||||
yield page_no, pix.tobytes("jpeg")
|
||||
else:
|
||||
yield page_no, pix.tobytes("png")
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
def render_page_to_image(
|
||||
pdf_path: str | Path,
|
||||
page_no: int,
|
||||
dpi: int = 300
|
||||
) -> bytes:
|
||||
"""
|
||||
Render a single page to image bytes.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
page_no: Page number (0-indexed)
|
||||
dpi: Resolution for rendering
|
||||
|
||||
Returns:
|
||||
PNG image bytes
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
|
||||
if page_no >= len(doc):
|
||||
doc.close()
|
||||
raise ValueError(f"Page {page_no} does not exist (PDF has {len(doc)} pages)")
|
||||
|
||||
zoom = dpi / 72
|
||||
matrix = fitz.Matrix(zoom, zoom)
|
||||
|
||||
page = doc[page_no]
|
||||
pix = page.get_pixmap(matrix=matrix)
|
||||
image_bytes = pix.tobytes("png")
|
||||
|
||||
doc.close()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def get_render_dimensions(pdf_path: str | Path, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
|
||||
"""
|
||||
Get the dimensions of a rendered page.
|
||||
|
||||
Returns:
|
||||
(width, height) in pixels
|
||||
"""
|
||||
doc = fitz.open(pdf_path)
|
||||
page = doc[page_no]
|
||||
|
||||
zoom = dpi / 72
|
||||
rect = page.rect
|
||||
|
||||
width = int(rect.width * zoom)
|
||||
height = int(rect.height * zoom)
|
||||
|
||||
doc.close()
|
||||
return width, height
|
||||
34
packages/shared/shared/utils/__init__.py
Normal file
34
packages/shared/shared/utils/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Shared utilities for invoice field extraction and matching.
|
||||
|
||||
This module provides common functionality used by both:
|
||||
- Inference stage (field_extractor.py) - extracting values from OCR text
|
||||
- Matching stage (normalizer.py) - generating variants for CSV matching
|
||||
|
||||
Modules:
|
||||
- TextCleaner: Unicode normalization and OCR error correction
|
||||
- FormatVariants: Generate format variants for matching
|
||||
- FieldValidators: Validate field values (Luhn, dates, amounts)
|
||||
- FuzzyMatcher: Fuzzy string matching with OCR awareness
|
||||
- OCRCorrections: Comprehensive OCR error correction
|
||||
- ContextExtractor: Context-aware field extraction
|
||||
"""
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
from .format_variants import FormatVariants
|
||||
from .validators import FieldValidators
|
||||
from .fuzzy_matcher import FuzzyMatcher, FuzzyMatchResult
|
||||
from .ocr_corrections import OCRCorrections, CorrectionResult
|
||||
from .context_extractor import ContextExtractor, ExtractionCandidate
|
||||
|
||||
__all__ = [
|
||||
'TextCleaner',
|
||||
'FormatVariants',
|
||||
'FieldValidators',
|
||||
'FuzzyMatcher',
|
||||
'FuzzyMatchResult',
|
||||
'OCRCorrections',
|
||||
'CorrectionResult',
|
||||
'ContextExtractor',
|
||||
'ExtractionCandidate',
|
||||
]
|
||||
433
packages/shared/shared/utils/context_extractor.py
Normal file
433
packages/shared/shared/utils/context_extractor.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Context-Aware Extraction Module
|
||||
|
||||
Extracts field values using contextual cues and label detection.
|
||||
Improves extraction accuracy by understanding the semantic context.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, NamedTuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
from .validators import FieldValidators
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionCandidate:
|
||||
"""A candidate extracted value with metadata."""
|
||||
value: str
|
||||
raw_text: str
|
||||
context_label: str
|
||||
confidence: float
|
||||
position: int # Character position in source text
|
||||
extraction_method: str # 'label', 'pattern', 'proximity'
|
||||
|
||||
|
||||
class ContextExtractor:
|
||||
"""
|
||||
Context-aware field extraction.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Label detection - finds values after field labels
|
||||
2. Pattern matching - uses field-specific regex patterns
|
||||
3. Proximity analysis - finds values near related terms
|
||||
4. Validation filtering - removes invalid candidates
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Swedish Label Patterns (what appears before the value)
|
||||
# =========================================================================
|
||||
|
||||
LABEL_PATTERNS = {
|
||||
'InvoiceNumber': [
|
||||
# Swedish
|
||||
r'(?:faktura|fakt)\.?\s*(?:nr|nummer|#)?[:\s]*',
|
||||
r'(?:fakturanummer|fakturanr)[:\s]*',
|
||||
r'(?:vår\s+referens)[:\s]*',
|
||||
# English
|
||||
r'(?:invoice)\s*(?:no|number|#)?[:\s]*',
|
||||
r'inv[.:\s]*#?',
|
||||
],
|
||||
'Amount': [
|
||||
# Swedish
|
||||
r'(?:att\s+)?betala[:\s]*',
|
||||
r'(?:total|totalt|summa)[:\s]*',
|
||||
r'(?:belopp)[:\s]*',
|
||||
r'(?:slutsumma)[:\s]*',
|
||||
r'(?:att\s+erlägga)[:\s]*',
|
||||
# English
|
||||
r'(?:total|amount|sum)[:\s]*',
|
||||
r'(?:amount\s+due)[:\s]*',
|
||||
],
|
||||
'InvoiceDate': [
|
||||
# Swedish
|
||||
r'(?:faktura)?datum[:\s]*',
|
||||
r'(?:fakt\.?\s*datum)[:\s]*',
|
||||
# English
|
||||
r'(?:invoice\s+)?date[:\s]*',
|
||||
],
|
||||
'InvoiceDueDate': [
|
||||
# Swedish
|
||||
r'(?:förfall(?:o)?datum)[:\s]*',
|
||||
r'(?:betalas\s+senast)[:\s]*',
|
||||
r'(?:sista\s+betalningsdag)[:\s]*',
|
||||
r'(?:förfaller)[:\s]*',
|
||||
# English
|
||||
r'(?:due\s+date)[:\s]*',
|
||||
r'(?:payment\s+due)[:\s]*',
|
||||
],
|
||||
'OCR': [
|
||||
r'(?:ocr)[:\s]*',
|
||||
r'(?:ocr\s*-?\s*nummer)[:\s]*',
|
||||
r'(?:referens(?:nummer)?)[:\s]*',
|
||||
r'(?:betalningsreferens)[:\s]*',
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'(?:bankgiro|bg)[:\s]*',
|
||||
r'(?:bank\s*giro)[:\s]*',
|
||||
],
|
||||
'Plusgiro': [
|
||||
r'(?:plusgiro|pg)[:\s]*',
|
||||
r'(?:plus\s*giro)[:\s]*',
|
||||
r'(?:postgiro)[:\s]*',
|
||||
],
|
||||
'supplier_organisation_number': [
|
||||
r'(?:org\.?\s*(?:nr|nummer)?)[:\s]*',
|
||||
r'(?:organisationsnummer)[:\s]*',
|
||||
r'(?:org\.?\s*id)[:\s]*',
|
||||
r'(?:vat\s*(?:no|number|nr)?)[:\s]*',
|
||||
r'(?:moms(?:reg)?\.?\s*(?:nr|nummer)?)[:\s]*',
|
||||
r'(?:se)[:\s]*', # VAT prefix
|
||||
],
|
||||
'customer_number': [
|
||||
r'(?:kund(?:nr|nummer)?)[:\s]*',
|
||||
r'(?:kundnummer)[:\s]*',
|
||||
r'(?:customer\s*(?:no|number|id)?)[:\s]*',
|
||||
r'(?:er\s+referens)[:\s]*',
|
||||
],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Value Patterns (what the value looks like)
|
||||
# =========================================================================
|
||||
|
||||
VALUE_PATTERNS = {
|
||||
'InvoiceNumber': [
|
||||
r'[A-Z]{0,3}\d{3,15}', # Alphanumeric: INV12345
|
||||
r'\d{3,15}', # Pure digits
|
||||
r'20\d{2}[-/]\d{3,8}', # Year prefix: 2024-001
|
||||
],
|
||||
'Amount': [
|
||||
r'\d{1,3}(?:[\s.]\d{3})*[,]\d{2}', # Swedish: 1 234,56
|
||||
r'\d{1,3}(?:[,]\d{3})*[.]\d{2}', # US: 1,234.56
|
||||
r'\d+[,.]\d{2}', # Simple: 123,45
|
||||
r'\d+', # Integer
|
||||
],
|
||||
'InvoiceDate': [
|
||||
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}', # ISO-like
|
||||
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}', # European
|
||||
r'\d{8}', # Compact YYYYMMDD
|
||||
],
|
||||
'InvoiceDueDate': [
|
||||
r'\d{4}[-/.]\d{1,2}[-/.]\d{1,2}',
|
||||
r'\d{1,2}[-/.]\d{1,2}[-/.]\d{4}',
|
||||
r'\d{8}',
|
||||
],
|
||||
'OCR': [
|
||||
r'\d{10,25}', # Long digit sequence
|
||||
],
|
||||
'Bankgiro': [
|
||||
r'\d{3,4}[-\s]?\d{4}', # XXX-XXXX or XXXX-XXXX
|
||||
r'\d{7,8}', # Without separator
|
||||
],
|
||||
'Plusgiro': [
|
||||
r'\d{1,7}[-\s]?\d', # XXXXXXX-X
|
||||
r'\d{2,8}', # Without separator
|
||||
],
|
||||
'supplier_organisation_number': [
|
||||
r'\d{6}[-\s]?\d{4}', # NNNNNN-NNNN
|
||||
r'\d{10}', # Without separator
|
||||
r'SE\s?\d{10,12}(?:\s?01)?', # VAT format
|
||||
],
|
||||
'customer_number': [
|
||||
r'[A-Z]{0,5}\s?[-]?\s?\d{1,10}', # EMM 256-6
|
||||
r'\d{3,15}', # Pure digits
|
||||
],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Extraction Methods
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def extract_with_label(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> list[ExtractionCandidate]:
|
||||
"""
|
||||
Extract field values by finding labels and taking following values.
|
||||
|
||||
Example: "Fakturanummer: 12345" -> extracts "12345"
|
||||
"""
|
||||
candidates = []
|
||||
label_patterns = cls.LABEL_PATTERNS.get(field_name, [])
|
||||
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
|
||||
|
||||
for label_pattern in label_patterns:
|
||||
for value_pattern in value_patterns:
|
||||
# Combine label + value patterns
|
||||
full_pattern = f'({label_pattern})({value_pattern})'
|
||||
matches = re.finditer(full_pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
label = match.group(1).strip()
|
||||
value = match.group(2).strip()
|
||||
|
||||
# Validate if requested
|
||||
if validate and not cls._validate_value(field_name, value):
|
||||
continue
|
||||
|
||||
# Calculate confidence based on label specificity
|
||||
confidence = cls._calculate_label_confidence(label, field_name)
|
||||
|
||||
candidates.append(ExtractionCandidate(
|
||||
value=value,
|
||||
raw_text=match.group(0),
|
||||
context_label=label,
|
||||
confidence=confidence,
|
||||
position=match.start(),
|
||||
extraction_method='label'
|
||||
))
|
||||
|
||||
return candidates
|
||||
|
||||
@classmethod
|
||||
def extract_with_pattern(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> list[ExtractionCandidate]:
|
||||
"""
|
||||
Extract field values using only value patterns (no label required).
|
||||
|
||||
This is a fallback when no labels are found.
|
||||
"""
|
||||
candidates = []
|
||||
value_patterns = cls.VALUE_PATTERNS.get(field_name, [])
|
||||
|
||||
for pattern in value_patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
value = match.group(0).strip()
|
||||
|
||||
# Validate if requested
|
||||
if validate and not cls._validate_value(field_name, value):
|
||||
continue
|
||||
|
||||
# Lower confidence for pattern-only extraction
|
||||
confidence = 0.6
|
||||
|
||||
candidates.append(ExtractionCandidate(
|
||||
value=value,
|
||||
raw_text=value,
|
||||
context_label='',
|
||||
confidence=confidence,
|
||||
position=match.start(),
|
||||
extraction_method='pattern'
|
||||
))
|
||||
|
||||
return candidates
|
||||
|
||||
@classmethod
|
||||
def extract_field(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> list[ExtractionCandidate]:
|
||||
"""
|
||||
Extract all candidate values for a field using multiple strategies.
|
||||
|
||||
Returns candidates sorted by confidence (highest first).
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
# Strategy 1: Label-based extraction (highest confidence)
|
||||
label_candidates = cls.extract_with_label(text, field_name, validate)
|
||||
candidates.extend(label_candidates)
|
||||
|
||||
# Strategy 2: Pattern-based extraction (fallback)
|
||||
if not label_candidates:
|
||||
pattern_candidates = cls.extract_with_pattern(text, field_name, validate)
|
||||
candidates.extend(pattern_candidates)
|
||||
|
||||
# Remove duplicates (same value, keep highest confidence)
|
||||
seen_values = {}
|
||||
for candidate in candidates:
|
||||
normalized = TextCleaner.normalize_for_comparison(candidate.value)
|
||||
if normalized not in seen_values or candidate.confidence > seen_values[normalized].confidence:
|
||||
seen_values[normalized] = candidate
|
||||
|
||||
# Sort by confidence
|
||||
result = sorted(seen_values.values(), key=lambda x: x.confidence, reverse=True)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_best(
|
||||
cls,
|
||||
text: str,
|
||||
field_name: str,
|
||||
validate: bool = True
|
||||
) -> Optional[ExtractionCandidate]:
|
||||
"""
|
||||
Extract the best (highest confidence) candidate for a field.
|
||||
"""
|
||||
candidates = cls.extract_field(text, field_name, validate)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
@classmethod
|
||||
def extract_all_fields(cls, text: str) -> dict[str, list[ExtractionCandidate]]:
|
||||
"""
|
||||
Extract all known fields from text.
|
||||
|
||||
Returns a dictionary mapping field names to their candidates.
|
||||
"""
|
||||
results = {}
|
||||
for field_name in cls.LABEL_PATTERNS.keys():
|
||||
candidates = cls.extract_field(text, field_name)
|
||||
if candidates:
|
||||
results[field_name] = candidates
|
||||
return results
|
||||
|
||||
# =========================================================================
|
||||
# Helper Methods
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def _validate_value(cls, field_name: str, value: str) -> bool:
|
||||
"""Validate a value based on field type."""
|
||||
field_lower = field_name.lower()
|
||||
|
||||
if 'date' in field_lower:
|
||||
return FieldValidators.is_valid_date(value)
|
||||
elif 'amount' in field_lower:
|
||||
return FieldValidators.is_valid_amount(value)
|
||||
elif 'bankgiro' in field_lower:
|
||||
# Basic format check, not Luhn
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return 7 <= len(digits) <= 8
|
||||
elif 'plusgiro' in field_lower:
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return 2 <= len(digits) <= 8
|
||||
elif 'ocr' in field_lower:
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return 10 <= len(digits) <= 25
|
||||
elif 'org' in field_lower:
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
return len(digits) >= 10
|
||||
else:
|
||||
# For other fields, just check it's not empty
|
||||
return bool(value.strip())
|
||||
|
||||
@classmethod
|
||||
def _calculate_label_confidence(cls, label: str, field_name: str) -> float:
|
||||
"""
|
||||
Calculate confidence based on how specific the label is.
|
||||
|
||||
More specific labels = higher confidence.
|
||||
"""
|
||||
label_lower = label.lower()
|
||||
|
||||
# Very specific labels
|
||||
very_specific = {
|
||||
'InvoiceNumber': ['fakturanummer', 'invoice number', 'fakturanr'],
|
||||
'Amount': ['att betala', 'slutsumma', 'amount due'],
|
||||
'InvoiceDate': ['fakturadatum', 'invoice date'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfallodag', 'due date'],
|
||||
'OCR': ['ocr', 'betalningsreferens'],
|
||||
'Bankgiro': ['bankgiro'],
|
||||
'Plusgiro': ['plusgiro', 'postgiro'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org nummer'],
|
||||
'customer_number': ['kundnummer', 'customer number'],
|
||||
}
|
||||
|
||||
# Check for very specific match
|
||||
if field_name in very_specific:
|
||||
for specific in very_specific[field_name]:
|
||||
if specific in label_lower:
|
||||
return 0.95
|
||||
|
||||
# Moderately specific
|
||||
moderate = {
|
||||
'InvoiceNumber': ['faktura', 'invoice', 'nr'],
|
||||
'Amount': ['total', 'summa', 'belopp'],
|
||||
'InvoiceDate': ['datum', 'date'],
|
||||
'InvoiceDueDate': ['förfall', 'due'],
|
||||
}
|
||||
|
||||
if field_name in moderate:
|
||||
for mod in moderate[field_name]:
|
||||
if mod in label_lower:
|
||||
return 0.85
|
||||
|
||||
# Generic match
|
||||
return 0.75
|
||||
|
||||
@classmethod
|
||||
def find_field_context(cls, text: str, position: int, window: int = 50) -> str:
|
||||
"""
|
||||
Get the surrounding context for a position in text.
|
||||
|
||||
Useful for understanding what field a value belongs to.
|
||||
"""
|
||||
start = max(0, position - window)
|
||||
end = min(len(text), position + window)
|
||||
return text[start:end]
|
||||
|
||||
@classmethod
|
||||
def identify_field_type(cls, text: str, value: str) -> Optional[str]:
|
||||
"""
|
||||
Try to identify what field type a value belongs to based on context.
|
||||
|
||||
Looks at text surrounding the value to find labels.
|
||||
"""
|
||||
# Find the value in text
|
||||
pos = text.find(value)
|
||||
if pos == -1:
|
||||
return None
|
||||
|
||||
# Get context before the value
|
||||
context_before = text[max(0, pos - 50):pos].lower()
|
||||
|
||||
# Check each field's labels
|
||||
for field_name, patterns in cls.LABEL_PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, context_before, re.IGNORECASE):
|
||||
return field_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience functions
|
||||
# =========================================================================
|
||||
|
||||
def extract_field_with_context(text: str, field_name: str) -> Optional[str]:
|
||||
"""Convenience function to extract a field value."""
|
||||
candidate = ContextExtractor.extract_best(text, field_name)
|
||||
return candidate.value if candidate else None
|
||||
|
||||
|
||||
def extract_all_with_context(text: str) -> dict[str, str]:
|
||||
"""Convenience function to extract all fields."""
|
||||
all_candidates = ContextExtractor.extract_all_fields(text)
|
||||
return {
|
||||
field: candidates[0].value
|
||||
for field, candidates in all_candidates.items()
|
||||
if candidates
|
||||
}
|
||||
610
packages/shared/shared/utils/format_variants.py
Normal file
610
packages/shared/shared/utils/format_variants.py
Normal file
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Format Variants Generator
|
||||
|
||||
Generates multiple format variants for invoice field values.
|
||||
Used by both inference (to try different extractions) and matching (to match CSV values).
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class FormatVariants:
|
||||
"""
|
||||
Generates format variants for different field types.
|
||||
|
||||
The same logic is used for:
|
||||
- Inference: trying different formats to extract a value
|
||||
- Matching: generating variants of CSV values to match against OCR text
|
||||
"""
|
||||
|
||||
# Swedish month names for date parsing
|
||||
SWEDISH_MONTHS = {
|
||||
'januari': '01', 'jan': '01',
|
||||
'februari': '02', 'feb': '02',
|
||||
'mars': '03', 'mar': '03',
|
||||
'april': '04', 'apr': '04',
|
||||
'maj': '05',
|
||||
'juni': '06', 'jun': '06',
|
||||
'juli': '07', 'jul': '07',
|
||||
'augusti': '08', 'aug': '08',
|
||||
'september': '09', 'sep': '09', 'sept': '09',
|
||||
'oktober': '10', 'okt': '10',
|
||||
'november': '11', 'nov': '11',
|
||||
'december': '12', 'dec': '12',
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Organization Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def organisation_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate all format variants for Swedish organization number.
|
||||
|
||||
Input formats handled:
|
||||
- "556123-4567" (standard with hyphen)
|
||||
- "5561234567" (no hyphen)
|
||||
- "SE556123456701" (VAT format)
|
||||
- "SE 556123-4567 01" (VAT with spaces)
|
||||
|
||||
Returns all possible variants for matching.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
value_upper = value.upper()
|
||||
variants = set()
|
||||
|
||||
# 提取纯数字
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
# 如果是 VAT 格式,提取中间的 org number
|
||||
# SE + 10 digits + 01 = "SE556123456701"
|
||||
if value_upper.startswith('SE') and len(digits) == 12 and digits.endswith('01'):
|
||||
# VAT format: SE + org_number + 01
|
||||
digits = digits[:10]
|
||||
elif digits.startswith('46') and len(digits) == 14:
|
||||
# SE prefix in numeric (46 is SE in phone code): 46 + 10 digits + 01
|
||||
digits = digits[2:12]
|
||||
|
||||
if len(digits) == 12:
|
||||
# 12 位数字可能是带世纪前缀的: NNNNNNNN-NNNN (19556123-4567)
|
||||
variants.add(value)
|
||||
variants.add(digits) # 195561234567
|
||||
# 带横线格式
|
||||
variants.add(f"{digits[:8]}-{digits[8:]}") # 19556123-4567
|
||||
# 提取后 10 位作为标准 org number
|
||||
short_digits = digits[2:] # 5561234567
|
||||
variants.add(short_digits)
|
||||
variants.add(f"{short_digits[:6]}-{short_digits[6:]}") # 556123-4567
|
||||
# VAT 格式
|
||||
variants.add(f"SE{short_digits}01") # SE556123456701
|
||||
return list(v for v in variants if v)
|
||||
|
||||
if len(digits) != 10:
|
||||
# 如果不是标准 10 位,返回原始值和清洗后的变体
|
||||
variants.add(value)
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
return list(variants)
|
||||
|
||||
# 生成所有变体
|
||||
# 1. 纯数字
|
||||
variants.add(digits) # 5561234567
|
||||
|
||||
# 2. 标准格式 (NNNNNN-NNNN)
|
||||
with_hyphen = f"{digits[:6]}-{digits[6:]}"
|
||||
variants.add(with_hyphen) # 556123-4567
|
||||
|
||||
# 3. VAT 格式
|
||||
vat_compact = f"SE{digits}01"
|
||||
variants.add(vat_compact) # SE556123456701
|
||||
variants.add(vat_compact.lower()) # se556123456701
|
||||
|
||||
vat_spaced = f"SE {digits[:6]}-{digits[6:]} 01"
|
||||
variants.add(vat_spaced) # SE 556123-4567 01
|
||||
|
||||
vat_spaced_no_hyphen = f"SE {digits} 01"
|
||||
variants.add(vat_spaced_no_hyphen) # SE 5561234567 01
|
||||
|
||||
# 4. 有时带国家代码但无 01 后缀
|
||||
variants.add(f"SE{digits}") # SE5561234567
|
||||
variants.add(f"SE {digits}") # SE 5561234567
|
||||
variants.add(f"SE{digits[:6]}-{digits[6:]}") # SE556123-4567
|
||||
|
||||
# 5. OCR 可能的错误变体
|
||||
ocr_variants = TextCleaner.generate_ocr_variants(digits)
|
||||
for ocr_var in ocr_variants:
|
||||
if len(ocr_var) == 10:
|
||||
variants.add(ocr_var)
|
||||
variants.add(f"{ocr_var[:6]}-{ocr_var[6:]}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Bankgiro Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def bankgiro_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for Bankgiro number.
|
||||
|
||||
Formats:
|
||||
- 7 digits: XXX-XXXX (e.g., 123-4567)
|
||||
- 8 digits: XXXX-XXXX (e.g., 1234-5678)
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
variants = set()
|
||||
|
||||
variants.add(value)
|
||||
|
||||
if not digits or len(digits) < 7 or len(digits) > 8:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# 纯数字
|
||||
variants.add(digits)
|
||||
|
||||
# 带横线格式
|
||||
if len(digits) == 7:
|
||||
variants.add(f"{digits[:3]}-{digits[3:]}") # XXX-XXXX
|
||||
elif len(digits) == 8:
|
||||
variants.add(f"{digits[:4]}-{digits[4:]}") # XXXX-XXXX
|
||||
# 有些 8 位也用 XXX-XXXXX 格式
|
||||
variants.add(f"{digits[:3]}-{digits[3:]}")
|
||||
|
||||
# 带空格格式 (有时 OCR 会这样识别)
|
||||
if len(digits) == 7:
|
||||
variants.add(f"{digits[:3]} {digits[3:]}")
|
||||
elif len(digits) == 8:
|
||||
variants.add(f"{digits[:4]} {digits[4:]}")
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Plusgiro Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def plusgiro_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for Plusgiro number.
|
||||
|
||||
Format: XXXXXXX-X (7 digits + check digit) or shorter
|
||||
Examples: 1234567-8, 12345-6, 1-8
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
variants = set()
|
||||
|
||||
variants.add(value)
|
||||
|
||||
if not digits or len(digits) < 2 or len(digits) > 8:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# 纯数字
|
||||
variants.add(digits)
|
||||
|
||||
# Plusgiro 格式: 最后一位是校验位,用横线分隔
|
||||
main_part = digits[:-1]
|
||||
check_digit = digits[-1]
|
||||
variants.add(f"{main_part}-{check_digit}")
|
||||
|
||||
# 有时带空格
|
||||
variants.add(f"{main_part} {check_digit}")
|
||||
|
||||
# 分组格式 (常见于长号码): XX XX XX-X
|
||||
if len(digits) >= 6:
|
||||
# 尝试 XX XX XX-X 格式
|
||||
spaced = ' '.join([digits[i:i + 2] for i in range(0, len(digits) - 1, 2)])
|
||||
if len(digits) % 2 == 0:
|
||||
spaced = spaced[:-1] + '-' + digits[-1]
|
||||
else:
|
||||
spaced = spaced + '-' + digits[-1]
|
||||
variants.add(spaced.replace('- ', '-'))
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Amount Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def amount_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for monetary amounts.
|
||||
|
||||
Handles:
|
||||
- Swedish: 1 234,56 (space thousand, comma decimal)
|
||||
- German: 1.234,56 (dot thousand, comma decimal)
|
||||
- US/UK: 1,234.56 (comma thousand, dot decimal)
|
||||
- Integer: 1234 -> 1234.00
|
||||
|
||||
Returns variants with different separators and with/without decimals.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 尝试解析为数值
|
||||
amount = cls._parse_amount(value)
|
||||
if amount is None:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# 生成不同格式的变体
|
||||
int_part = int(amount)
|
||||
dec_part = round((amount - int_part) * 100)
|
||||
|
||||
# 1. 基础格式
|
||||
if dec_part == 0:
|
||||
variants.add(str(int_part)) # 1234
|
||||
variants.add(f"{int_part}.00") # 1234.00
|
||||
variants.add(f"{int_part},00") # 1234,00
|
||||
else:
|
||||
variants.add(f"{int_part}.{dec_part:02d}") # 1234.56
|
||||
variants.add(f"{int_part},{dec_part:02d}") # 1234,56
|
||||
|
||||
# 2. 带千位分隔符
|
||||
int_str = str(int_part)
|
||||
if len(int_str) > 3:
|
||||
# 从右往左每3位加分隔符
|
||||
parts = []
|
||||
while int_str:
|
||||
parts.append(int_str[-3:])
|
||||
int_str = int_str[:-3]
|
||||
parts.reverse()
|
||||
|
||||
# 空格分隔 (Swedish)
|
||||
space_sep = ' '.join(parts)
|
||||
if dec_part == 0:
|
||||
variants.add(space_sep)
|
||||
else:
|
||||
variants.add(f"{space_sep},{dec_part:02d}")
|
||||
variants.add(f"{space_sep}.{dec_part:02d}")
|
||||
|
||||
# 点分隔 (German)
|
||||
dot_sep = '.'.join(parts)
|
||||
if dec_part == 0:
|
||||
variants.add(dot_sep)
|
||||
else:
|
||||
variants.add(f"{dot_sep},{dec_part:02d}")
|
||||
|
||||
# 逗号分隔 (US)
|
||||
comma_sep = ','.join(parts)
|
||||
if dec_part == 0:
|
||||
variants.add(comma_sep)
|
||||
else:
|
||||
variants.add(f"{comma_sep}.{dec_part:02d}")
|
||||
|
||||
# 3. 带货币符号
|
||||
base_amounts = [f"{int_part}.{dec_part:02d}", f"{int_part},{dec_part:02d}"]
|
||||
if dec_part == 0:
|
||||
base_amounts.append(str(int_part))
|
||||
|
||||
for base in base_amounts:
|
||||
variants.add(f"{base} kr")
|
||||
variants.add(f"{base} SEK")
|
||||
variants.add(f"{base}kr")
|
||||
variants.add(f"SEK {base}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@classmethod
|
||||
def _parse_amount(cls, text: str) -> Optional[float]:
|
||||
"""Parse amount from various formats."""
|
||||
text = TextCleaner.normalize_amount_text(text)
|
||||
|
||||
# 移除所有非数字和分隔符
|
||||
clean = re.sub(r'[^\d,.\s]', '', text)
|
||||
if not clean:
|
||||
return None
|
||||
|
||||
# 检测格式
|
||||
# 瑞典格式: 1 234,56 或 1234,56
|
||||
if re.match(r'^[\d\s]+,\d{2}$', clean):
|
||||
clean = clean.replace(' ', '').replace(',', '.')
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 德国格式: 1.234,56
|
||||
if re.match(r'^[\d.]+,\d{2}$', clean):
|
||||
clean = clean.replace('.', '').replace(',', '.')
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 美国格式: 1,234.56
|
||||
if re.match(r'^[\d,]+\.\d{2}$', clean):
|
||||
clean = clean.replace(',', '')
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 简单格式
|
||||
clean = clean.replace(' ', '').replace(',', '.')
|
||||
# 如果有多个点,只保留最后一个
|
||||
if clean.count('.') > 1:
|
||||
parts = clean.rsplit('.', 1)
|
||||
clean = parts[0].replace('.', '') + '.' + parts[1]
|
||||
|
||||
try:
|
||||
return float(clean)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Date Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def date_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for dates.
|
||||
|
||||
Input can be:
|
||||
- ISO: 2024-12-29
|
||||
- European: 29/12/2024, 29.12.2024
|
||||
- Swedish text: "29 december 2024"
|
||||
- Compact: 20241229
|
||||
|
||||
Returns all format variants.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 尝试解析日期
|
||||
parsed = cls._parse_date(value)
|
||||
if parsed is None:
|
||||
return list(v for v in variants if v)
|
||||
|
||||
year, month, day = parsed
|
||||
|
||||
# 生成所有格式变体
|
||||
# ISO
|
||||
variants.add(f"{year}-{month:02d}-{day:02d}")
|
||||
variants.add(f"{year}-{month}-{day}") # 不补零
|
||||
|
||||
# 点分隔 (Swedish common)
|
||||
variants.add(f"{year}.{month:02d}.{day:02d}")
|
||||
variants.add(f"{day:02d}.{month:02d}.{year}")
|
||||
|
||||
# 斜杠分隔
|
||||
variants.add(f"{day:02d}/{month:02d}/{year}")
|
||||
variants.add(f"{month:02d}/{day:02d}/{year}") # US format
|
||||
variants.add(f"{year}/{month:02d}/{day:02d}")
|
||||
|
||||
# 紧凑格式
|
||||
variants.add(f"{year}{month:02d}{day:02d}")
|
||||
|
||||
# 带月份名 (Swedish)
|
||||
for month_name, month_num in cls.SWEDISH_MONTHS.items():
|
||||
if month_num == f"{month:02d}":
|
||||
variants.add(f"{day} {month_name} {year}")
|
||||
variants.add(f"{day:02d} {month_name} {year}")
|
||||
# 首字母大写
|
||||
variants.add(f"{day} {month_name.capitalize()} {year}")
|
||||
|
||||
# 短年份
|
||||
short_year = str(year)[2:]
|
||||
variants.add(f"{day:02d}.{month:02d}.{short_year}")
|
||||
variants.add(f"{day:02d}/{month:02d}/{short_year}")
|
||||
variants.add(f"{short_year}-{month:02d}-{day:02d}")
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
@classmethod
|
||||
def _parse_date(cls, text: str) -> Optional[tuple[int, int, int]]:
|
||||
"""
|
||||
Parse date from text, returns (year, month, day) or None.
|
||||
"""
|
||||
text = TextCleaner.clean_text(text).lower()
|
||||
|
||||
# ISO: 2024-12-29
|
||||
match = re.search(r'(\d{4})-(\d{1,2})-(\d{1,2})', text)
|
||||
if match:
|
||||
return int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
|
||||
# Dot format: 2024.12.29
|
||||
match = re.search(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', text)
|
||||
if match:
|
||||
return int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
|
||||
# European: 29/12/2024 or 29.12.2024
|
||||
match = re.search(r'(\d{1,2})[/.](\d{1,2})[/.](\d{4})', text)
|
||||
if match:
|
||||
day, month, year = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
# 验证日期合理性
|
||||
if 1 <= day <= 31 and 1 <= month <= 12:
|
||||
return year, month, day
|
||||
|
||||
# Compact: 20241229
|
||||
match = re.search(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', text)
|
||||
if match:
|
||||
year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
if 2000 <= year <= 2100 and 1 <= month <= 12 and 1 <= day <= 31:
|
||||
return year, month, day
|
||||
|
||||
# Swedish month name: "29 december 2024"
|
||||
for month_name, month_num in cls.SWEDISH_MONTHS.items():
|
||||
pattern = rf'(\d{{1,2}})\s*{month_name}\s*(\d{{4}})'
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
day, year = int(match.group(1)), int(match.group(2))
|
||||
return year, int(month_num), day
|
||||
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def invoice_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for invoice numbers.
|
||||
|
||||
Invoice numbers are highly variable:
|
||||
- Pure digits: 12345678
|
||||
- Alphanumeric: A3861, INV-2024-001
|
||||
- With separators: 2024/001
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 提取数字部分
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
|
||||
# 大小写变体
|
||||
variants.add(value.upper())
|
||||
variants.add(value.lower())
|
||||
|
||||
# 移除分隔符
|
||||
no_sep = re.sub(r'[-/\s]', '', value)
|
||||
variants.add(no_sep)
|
||||
variants.add(no_sep.upper())
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(value):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# OCR Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def ocr_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for OCR reference numbers.
|
||||
|
||||
OCR numbers are typically 10-25 digits.
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
variants = set()
|
||||
|
||||
variants.add(value)
|
||||
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
|
||||
# 有些 OCR 号码带空格分组
|
||||
if len(digits) > 4:
|
||||
# 每 4 位分组
|
||||
spaced = ' '.join([digits[i:i + 4] for i in range(0, len(digits), 4)])
|
||||
variants.add(spaced)
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(digits):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Customer Number Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def customer_number_variants(cls, value: str) -> list[str]:
|
||||
"""
|
||||
Generate variants for customer numbers.
|
||||
|
||||
Customer numbers can be very diverse:
|
||||
- Pure digits: 12345
|
||||
- Alphanumeric: ABC123, EMM 256-6
|
||||
- With separators: 123-456
|
||||
"""
|
||||
value = TextCleaner.clean_text(value)
|
||||
variants = set()
|
||||
variants.add(value)
|
||||
|
||||
# 大小写
|
||||
variants.add(value.upper())
|
||||
variants.add(value.lower())
|
||||
|
||||
# 移除所有分隔符和空格
|
||||
compact = re.sub(r'[-/\s]', '', value)
|
||||
variants.add(compact)
|
||||
variants.add(compact.upper())
|
||||
variants.add(compact.lower())
|
||||
|
||||
# 纯数字
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=False)
|
||||
if digits:
|
||||
variants.add(digits)
|
||||
|
||||
# 纯字母 + 数字 (分离)
|
||||
letters = re.sub(r'[^a-zA-Z]', '', value)
|
||||
if letters and digits:
|
||||
variants.add(f"{letters}{digits}")
|
||||
variants.add(f"{letters.upper()}{digits}")
|
||||
variants.add(f"{letters} {digits}")
|
||||
variants.add(f"{letters.upper()} {digits}")
|
||||
variants.add(f"{letters}-{digits}")
|
||||
variants.add(f"{letters.upper()}-{digits}")
|
||||
|
||||
# OCR 错误变体
|
||||
for ocr_var in TextCleaner.generate_ocr_variants(value):
|
||||
variants.add(ocr_var)
|
||||
|
||||
return list(v for v in variants if v)
|
||||
|
||||
# =========================================================================
|
||||
# Generic Field Variants
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def get_variants(cls, field_name: str, value: str) -> list[str]:
|
||||
"""
|
||||
Get variants for a field by name.
|
||||
|
||||
This is the main entry point - dispatches to specific variant generators.
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
# 映射字段名到变体生成器
|
||||
if 'organisation' in field_lower or 'org' in field_lower:
|
||||
return cls.organisation_number_variants(value)
|
||||
elif 'bankgiro' in field_lower or field_lower == 'bg':
|
||||
return cls.bankgiro_variants(value)
|
||||
elif 'plusgiro' in field_lower or field_lower == 'pg':
|
||||
return cls.plusgiro_variants(value)
|
||||
elif 'amount' in field_lower or 'belopp' in field_lower:
|
||||
return cls.amount_variants(value)
|
||||
elif 'date' in field_lower or 'datum' in field_lower:
|
||||
return cls.date_variants(value)
|
||||
elif 'invoice' in field_lower and 'number' in field_lower:
|
||||
return cls.invoice_number_variants(value)
|
||||
elif field_lower == 'invoicenumber':
|
||||
return cls.invoice_number_variants(value)
|
||||
elif 'ocr' in field_lower:
|
||||
return cls.ocr_number_variants(value)
|
||||
elif 'customer' in field_lower:
|
||||
return cls.customer_number_variants(value)
|
||||
else:
|
||||
# 默认: 返回原值和基本清洗
|
||||
return [value, TextCleaner.clean_text(value)]
|
||||
417
packages/shared/shared/utils/fuzzy_matcher.py
Normal file
417
packages/shared/shared/utils/fuzzy_matcher.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Fuzzy Matching Module
|
||||
|
||||
Provides fuzzy string matching with OCR-aware similarity scoring.
|
||||
Handles common OCR errors and format variations in invoice fields.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuzzyMatchResult:
|
||||
"""Result of a fuzzy match operation."""
|
||||
matched: bool
|
||||
score: float # 0.0 to 1.0
|
||||
ocr_value: str
|
||||
expected_value: str
|
||||
normalized_ocr: str
|
||||
normalized_expected: str
|
||||
match_type: str # 'exact', 'normalized', 'fuzzy', 'ocr_corrected'
|
||||
|
||||
|
||||
class FuzzyMatcher:
|
||||
"""
|
||||
Fuzzy string matcher optimized for OCR text matching.
|
||||
|
||||
Provides multiple matching strategies:
|
||||
1. Exact match
|
||||
2. Normalized match (case-insensitive, whitespace-normalized)
|
||||
3. OCR-corrected match (applying common OCR error corrections)
|
||||
4. Edit distance based fuzzy match
|
||||
5. Digit-sequence match (for numeric fields)
|
||||
"""
|
||||
|
||||
# Minimum similarity threshold for fuzzy matches
|
||||
DEFAULT_THRESHOLD = 0.85
|
||||
|
||||
# Field-specific thresholds (some fields need stricter matching)
|
||||
FIELD_THRESHOLDS = {
|
||||
'InvoiceNumber': 0.90,
|
||||
'OCR': 0.95, # OCR numbers need high precision
|
||||
'Amount': 0.95,
|
||||
'Bankgiro': 0.90,
|
||||
'Plusgiro': 0.90,
|
||||
'InvoiceDate': 0.90,
|
||||
'InvoiceDueDate': 0.90,
|
||||
'supplier_organisation_number': 0.85,
|
||||
'customer_number': 0.80, # More lenient for customer numbers
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_threshold(cls, field_name: str) -> float:
|
||||
"""Get the matching threshold for a specific field."""
|
||||
return cls.FIELD_THRESHOLDS.get(field_name, cls.DEFAULT_THRESHOLD)
|
||||
|
||||
@classmethod
|
||||
def levenshtein_distance(cls, s1: str, s2: str) -> int:
|
||||
"""
|
||||
Calculate Levenshtein (edit) distance between two strings.
|
||||
|
||||
This is the minimum number of single-character edits
|
||||
(insertions, deletions, substitutions) needed to change s1 into s2.
|
||||
"""
|
||||
if len(s1) < len(s2):
|
||||
return cls.levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
# Cost is 0 if characters match, 1 otherwise
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
@classmethod
|
||||
def similarity_ratio(cls, s1: str, s2: str) -> float:
|
||||
"""
|
||||
Calculate similarity ratio between two strings.
|
||||
|
||||
Returns a value between 0.0 (completely different) and 1.0 (identical).
|
||||
Based on Levenshtein distance normalized by the length of the longer string.
|
||||
"""
|
||||
if not s1 and not s2:
|
||||
return 1.0
|
||||
if not s1 or not s2:
|
||||
return 0.0
|
||||
|
||||
max_len = max(len(s1), len(s2))
|
||||
distance = cls.levenshtein_distance(s1, s2)
|
||||
return 1.0 - (distance / max_len)
|
||||
|
||||
@classmethod
|
||||
def ocr_aware_similarity(cls, ocr_text: str, expected: str) -> float:
|
||||
"""
|
||||
Calculate similarity with OCR error awareness.
|
||||
|
||||
This method considers common OCR errors when calculating similarity,
|
||||
giving higher scores when differences are likely OCR mistakes.
|
||||
"""
|
||||
if not ocr_text or not expected:
|
||||
return 0.0 if ocr_text != expected else 1.0
|
||||
|
||||
# First try exact match
|
||||
if ocr_text == expected:
|
||||
return 1.0
|
||||
|
||||
# Try with OCR corrections applied to ocr_text
|
||||
corrected = TextCleaner.apply_ocr_digit_corrections(ocr_text)
|
||||
if corrected == expected:
|
||||
return 0.98 # Slightly less than exact match
|
||||
|
||||
# Try normalized comparison
|
||||
norm_ocr = TextCleaner.normalize_for_comparison(ocr_text)
|
||||
norm_expected = TextCleaner.normalize_for_comparison(expected)
|
||||
if norm_ocr == norm_expected:
|
||||
return 0.95
|
||||
|
||||
# Calculate base similarity
|
||||
base_sim = cls.similarity_ratio(norm_ocr, norm_expected)
|
||||
|
||||
# Boost score if differences are common OCR errors
|
||||
boost = cls._calculate_ocr_error_boost(ocr_text, expected)
|
||||
|
||||
return min(1.0, base_sim + boost)
|
||||
|
||||
@classmethod
|
||||
def _calculate_ocr_error_boost(cls, ocr_text: str, expected: str) -> float:
|
||||
"""
|
||||
Calculate a score boost based on whether differences are likely OCR errors.
|
||||
|
||||
Returns a value between 0.0 and 0.1.
|
||||
"""
|
||||
if len(ocr_text) != len(expected):
|
||||
return 0.0
|
||||
|
||||
ocr_errors = 0
|
||||
total_diffs = 0
|
||||
|
||||
for oc, ec in zip(ocr_text, expected):
|
||||
if oc != ec:
|
||||
total_diffs += 1
|
||||
# Check if this is a known OCR confusion pair
|
||||
if cls._is_ocr_confusion_pair(oc, ec):
|
||||
ocr_errors += 1
|
||||
|
||||
if total_diffs == 0:
|
||||
return 0.0
|
||||
|
||||
# Boost proportional to how many differences are OCR errors
|
||||
ocr_error_ratio = ocr_errors / total_diffs
|
||||
return ocr_error_ratio * 0.1
|
||||
|
||||
@classmethod
|
||||
def _is_ocr_confusion_pair(cls, char1: str, char2: str) -> bool:
|
||||
"""Check if two characters are commonly confused in OCR."""
|
||||
confusion_pairs = {
|
||||
('0', 'O'), ('0', 'o'), ('0', 'D'), ('0', 'Q'),
|
||||
('1', 'l'), ('1', 'I'), ('1', 'i'), ('1', '|'),
|
||||
('2', 'Z'), ('2', 'z'),
|
||||
('5', 'S'), ('5', 's'),
|
||||
('6', 'G'), ('6', 'b'),
|
||||
('8', 'B'),
|
||||
('9', 'g'), ('9', 'q'),
|
||||
}
|
||||
|
||||
pair = (char1, char2)
|
||||
return pair in confusion_pairs or (char2, char1) in confusion_pairs
|
||||
|
||||
@classmethod
|
||||
def match_digits(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match digit sequences with OCR error tolerance.
|
||||
|
||||
Optimized for numeric fields like OCR numbers, amounts, etc.
|
||||
"""
|
||||
# Extract digits
|
||||
ocr_digits = TextCleaner.extract_digits(ocr_text, apply_ocr_correction=True)
|
||||
expected_digits = TextCleaner.extract_digits(expected, apply_ocr_correction=False)
|
||||
|
||||
# Exact match after extraction
|
||||
if ocr_digits == expected_digits:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_digits,
|
||||
normalized_expected=expected_digits,
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Calculate similarity
|
||||
score = cls.ocr_aware_similarity(ocr_digits, expected_digits)
|
||||
|
||||
return FuzzyMatchResult(
|
||||
matched=score >= threshold,
|
||||
score=score,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_digits,
|
||||
normalized_expected=expected_digits,
|
||||
match_type='fuzzy' if score >= threshold else 'no_match'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_amount(cls, ocr_text: str, expected: str, threshold: float = 0.95) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match monetary amounts with format tolerance.
|
||||
|
||||
Handles different decimal separators (. vs ,) and thousand separators.
|
||||
"""
|
||||
from .validators import FieldValidators
|
||||
|
||||
# Parse both amounts
|
||||
ocr_amount = FieldValidators.parse_amount(ocr_text)
|
||||
expected_amount = FieldValidators.parse_amount(expected)
|
||||
|
||||
if ocr_amount is None or expected_amount is None:
|
||||
# Can't parse, fall back to string matching
|
||||
return cls.match_string(ocr_text, expected, threshold)
|
||||
|
||||
# Compare numeric values
|
||||
if abs(ocr_amount - expected_amount) < 0.01: # Within 1 cent
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=f"{ocr_amount:.2f}",
|
||||
normalized_expected=f"{expected_amount:.2f}",
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Calculate relative difference
|
||||
max_val = max(abs(ocr_amount), abs(expected_amount))
|
||||
if max_val > 0:
|
||||
diff_ratio = abs(ocr_amount - expected_amount) / max_val
|
||||
score = max(0.0, 1.0 - diff_ratio)
|
||||
else:
|
||||
score = 1.0 if ocr_amount == expected_amount else 0.0
|
||||
|
||||
return FuzzyMatchResult(
|
||||
matched=score >= threshold,
|
||||
score=score,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=f"{ocr_amount:.2f}" if ocr_amount else ocr_text,
|
||||
normalized_expected=f"{expected_amount:.2f}" if expected_amount else expected,
|
||||
match_type='fuzzy' if score >= threshold else 'no_match'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_date(cls, ocr_text: str, expected: str, threshold: float = 0.90) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match dates with format tolerance.
|
||||
|
||||
Handles different date formats (ISO, European, compact, etc.)
|
||||
"""
|
||||
from .validators import FieldValidators
|
||||
|
||||
# Parse both dates to ISO format
|
||||
ocr_iso = FieldValidators.format_date_iso(ocr_text)
|
||||
expected_iso = FieldValidators.format_date_iso(expected)
|
||||
|
||||
if ocr_iso and expected_iso:
|
||||
if ocr_iso == expected_iso:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_iso,
|
||||
normalized_expected=expected_iso,
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Fall back to string matching on digits
|
||||
return cls.match_digits(ocr_text, expected, threshold)
|
||||
|
||||
@classmethod
|
||||
def match_string(cls, ocr_text: str, expected: str, threshold: float = 0.85) -> FuzzyMatchResult:
|
||||
"""
|
||||
General string matching with multiple strategies.
|
||||
|
||||
Tries exact, normalized, and fuzzy matching in order.
|
||||
"""
|
||||
# Clean both strings
|
||||
ocr_clean = TextCleaner.clean_text(ocr_text)
|
||||
expected_clean = TextCleaner.clean_text(expected)
|
||||
|
||||
# Strategy 1: Exact match
|
||||
if ocr_clean == expected_clean:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=1.0,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_clean,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='exact'
|
||||
)
|
||||
|
||||
# Strategy 2: Case-insensitive match
|
||||
if ocr_clean.lower() == expected_clean.lower():
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=0.98,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_clean,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='normalized'
|
||||
)
|
||||
|
||||
# Strategy 3: OCR-corrected match
|
||||
ocr_corrected = TextCleaner.apply_ocr_digit_corrections(ocr_clean)
|
||||
if ocr_corrected == expected_clean:
|
||||
return FuzzyMatchResult(
|
||||
matched=True,
|
||||
score=0.95,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_corrected,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='ocr_corrected'
|
||||
)
|
||||
|
||||
# Strategy 4: Fuzzy match
|
||||
score = cls.ocr_aware_similarity(ocr_clean, expected_clean)
|
||||
|
||||
return FuzzyMatchResult(
|
||||
matched=score >= threshold,
|
||||
score=score,
|
||||
ocr_value=ocr_text,
|
||||
expected_value=expected,
|
||||
normalized_ocr=ocr_clean,
|
||||
normalized_expected=expected_clean,
|
||||
match_type='fuzzy' if score >= threshold else 'no_match'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def match_field(
|
||||
cls,
|
||||
field_name: str,
|
||||
ocr_value: str,
|
||||
expected_value: str,
|
||||
threshold: Optional[float] = None
|
||||
) -> FuzzyMatchResult:
|
||||
"""
|
||||
Match a field value using field-appropriate strategy.
|
||||
|
||||
Automatically selects the best matching strategy based on field type.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = cls.get_threshold(field_name)
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
# Route to appropriate matcher
|
||||
if 'amount' in field_lower or 'belopp' in field_lower:
|
||||
return cls.match_amount(ocr_value, expected_value, threshold)
|
||||
|
||||
if 'date' in field_lower or 'datum' in field_lower:
|
||||
return cls.match_date(ocr_value, expected_value, threshold)
|
||||
|
||||
if any(x in field_lower for x in ['ocr', 'bankgiro', 'plusgiro', 'org']):
|
||||
# Numeric fields with OCR errors
|
||||
return cls.match_digits(ocr_value, expected_value, threshold)
|
||||
|
||||
if 'invoice' in field_lower and 'number' in field_lower:
|
||||
# Invoice numbers can be alphanumeric
|
||||
return cls.match_string(ocr_value, expected_value, threshold)
|
||||
|
||||
# Default to string matching
|
||||
return cls.match_string(ocr_value, expected_value, threshold)
|
||||
|
||||
@classmethod
|
||||
def find_best_match(
|
||||
cls,
|
||||
ocr_value: str,
|
||||
candidates: list[str],
|
||||
field_name: str = '',
|
||||
threshold: Optional[float] = None
|
||||
) -> Optional[tuple[str, FuzzyMatchResult]]:
|
||||
"""
|
||||
Find the best matching candidate from a list.
|
||||
|
||||
Returns (matched_value, match_result) or None if no match above threshold.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = cls.get_threshold(field_name) if field_name else cls.DEFAULT_THRESHOLD
|
||||
|
||||
best_match = None
|
||||
best_result = None
|
||||
|
||||
for candidate in candidates:
|
||||
result = cls.match_field(field_name, ocr_value, candidate, threshold=0.0)
|
||||
if result.score >= threshold:
|
||||
if best_result is None or result.score > best_result.score:
|
||||
best_match = candidate
|
||||
best_result = result
|
||||
|
||||
if best_match:
|
||||
return (best_match, best_result)
|
||||
return None
|
||||
384
packages/shared/shared/utils/ocr_corrections.py
Normal file
384
packages/shared/shared/utils/ocr_corrections.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
OCR Error Corrections Module
|
||||
|
||||
Provides comprehensive OCR error correction tables and correction functions.
|
||||
Based on common OCR recognition errors in Swedish invoice documents.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionResult:
|
||||
"""Result of an OCR correction operation."""
|
||||
original: str
|
||||
corrected: str
|
||||
corrections_applied: list[tuple[int, str, str]] # (position, from_char, to_char)
|
||||
confidence: float # How confident we are in the correction
|
||||
|
||||
|
||||
class OCRCorrections:
|
||||
"""
|
||||
Comprehensive OCR error correction utilities.
|
||||
|
||||
Provides:
|
||||
- Character-level corrections for digits
|
||||
- Word-level corrections for common Swedish terms
|
||||
- Context-aware corrections
|
||||
- Multiple correction strategies
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Character-level OCR errors (digit fields)
|
||||
# =========================================================================
|
||||
|
||||
# Characters commonly misread as digits
|
||||
CHAR_TO_DIGIT = {
|
||||
# Letters that look like digits
|
||||
'O': '0', 'o': '0', # O -> 0
|
||||
'Q': '0', # Q -> 0 (less common)
|
||||
'D': '0', # D -> 0 (in some fonts)
|
||||
|
||||
'l': '1', 'I': '1', # l/I -> 1
|
||||
'i': '1', # i without dot -> 1
|
||||
'|': '1', # pipe -> 1
|
||||
'!': '1', # exclamation -> 1
|
||||
|
||||
'Z': '2', 'z': '2', # Z -> 2
|
||||
|
||||
'E': '3', # E -> 3 (rare)
|
||||
|
||||
'A': '4', 'h': '4', # A/h -> 4 (in some fonts)
|
||||
|
||||
'S': '5', 's': '5', # S -> 5
|
||||
|
||||
'G': '6', 'b': '6', # G/b -> 6
|
||||
|
||||
'T': '7', 't': '7', # T -> 7 (rare)
|
||||
|
||||
'B': '8', # B -> 8
|
||||
|
||||
'g': '9', 'q': '9', # g/q -> 9
|
||||
}
|
||||
|
||||
# Digits commonly misread as other characters
|
||||
DIGIT_TO_CHAR = {
|
||||
'0': ['O', 'o', 'D', 'Q'],
|
||||
'1': ['l', 'I', 'i', '|', '!'],
|
||||
'2': ['Z', 'z'],
|
||||
'3': ['E'],
|
||||
'4': ['A', 'h'],
|
||||
'5': ['S', 's'],
|
||||
'6': ['G', 'b'],
|
||||
'7': ['T', 't'],
|
||||
'8': ['B'],
|
||||
'9': ['g', 'q'],
|
||||
}
|
||||
|
||||
# Bidirectional confusion pairs (either direction is possible)
|
||||
CONFUSION_PAIRS = [
|
||||
('0', 'O'), ('0', 'o'), ('0', 'D'),
|
||||
('1', 'l'), ('1', 'I'), ('1', '|'),
|
||||
('2', 'Z'), ('2', 'z'),
|
||||
('5', 'S'), ('5', 's'),
|
||||
('6', 'G'), ('6', 'b'),
|
||||
('8', 'B'),
|
||||
('9', 'g'), ('9', 'q'),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# Word-level OCR errors (Swedish invoice terms)
|
||||
# =========================================================================
|
||||
|
||||
# Common Swedish invoice terms and their OCR misreadings
|
||||
SWEDISH_TERM_CORRECTIONS = {
|
||||
# Faktura (Invoice)
|
||||
'faktura': ['Faktura', 'FAKTURA', 'faktúra', 'faKtura'],
|
||||
'fakturanummer': ['Fakturanummer', 'FAKTURANUMMER', 'fakturanr', 'fakt.nr'],
|
||||
'fakturadatum': ['Fakturadatum', 'FAKTURADATUM', 'fakt.datum'],
|
||||
|
||||
# Belopp (Amount)
|
||||
'belopp': ['Belopp', 'BELOPP', 'be1opp', 'bel0pp'],
|
||||
'summa': ['Summa', 'SUMMA', '5umma'],
|
||||
'total': ['Total', 'TOTAL', 'tota1', 't0tal'],
|
||||
'moms': ['Moms', 'MOMS', 'm0ms'],
|
||||
|
||||
# Dates
|
||||
'förfallodatum': ['Förfallodatum', 'FÖRFALLODATUM', 'förfa11odatum'],
|
||||
'datum': ['Datum', 'DATUM', 'dátum'],
|
||||
|
||||
# Payment
|
||||
'bankgiro': ['Bankgiro', 'BANKGIRO', 'BG', 'bg', 'bank giro'],
|
||||
'plusgiro': ['Plusgiro', 'PLUSGIRO', 'PG', 'pg', 'plus giro'],
|
||||
'postgiro': ['Postgiro', 'POSTGIRO'],
|
||||
'ocr': ['OCR', 'ocr', '0CR', 'OcR'],
|
||||
|
||||
# Organization
|
||||
'organisationsnummer': ['Organisationsnummer', 'ORGANISATIONSNUMMER', 'org.nr', 'orgnr'],
|
||||
'kundnummer': ['Kundnummer', 'KUNDNUMMER', 'kund nr', 'kundnr'],
|
||||
|
||||
# Currency
|
||||
'kronor': ['Kronor', 'KRONOR', 'kr', 'KR', 'SEK', 'sek'],
|
||||
'öre': ['Öre', 'ÖRE', 'ore', 'ORE'],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Context patterns
|
||||
# =========================================================================
|
||||
|
||||
# Patterns that indicate the following/preceding text is a specific field
|
||||
CONTEXT_INDICATORS = {
|
||||
'invoice_number': [
|
||||
r'faktura\s*(?:nr|nummer)?[:\s]*',
|
||||
r'invoice\s*(?:no|number)?[:\s]*',
|
||||
r'fakt\.?\s*nr[:\s]*',
|
||||
r'inv[:\s]*#?',
|
||||
],
|
||||
'amount': [
|
||||
r'(?:att\s+)?betala[:\s]*',
|
||||
r'total[t]?[:\s]*',
|
||||
r'summa[:\s]*',
|
||||
r'belopp[:\s]*',
|
||||
r'amount[:\s]*',
|
||||
],
|
||||
'date': [
|
||||
r'datum[:\s]*',
|
||||
r'date[:\s]*',
|
||||
r'förfall(?:o)?datum[:\s]*',
|
||||
r'fakturadatum[:\s]*',
|
||||
],
|
||||
'ocr': [
|
||||
r'ocr[:\s]*',
|
||||
r'referens[:\s]*',
|
||||
r'betalningsreferens[:\s]*',
|
||||
],
|
||||
'bankgiro': [
|
||||
r'bankgiro[:\s]*',
|
||||
r'bg[:\s]*',
|
||||
r'bank\s*giro[:\s]*',
|
||||
],
|
||||
'plusgiro': [
|
||||
r'plusgiro[:\s]*',
|
||||
r'pg[:\s]*',
|
||||
r'plus\s*giro[:\s]*',
|
||||
r'postgiro[:\s]*',
|
||||
],
|
||||
'org_number': [
|
||||
r'org\.?\s*(?:nr|nummer)?[:\s]*',
|
||||
r'organisationsnummer[:\s]*',
|
||||
r'vat[:\s]*',
|
||||
r'moms(?:reg)?\.?\s*(?:nr|nummer)?[:\s]*',
|
||||
],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Correction Methods
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def correct_digits(cls, text: str, aggressive: bool = False) -> CorrectionResult:
|
||||
"""
|
||||
Apply digit corrections to text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
aggressive: If True, correct all potential digit-like characters.
|
||||
If False, only correct characters adjacent to digits.
|
||||
|
||||
Returns:
|
||||
CorrectionResult with original, corrected text, and details.
|
||||
"""
|
||||
corrections = []
|
||||
result = []
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char.isdigit():
|
||||
result.append(char)
|
||||
elif char in cls.CHAR_TO_DIGIT:
|
||||
if aggressive:
|
||||
# Always correct
|
||||
corrected_char = cls.CHAR_TO_DIGIT[char]
|
||||
corrections.append((i, char, corrected_char))
|
||||
result.append(corrected_char)
|
||||
else:
|
||||
# Only correct if adjacent to digit
|
||||
prev_is_digit = i > 0 and (text[i-1].isdigit() or text[i-1] in cls.CHAR_TO_DIGIT)
|
||||
next_is_digit = i < len(text) - 1 and (text[i+1].isdigit() or text[i+1] in cls.CHAR_TO_DIGIT)
|
||||
|
||||
if prev_is_digit or next_is_digit:
|
||||
corrected_char = cls.CHAR_TO_DIGIT[char]
|
||||
corrections.append((i, char, corrected_char))
|
||||
result.append(corrected_char)
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
result.append(char)
|
||||
|
||||
corrected = ''.join(result)
|
||||
confidence = 1.0 - (len(corrections) * 0.05) # Decrease confidence per correction
|
||||
|
||||
return CorrectionResult(
|
||||
original=text,
|
||||
corrected=corrected,
|
||||
corrections_applied=corrections,
|
||||
confidence=max(0.5, confidence)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_digit_variants(cls, text: str) -> list[str]:
|
||||
"""
|
||||
Generate all possible digit interpretations of a text.
|
||||
|
||||
Useful for matching when we don't know which direction the OCR error went.
|
||||
"""
|
||||
if not text:
|
||||
return [text]
|
||||
|
||||
variants = {text}
|
||||
|
||||
# For each character that could be confused
|
||||
for i, char in enumerate(text):
|
||||
new_variants = set()
|
||||
for existing in variants:
|
||||
# If it's a digit, add letter variants
|
||||
if char.isdigit() and char in cls.DIGIT_TO_CHAR:
|
||||
for replacement in cls.DIGIT_TO_CHAR[char]:
|
||||
new_variants.add(existing[:i] + replacement + existing[i+1:])
|
||||
|
||||
# If it's a letter that looks like a digit, add digit variant
|
||||
if char in cls.CHAR_TO_DIGIT:
|
||||
new_variants.add(existing[:i] + cls.CHAR_TO_DIGIT[char] + existing[i+1:])
|
||||
|
||||
variants.update(new_variants)
|
||||
|
||||
# Limit explosion - only keep reasonable number
|
||||
if len(variants) > 100:
|
||||
break
|
||||
|
||||
return list(variants)
|
||||
|
||||
@classmethod
|
||||
def correct_swedish_term(cls, text: str) -> str:
|
||||
"""
|
||||
Correct common Swedish invoice terms that may have OCR errors.
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
|
||||
for canonical, variants in cls.SWEDISH_TERM_CORRECTIONS.items():
|
||||
for variant in variants:
|
||||
if variant.lower() in text_lower:
|
||||
# Replace with canonical form (preserving case of first letter)
|
||||
pattern = re.compile(re.escape(variant), re.IGNORECASE)
|
||||
if text[0].isupper():
|
||||
replacement = canonical.capitalize()
|
||||
else:
|
||||
replacement = canonical
|
||||
text = pattern.sub(replacement, text)
|
||||
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def extract_with_context(cls, text: str, field_type: str) -> Optional[str]:
|
||||
"""
|
||||
Extract a field value using context indicators.
|
||||
|
||||
Looks for patterns like "Fakturanr: 12345" and extracts "12345".
|
||||
"""
|
||||
patterns = cls.CONTEXT_INDICATORS.get(field_type, [])
|
||||
|
||||
for pattern in patterns:
|
||||
# Look for pattern followed by value
|
||||
full_pattern = pattern + r'([^\s,;]+)'
|
||||
match = re.search(full_pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_likely_ocr_error(cls, char1: str, char2: str) -> bool:
|
||||
"""
|
||||
Check if two characters are commonly confused in OCR.
|
||||
"""
|
||||
pair = (char1, char2)
|
||||
reverse_pair = (char2, char1)
|
||||
|
||||
for p in cls.CONFUSION_PAIRS:
|
||||
if pair == p or reverse_pair == p:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def count_potential_ocr_errors(cls, s1: str, s2: str) -> tuple[int, int]:
|
||||
"""
|
||||
Count how many character differences between two strings
|
||||
are likely OCR errors vs other differences.
|
||||
|
||||
Returns: (ocr_errors, other_errors)
|
||||
"""
|
||||
if len(s1) != len(s2):
|
||||
return (0, abs(len(s1) - len(s2)))
|
||||
|
||||
ocr_errors = 0
|
||||
other_errors = 0
|
||||
|
||||
for c1, c2 in zip(s1, s2):
|
||||
if c1 != c2:
|
||||
if cls.is_likely_ocr_error(c1, c2):
|
||||
ocr_errors += 1
|
||||
else:
|
||||
other_errors += 1
|
||||
|
||||
return (ocr_errors, other_errors)
|
||||
|
||||
@classmethod
|
||||
def suggest_corrections(cls, text: str, expected_type: str = 'digit') -> list[tuple[str, float]]:
|
||||
"""
|
||||
Suggest possible corrections for a text with confidence scores.
|
||||
|
||||
Returns list of (corrected_text, confidence) tuples, sorted by confidence.
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if expected_type == 'digit':
|
||||
# Apply digit corrections with different levels of aggressiveness
|
||||
mild = cls.correct_digits(text, aggressive=False)
|
||||
if mild.corrected != text:
|
||||
suggestions.append((mild.corrected, mild.confidence))
|
||||
|
||||
aggressive = cls.correct_digits(text, aggressive=True)
|
||||
if aggressive.corrected != text and aggressive.corrected != mild.corrected:
|
||||
suggestions.append((aggressive.corrected, aggressive.confidence * 0.9))
|
||||
|
||||
# Generate variants
|
||||
variants = cls.generate_digit_variants(text)
|
||||
for variant in variants[:10]: # Limit to top 10
|
||||
if variant != text and variant not in [s[0] for s in suggestions]:
|
||||
# Lower confidence for variants
|
||||
suggestions.append((variant, 0.7))
|
||||
|
||||
# Sort by confidence
|
||||
suggestions.sort(key=lambda x: x[1], reverse=True)
|
||||
return suggestions
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience functions
|
||||
# =========================================================================
|
||||
|
||||
def correct_ocr_digits(text: str, aggressive: bool = False) -> str:
|
||||
"""Convenience function to correct OCR digit errors."""
|
||||
return OCRCorrections.correct_digits(text, aggressive).corrected
|
||||
|
||||
|
||||
def generate_ocr_variants(text: str) -> list[str]:
|
||||
"""Convenience function to generate OCR variants."""
|
||||
return OCRCorrections.generate_digit_variants(text)
|
||||
|
||||
|
||||
def is_ocr_confusion(char1: str, char2: str) -> bool:
|
||||
"""Convenience function to check if characters are OCR confusable."""
|
||||
return OCRCorrections.is_likely_ocr_error(char1, char2)
|
||||
244
packages/shared/shared/utils/text_cleaner.py
Normal file
244
packages/shared/shared/utils/text_cleaner.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Text Cleaning Module
|
||||
|
||||
Provides text normalization and OCR error correction utilities.
|
||||
Used by both inference (field_extractor) and matching (normalizer) stages.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TextCleaner:
|
||||
"""
|
||||
Unified text cleaning utilities for invoice processing.
|
||||
|
||||
Handles:
|
||||
- Unicode normalization (zero-width chars, dash variants)
|
||||
- OCR error correction (O/0, l/1, etc.)
|
||||
- Whitespace normalization
|
||||
- Swedish-specific character handling
|
||||
"""
|
||||
|
||||
# OCR常见错误修正映射 (用于数字字段)
|
||||
# 当我们期望数字时,这些字符常被误识别
|
||||
OCR_DIGIT_CORRECTIONS = {
|
||||
'O': '0', 'o': '0', # 字母O -> 数字0
|
||||
'Q': '0', # Q 有时像 0
|
||||
'l': '1', 'I': '1', # 小写L/大写I -> 数字1
|
||||
'|': '1', # 竖线 -> 1
|
||||
'i': '1', # 小写i -> 1
|
||||
'S': '5', 's': '5', # S -> 5
|
||||
'B': '8', # B -> 8
|
||||
'Z': '2', 'z': '2', # Z -> 2
|
||||
'G': '6', 'g': '6', # G -> 6 (在某些字体中)
|
||||
'A': '4', # A -> 4 (在某些字体中)
|
||||
'T': '7', # T -> 7 (在某些字体中)
|
||||
'q': '9', # q -> 9
|
||||
'D': '0', # D -> 0
|
||||
}
|
||||
|
||||
# 反向映射:数字被误识别为字母的情况 (用于字母数字混合字段)
|
||||
OCR_LETTER_CORRECTIONS = {
|
||||
'0': 'O',
|
||||
'1': 'I',
|
||||
'5': 'S',
|
||||
'8': 'B',
|
||||
'2': 'Z',
|
||||
}
|
||||
|
||||
# Unicode 特殊字符归一化
|
||||
UNICODE_NORMALIZATIONS = {
|
||||
# 各种横线/破折号 -> 标准连字符
|
||||
'\u2013': '-', # en-dash –
|
||||
'\u2014': '-', # em-dash —
|
||||
'\u2212': '-', # minus sign −
|
||||
'\u00b7': '-', # middle dot ·
|
||||
'\u2010': '-', # hyphen ‐
|
||||
'\u2011': '-', # non-breaking hyphen ‑
|
||||
'\u2012': '-', # figure dash ‒
|
||||
'\u2015': '-', # horizontal bar ―
|
||||
|
||||
# 各种空格 -> 标准空格
|
||||
'\u00a0': ' ', # non-breaking space
|
||||
'\u2002': ' ', # en space
|
||||
'\u2003': ' ', # em space
|
||||
'\u2009': ' ', # thin space
|
||||
'\u200a': ' ', # hair space
|
||||
|
||||
# 零宽字符 -> 删除
|
||||
'\u200b': '', # zero-width space
|
||||
'\u200c': '', # zero-width non-joiner
|
||||
'\u200d': '', # zero-width joiner
|
||||
'\ufeff': '', # BOM / zero-width no-break space
|
||||
|
||||
# 各种引号 -> 标准引号
|
||||
'\u2018': "'", # left single quote '
|
||||
'\u2019': "'", # right single quote '
|
||||
'\u201c': '"', # left double quote "
|
||||
'\u201d': '"', # right double quote "
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def clean_unicode(cls, text: str) -> str:
|
||||
"""
|
||||
Normalize Unicode characters to ASCII equivalents.
|
||||
|
||||
Handles:
|
||||
- Various dash types -> standard hyphen (-)
|
||||
- Various spaces -> standard space
|
||||
- Zero-width characters -> removed
|
||||
- Various quotes -> standard quotes
|
||||
"""
|
||||
for unicode_char, replacement in cls.UNICODE_NORMALIZATIONS.items():
|
||||
text = text.replace(unicode_char, replacement)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def normalize_whitespace(cls, text: str) -> str:
|
||||
"""Collapse multiple whitespace to single space and strip."""
|
||||
return ' '.join(text.split())
|
||||
|
||||
@classmethod
|
||||
def clean_text(cls, text: str) -> str:
|
||||
"""
|
||||
Full text cleaning pipeline.
|
||||
|
||||
1. Normalize Unicode
|
||||
2. Normalize whitespace
|
||||
3. Strip
|
||||
|
||||
This is safe for all field types.
|
||||
"""
|
||||
text = cls.clean_unicode(text)
|
||||
text = cls.normalize_whitespace(text)
|
||||
return text.strip()
|
||||
|
||||
@classmethod
|
||||
def apply_ocr_digit_corrections(cls, text: str) -> str:
|
||||
"""
|
||||
Apply OCR error corrections for digit-only fields.
|
||||
|
||||
Use this when the field is expected to contain only digits
|
||||
(e.g., OCR number, organization number digits, etc.)
|
||||
|
||||
Example:
|
||||
"556l23-4S67" -> "556123-4567"
|
||||
"""
|
||||
result = []
|
||||
for char in text:
|
||||
if char in cls.OCR_DIGIT_CORRECTIONS:
|
||||
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
|
||||
else:
|
||||
result.append(char)
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def extract_digits(cls, text: str, apply_ocr_correction: bool = True) -> str:
|
||||
"""
|
||||
Extract only digits from text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
apply_ocr_correction: If True, apply OCR corrections ONLY to characters
|
||||
that are adjacent to digits (not standalone letters)
|
||||
|
||||
Returns:
|
||||
String containing only digits
|
||||
"""
|
||||
if apply_ocr_correction:
|
||||
# 只对看起来像数字序列中的字符应用 OCR 修正
|
||||
# 例如 "556O23" 中的 O 应该修正,但 "ABC 123" 中的 ABC 不应该
|
||||
result = []
|
||||
for i, char in enumerate(text):
|
||||
if char.isdigit():
|
||||
result.append(char)
|
||||
elif char in cls.OCR_DIGIT_CORRECTIONS:
|
||||
# 检查前后是否有数字
|
||||
prev_is_digit = i > 0 and (text[i - 1].isdigit() or text[i - 1] in cls.OCR_DIGIT_CORRECTIONS)
|
||||
next_is_digit = i < len(text) - 1 and (text[i + 1].isdigit() or text[i + 1] in cls.OCR_DIGIT_CORRECTIONS)
|
||||
if prev_is_digit or next_is_digit:
|
||||
result.append(cls.OCR_DIGIT_CORRECTIONS[char])
|
||||
# 其他字符跳过
|
||||
return ''.join(result)
|
||||
else:
|
||||
return re.sub(r'\D', '', text)
|
||||
|
||||
@classmethod
|
||||
def clean_for_digits(cls, text: str) -> str:
|
||||
"""
|
||||
Clean text that should primarily contain digits.
|
||||
|
||||
Pipeline:
|
||||
1. Clean Unicode
|
||||
2. Apply OCR digit corrections
|
||||
3. Normalize whitespace
|
||||
|
||||
Preserves separators (-, /) for formatted numbers like "556123-4567"
|
||||
"""
|
||||
text = cls.clean_unicode(text)
|
||||
text = cls.apply_ocr_digit_corrections(text)
|
||||
text = cls.normalize_whitespace(text)
|
||||
return text.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_ocr_variants(cls, text: str) -> list[str]:
|
||||
"""
|
||||
Generate possible OCR error variants of the input text.
|
||||
|
||||
This is useful for matching: if we have a CSV value,
|
||||
we generate variants that might appear in OCR output.
|
||||
|
||||
Example:
|
||||
"5561234567" -> ["5561234567", "556I234567", "5561234S67", ...]
|
||||
"""
|
||||
variants = {text}
|
||||
|
||||
# 只对数字生成字母变体
|
||||
for digit, letter in cls.OCR_LETTER_CORRECTIONS.items():
|
||||
if digit in text:
|
||||
variants.add(text.replace(digit, letter))
|
||||
|
||||
# 对字母生成数字变体
|
||||
for letter, digit in cls.OCR_DIGIT_CORRECTIONS.items():
|
||||
if letter in text:
|
||||
variants.add(text.replace(letter, digit))
|
||||
|
||||
return list(variants)
|
||||
|
||||
@classmethod
|
||||
def normalize_amount_text(cls, text: str) -> str:
|
||||
"""
|
||||
Normalize amount text for parsing.
|
||||
|
||||
- Removes currency symbols and labels
|
||||
- Normalizes separators
|
||||
- Handles Swedish format (space as thousand separator)
|
||||
"""
|
||||
text = cls.clean_text(text)
|
||||
|
||||
# 移除货币符号和标签 (使用单词边界确保完整匹配)
|
||||
text = re.sub(r'(?i)\b(kr|sek|kronor|öre)\b', '', text)
|
||||
|
||||
# 移除千位分隔空格 (Swedish: "1 234,56" -> "1234,56")
|
||||
# 但保留小数点前的数字
|
||||
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
@classmethod
|
||||
def normalize_for_comparison(cls, text: str) -> str:
|
||||
"""
|
||||
Normalize text for loose comparison.
|
||||
|
||||
- Lowercase
|
||||
- Remove all non-alphanumeric
|
||||
- Apply OCR corrections
|
||||
|
||||
This is the most aggressive normalization, used for fuzzy matching.
|
||||
"""
|
||||
text = cls.clean_text(text)
|
||||
text = text.lower()
|
||||
text = cls.apply_ocr_digit_corrections(text)
|
||||
text = re.sub(r'[^a-z0-9]', '', text)
|
||||
return text
|
||||
393
packages/shared/shared/utils/validators.py
Normal file
393
packages/shared/shared/utils/validators.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Field Validators Module
|
||||
|
||||
Provides validation functions for Swedish invoice fields.
|
||||
Used by both inference (to validate extracted values) and matching (to filter candidates).
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
|
||||
class FieldValidators:
|
||||
"""
|
||||
Validators for Swedish invoice field values.
|
||||
|
||||
Includes:
|
||||
- Luhn (Mod10) checksum validation
|
||||
- Format validation for specific field types
|
||||
- Range validation for dates and amounts
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Luhn (Mod10) Checksum
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def luhn_checksum(cls, digits: str) -> bool:
|
||||
"""
|
||||
Validate using Luhn (Mod10) algorithm.
|
||||
|
||||
Used for:
|
||||
- Bankgiro numbers
|
||||
- Plusgiro numbers
|
||||
- OCR reference numbers
|
||||
- Swedish organization numbers
|
||||
|
||||
The checksum is valid if the total modulo 10 equals 0.
|
||||
"""
|
||||
# 只保留数字
|
||||
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
|
||||
|
||||
if not digits or not digits.isdigit():
|
||||
return False
|
||||
|
||||
total = 0
|
||||
for i, char in enumerate(reversed(digits)):
|
||||
digit = int(char)
|
||||
if i % 2 == 1: # 从右往左,每隔一位加倍
|
||||
digit *= 2
|
||||
if digit > 9:
|
||||
digit -= 9
|
||||
total += digit
|
||||
|
||||
return total % 10 == 0
|
||||
|
||||
@classmethod
|
||||
def calculate_luhn_check_digit(cls, digits: str) -> int:
|
||||
"""
|
||||
Calculate the Luhn check digit for a number.
|
||||
|
||||
Given a number without check digit, returns the digit that would make it valid.
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
|
||||
|
||||
# 计算现有数字的 Luhn 和
|
||||
total = 0
|
||||
for i, char in enumerate(reversed(digits)):
|
||||
digit = int(char)
|
||||
if i % 2 == 0: # 注意:因为还要加一位,所以偶数位置加倍
|
||||
digit *= 2
|
||||
if digit > 9:
|
||||
digit -= 9
|
||||
total += digit
|
||||
|
||||
# 计算需要的校验位
|
||||
check_digit = (10 - (total % 10)) % 10
|
||||
return check_digit
|
||||
|
||||
# =========================================================================
|
||||
# Organisation Number Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_organisation_number(cls, value: str) -> bool:
|
||||
"""
|
||||
Validate Swedish organisation number.
|
||||
|
||||
Format: NNNNNN-NNNN (10 digits)
|
||||
- First digit: 1-9
|
||||
- Third digit: >= 2 (distinguishes from personal numbers)
|
||||
- Last digit: Luhn check digit
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
# 处理 VAT 格式
|
||||
if len(digits) == 12 and digits.endswith('01'):
|
||||
digits = digits[:10]
|
||||
elif len(digits) == 14 and digits.startswith('46') and digits.endswith('01'):
|
||||
digits = digits[2:12]
|
||||
|
||||
if len(digits) != 10:
|
||||
return False
|
||||
|
||||
# 第一位 1-9
|
||||
if digits[0] == '0':
|
||||
return False
|
||||
|
||||
# 第三位 >= 2 (区分组织号和个人号)
|
||||
# 注意:有些特殊组织可能不符合此规则,所以这里放宽
|
||||
# if int(digits[2]) < 2:
|
||||
# return False
|
||||
|
||||
# Luhn 校验
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
# =========================================================================
|
||||
# Bankgiro Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_bankgiro(cls, value: str) -> bool:
|
||||
"""
|
||||
Validate Swedish Bankgiro number.
|
||||
|
||||
Format: 7 or 8 digits with Luhn checksum
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 7 or len(digits) > 8:
|
||||
return False
|
||||
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
@classmethod
|
||||
def format_bankgiro(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Format Bankgiro number to standard format.
|
||||
|
||||
Returns: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits), or None if invalid
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) == 7:
|
||||
return f"{digits[:3]}-{digits[3:]}"
|
||||
elif len(digits) == 8:
|
||||
return f"{digits[:4]}-{digits[4:]}"
|
||||
else:
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Plusgiro Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_plusgiro(cls, value: str) -> bool:
|
||||
"""
|
||||
Validate Swedish Plusgiro number.
|
||||
|
||||
Format: 2-8 digits with Luhn checksum
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 2 or len(digits) > 8:
|
||||
return False
|
||||
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
@classmethod
|
||||
def format_plusgiro(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Format Plusgiro number to standard format.
|
||||
|
||||
Returns: XXXXXXX-X format, or None if invalid
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 2 or len(digits) > 8:
|
||||
return None
|
||||
|
||||
return f"{digits[:-1]}-{digits[-1]}"
|
||||
|
||||
# =========================================================================
|
||||
# OCR Number Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_ocr_number(cls, value: str, validate_checksum: bool = True) -> bool:
|
||||
"""
|
||||
Validate Swedish OCR reference number.
|
||||
|
||||
- Typically 10-25 digits
|
||||
- Usually has Luhn checksum (but not always enforced)
|
||||
"""
|
||||
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
|
||||
|
||||
if len(digits) < 5 or len(digits) > 25:
|
||||
return False
|
||||
|
||||
if validate_checksum:
|
||||
return cls.luhn_checksum(digits)
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Amount Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_amount(cls, value: str, min_amount: float = 0.0, max_amount: float = 10_000_000.0) -> bool:
|
||||
"""
|
||||
Validate monetary amount.
|
||||
|
||||
- Must be positive (or at least >= min_amount)
|
||||
- Should be within reasonable range
|
||||
"""
|
||||
try:
|
||||
# 尝试解析
|
||||
text = TextCleaner.normalize_amount_text(value)
|
||||
# 统一为点作为小数分隔符
|
||||
text = text.replace(' ', '').replace(',', '.')
|
||||
# 如果有多个点,保留最后一个
|
||||
if text.count('.') > 1:
|
||||
parts = text.rsplit('.', 1)
|
||||
text = parts[0].replace('.', '') + '.' + parts[1]
|
||||
|
||||
amount = float(text)
|
||||
return min_amount <= amount <= max_amount
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_amount(cls, value: str) -> Optional[float]:
|
||||
"""
|
||||
Parse amount from string, handling various formats.
|
||||
|
||||
Returns float or None if parsing fails.
|
||||
"""
|
||||
try:
|
||||
text = TextCleaner.normalize_amount_text(value)
|
||||
text = text.replace(' ', '')
|
||||
|
||||
# 检测格式并解析
|
||||
# 瑞典/德国格式: 逗号是小数点
|
||||
if re.match(r'^[\d.]+,\d{1,2}$', text):
|
||||
text = text.replace('.', '').replace(',', '.')
|
||||
# 美国格式: 点是小数点
|
||||
elif re.match(r'^[\d,]+\.\d{1,2}$', text):
|
||||
text = text.replace(',', '')
|
||||
else:
|
||||
# 简单格式
|
||||
text = text.replace(',', '.')
|
||||
if text.count('.') > 1:
|
||||
parts = text.rsplit('.', 1)
|
||||
text = parts[0].replace('.', '') + '.' + parts[1]
|
||||
|
||||
return float(text)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Date Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_date(cls, value: str, min_year: int = 2000, max_year: int = 2100) -> bool:
|
||||
"""
|
||||
Validate date string.
|
||||
|
||||
- Year should be within reasonable range
|
||||
- Month 1-12
|
||||
- Day 1-31 (basic check)
|
||||
"""
|
||||
parsed = cls.parse_date(value)
|
||||
if parsed is None:
|
||||
return False
|
||||
|
||||
year, month, day = parsed
|
||||
if not (min_year <= year <= max_year):
|
||||
return False
|
||||
if not (1 <= month <= 12):
|
||||
return False
|
||||
if not (1 <= day <= 31):
|
||||
return False
|
||||
|
||||
# 更精确的日期验证
|
||||
try:
|
||||
datetime(year, month, day)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_date(cls, value: str) -> Optional[tuple[int, int, int]]:
|
||||
"""
|
||||
Parse date from string.
|
||||
|
||||
Returns (year, month, day) tuple or None.
|
||||
"""
|
||||
from .format_variants import FormatVariants
|
||||
return FormatVariants._parse_date(value)
|
||||
|
||||
@classmethod
|
||||
def format_date_iso(cls, value: str) -> Optional[str]:
|
||||
"""
|
||||
Format date to ISO format (YYYY-MM-DD).
|
||||
|
||||
Returns formatted string or None if parsing fails.
|
||||
"""
|
||||
parsed = cls.parse_date(value)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
year, month, day = parsed
|
||||
return f"{year}-{month:02d}-{day:02d}"
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def is_valid_invoice_number(cls, value: str, min_length: int = 1, max_length: int = 30) -> bool:
|
||||
"""
|
||||
Validate invoice number.
|
||||
|
||||
Basic validation - just length check since invoice numbers are highly variable.
|
||||
"""
|
||||
clean = TextCleaner.clean_text(value)
|
||||
if not clean:
|
||||
return False
|
||||
|
||||
# 提取有意义的字符(字母和数字)
|
||||
meaningful = re.sub(r'[^a-zA-Z0-9]', '', clean)
|
||||
return min_length <= len(meaningful) <= max_length
|
||||
|
||||
# =========================================================================
|
||||
# Generic Validation
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def validate_field(cls, field_name: str, value: str) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a field by name.
|
||||
|
||||
Returns (is_valid, error_message).
|
||||
"""
|
||||
if not value:
|
||||
return False, "Empty value"
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
if 'organisation' in field_lower or 'org' in field_lower:
|
||||
if cls.is_valid_organisation_number(value):
|
||||
return True, None
|
||||
return False, "Invalid organisation number format or checksum"
|
||||
|
||||
elif 'bankgiro' in field_lower:
|
||||
if cls.is_valid_bankgiro(value):
|
||||
return True, None
|
||||
return False, "Invalid Bankgiro format or checksum"
|
||||
|
||||
elif 'plusgiro' in field_lower:
|
||||
if cls.is_valid_plusgiro(value):
|
||||
return True, None
|
||||
return False, "Invalid Plusgiro format or checksum"
|
||||
|
||||
elif 'ocr' in field_lower:
|
||||
if cls.is_valid_ocr_number(value, validate_checksum=False):
|
||||
return True, None
|
||||
return False, "Invalid OCR number length"
|
||||
|
||||
elif 'amount' in field_lower:
|
||||
if cls.is_valid_amount(value):
|
||||
return True, None
|
||||
return False, "Invalid amount format"
|
||||
|
||||
elif 'date' in field_lower:
|
||||
if cls.is_valid_date(value):
|
||||
return True, None
|
||||
return False, "Invalid date format"
|
||||
|
||||
elif 'invoice' in field_lower and 'number' in field_lower:
|
||||
if cls.is_valid_invoice_number(value):
|
||||
return True, None
|
||||
return False, "Invalid invoice number"
|
||||
|
||||
else:
|
||||
# 默认:只检查非空
|
||||
if TextCleaner.clean_text(value):
|
||||
return True, None
|
||||
return False, "Empty value after cleaning"
|
||||
Reference in New Issue
Block a user