This commit is contained in:
Yaojia Wang
2026-01-16 23:10:01 +01:00
parent 53d1e8db25
commit 425b8fdedf
10 changed files with 653 additions and 87 deletions

View File

@@ -72,7 +72,10 @@
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")", "Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")", "Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")", "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")" "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")",
"Bash(tasklist:*)",
"Bash(findstr:*)"
], ],
"deny": [], "deny": [],
"ask": [], "ask": [],

View File

@@ -9,12 +9,24 @@ import argparse
import sys import sys
import time import time
import os import os
import signal
import warnings import warnings
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import multiprocessing import multiprocessing
# Global flag for graceful shutdown
_shutdown_requested = False
def _signal_handler(signum, frame):
"""Handle interrupt signals for graceful shutdown."""
global _shutdown_requested
_shutdown_requested = True
print("\n\nShutdown requested. Finishing current batch and saving progress...")
print("(Press Ctrl+C again to force quit)\n")
# Windows compatibility: use 'spawn' method for multiprocessing # Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR # This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32': if sys.platform == 'win32':
@@ -111,6 +123,12 @@ def process_single_document(args_tuple):
report = AutoLabelReport(document_id=doc_id) report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path) report.pdf_path = str(pdf_path)
# Store metadata fields from CSV
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = { result = {
'doc_id': doc_id, 'doc_id': doc_id,
@@ -204,6 +222,67 @@ def process_single_document(args_tuple):
context_keywords=best.context_keywords context_keywords=best.context_keywords
)) ))
# Match supplier_accounts and map to Bankgiro/Plusgiro
supplier_accounts_value = row_dict.get('supplier_accounts')
if supplier_accounts_value:
# Parse accounts: "BG:xxx | PG:yyy" format
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Determine account type (BG or PG) and extract account number
account_type = None
account_number = account # Default to full value
if account.upper().startswith('BG:'):
account_type = 'Bankgiro'
account_number = account[3:].strip() # Remove "BG:" prefix
elif account.upper().startswith('BG '):
account_type = 'Bankgiro'
account_number = account[2:].strip() # Remove "BG" prefix
elif account.upper().startswith('PG:'):
account_type = 'Plusgiro'
account_number = account[3:].strip() # Remove "PG:" prefix
elif account.upper().startswith('PG '):
account_type = 'Plusgiro'
account_number = account[2:].strip() # Remove "PG" prefix
else:
# Try to guess from format - Plusgiro often has format XXXXXXX-X
digits = ''.join(c for c in account if c.isdigit())
if len(digits) == 8 and '-' in account:
account_type = 'Plusgiro'
elif len(digits) in (7, 8):
account_type = 'Bankgiro' # Default to Bankgiro
if not account_type:
continue
# Normalize and match using the account number (without prefix)
normalized = normalize_field('supplier_accounts', account_number)
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
if field_matches:
best = field_matches[0]
# Add to matches under the target class (Bankgiro/Plusgiro)
if account_type not in matches:
matches[account_type] = []
matches[account_type].extend(field_matches)
matched_fields.add('supplier_accounts')
report.add_field_result(FieldMatchResult(
field_name=f'supplier_accounts({account_type})',
csv_value=account_number, # Store without prefix
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
# Count annotations # Count annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi) annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
@@ -329,6 +408,10 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
# Import here to avoid slow startup # Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter from ..data.autolabel_report import ReportWriter
@@ -364,6 +447,7 @@ def main():
from ..data.db import DocumentDB from ..data.db import DocumentDB
db = DocumentDB() db = DocumentDB()
db.connect() db.connect()
db.create_tables() # Ensure tables exist
print("Connected to database for status checking") print("Connected to database for status checking")
# Global stats # Global stats
@@ -458,6 +542,11 @@ def main():
try: try:
# Process CSV files one by one (streaming) # Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files): for csv_idx, csv_file in enumerate(csv_files):
# Check for shutdown request
if _shutdown_requested:
print("\nShutdown requested. Stopping after current batch...")
break
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}") print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Load only this CSV file # Load only this CSV file
@@ -548,6 +637,13 @@ def main():
'Bankgiro': row.Bankgiro, 'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro, 'Plusgiro': row.Plusgiro,
'Amount': row.Amount, 'Amount': row.Amount,
# New fields
'supplier_organisation_number': row.supplier_organisation_number,
'supplier_accounts': row.supplier_accounts,
# Metadata fields (not for matching, but for database storage)
'split': row.split,
'customer_number': row.customer_number,
'supplier_name': row.supplier_name,
} }
tasks.append(( tasks.append((
@@ -647,11 +743,19 @@ def main():
futures = {executor.submit(process_single_document, task): task[0]['DocumentId'] futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks} for task in tasks}
# Per-document timeout: 120 seconds (2 minutes)
# This prevents a single stuck document from blocking the entire batch
DOCUMENT_TIMEOUT = 120
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future] doc_id = futures[future]
try: try:
result = future.result() result = future.result(timeout=DOCUMENT_TIMEOUT)
handle_result(result) handle_result(result)
except TimeoutError:
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
# Cancel the stuck future
future.cancel()
except Exception as e: except Exception as e:
handle_error(doc_id, e) handle_error(doc_id, e)

View File

@@ -34,7 +34,13 @@ def create_tables(conn):
annotations_generated INTEGER, annotations_generated INTEGER,
processing_time_ms REAL, processing_time_ms REAL,
timestamp TIMESTAMPTZ, timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]' errors JSONB DEFAULT '[]',
-- New fields for extended CSV format
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
); );
CREATE TABLE IF NOT EXISTS field_results ( CREATE TABLE IF NOT EXISTS field_results (
@@ -56,6 +62,26 @@ def create_tables(conn):
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id); CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name); CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched); CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""") """)
conn.commit() conn.commit()
@@ -82,7 +108,8 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
INSERT INTO documents INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages, (document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated, fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors) processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s VALUES %s
ON CONFLICT (document_id) DO UPDATE SET ON CONFLICT (document_id) DO UPDATE SET
pdf_path = EXCLUDED.pdf_path, pdf_path = EXCLUDED.pdf_path,
@@ -94,7 +121,12 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
annotations_generated = EXCLUDED.annotations_generated, annotations_generated = EXCLUDED.annotations_generated,
processing_time_ms = EXCLUDED.processing_time_ms, processing_time_ms = EXCLUDED.processing_time_ms,
timestamp = EXCLUDED.timestamp, timestamp = EXCLUDED.timestamp,
errors = EXCLUDED.errors errors = EXCLUDED.errors,
split = EXCLUDED.split,
customer_number = EXCLUDED.customer_number,
supplier_name = EXCLUDED.supplier_name,
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
supplier_accounts = EXCLUDED.supplier_accounts
""", doc_batch) """, doc_batch)
doc_batch = [] doc_batch = []
@@ -150,7 +182,13 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
record.get('annotations_generated'), record.get('annotations_generated'),
record.get('processing_time_ms'), record.get('processing_time_ms'),
record.get('timestamp'), record.get('timestamp'),
json.dumps(record.get('errors', [])) json.dumps(record.get('errors', [])),
# New fields
record.get('split'),
record.get('customer_number'),
record.get('supplier_name'),
record.get('supplier_organisation_number'),
record.get('supplier_accounts'),
)) ))
for field in record.get('field_results', []): for field in record.get('field_results', []):

View File

@@ -63,6 +63,12 @@ class AutoLabelReport:
processing_time_ms: float = 0.0 processing_time_ms: float = 0.0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
errors: list[str] = field(default_factory=list) errors: list[str] = field(default_factory=list)
# New metadata fields (from CSV, not for matching)
split: str | None = None
customer_number: str | None = None
supplier_name: str | None = None
supplier_organisation_number: str | None = None
supplier_accounts: str | None = None
def add_field_result(self, result: FieldMatchResult) -> None: def add_field_result(self, result: FieldMatchResult) -> None:
"""Add a field matching result.""" """Add a field matching result."""
@@ -87,7 +93,13 @@ class AutoLabelReport:
'label_paths': self.label_paths, 'label_paths': self.label_paths,
'processing_time_ms': self.processing_time_ms, 'processing_time_ms': self.processing_time_ms,
'timestamp': self.timestamp, 'timestamp': self.timestamp,
'errors': self.errors 'errors': self.errors,
# New metadata fields
'split': self.split,
'customer_number': self.customer_number,
'supplier_name': self.supplier_name,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
} }
def to_json(self, indent: int | None = None) -> str: def to_json(self, indent: int | None = None) -> str:

View File

@@ -25,6 +25,12 @@ class InvoiceRow:
Bankgiro: str | None = None Bankgiro: str | None = None
Plusgiro: str | None = None Plusgiro: str | None = None
Amount: Decimal | None = None Amount: Decimal | None = None
# New fields
split: str | None = None # train/test split indicator
customer_number: str | None = None # Customer number (no matching needed)
supplier_name: str | None = None # Supplier name (no matching)
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
supplier_accounts: str | None = None # Supplier accounts (needs matching)
# Raw values for reference # Raw values for reference
raw_data: dict = field(default_factory=dict) raw_data: dict = field(default_factory=dict)
@@ -40,6 +46,8 @@ class InvoiceRow:
'Bankgiro': self.Bankgiro, 'Bankgiro': self.Bankgiro,
'Plusgiro': self.Plusgiro, 'Plusgiro': self.Plusgiro,
'Amount': str(self.Amount) if self.Amount else None, 'Amount': str(self.Amount) if self.Amount else None,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
} }
def get_field_value(self, field_name: str) -> str | None: def get_field_value(self, field_name: str) -> str | None:
@@ -68,6 +76,12 @@ class CSVLoader:
'Bankgiro': 'Bankgiro', 'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro', 'Plusgiro': 'Plusgiro',
'Amount': 'Amount', 'Amount': 'Amount',
# New fields
'split': 'split',
'customer_number': 'customer_number',
'supplier_name': 'supplier_name',
'supplier_organisation_number': 'supplier_organisation_number',
'supplier_accounts': 'supplier_accounts',
} }
def __init__( def __init__(
@@ -200,6 +214,12 @@ class CSVLoader:
Bankgiro=self._parse_string(row.get('Bankgiro')), Bankgiro=self._parse_string(row.get('Bankgiro')),
Plusgiro=self._parse_string(row.get('Plusgiro')), Plusgiro=self._parse_string(row.get('Plusgiro')),
Amount=self._parse_amount(row.get('Amount')), Amount=self._parse_amount(row.get('Amount')),
# New fields
split=self._parse_string(row.get('split')),
customer_number=self._parse_string(row.get('customer_number')),
supplier_name=self._parse_string(row.get('supplier_name')),
supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')),
supplier_accounts=self._parse_string(row.get('supplier_accounts')),
raw_data=dict(row) raw_data=dict(row)
) )
@@ -318,14 +338,16 @@ class CSVLoader:
row.OCR, row.OCR,
row.Bankgiro, row.Bankgiro,
row.Plusgiro, row.Plusgiro,
row.Amount row.Amount,
row.supplier_organisation_number,
row.supplier_accounts,
] ]
if not any(matchable_fields): if not any(matchable_fields):
issues.append({ issues.append({
'row': i, 'row': i,
'doc_id': row.DocumentId, 'doc_id': row.DocumentId,
'field': 'All', 'field': 'All',
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)' 'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)'
}) })
return issues return issues

View File

@@ -26,6 +26,73 @@ class DocumentDB:
self.conn = psycopg2.connect(self.connection_string) self.conn = psycopg2.connect(self.connection_string)
return self.conn return self.conn
def create_tables(self):
"""Create database tables if they don't exist."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
document_id TEXT PRIMARY KEY,
pdf_path TEXT,
pdf_type TEXT,
success BOOLEAN,
total_pages INTEGER,
fields_matched INTEGER,
fields_total INTEGER,
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]',
-- Extended CSV format fields
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
);
CREATE TABLE IF NOT EXISTS field_results (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
field_name TEXT,
csv_value TEXT,
matched BOOLEAN,
score REAL,
matched_text TEXT,
candidate_used TEXT,
bbox JSONB,
page_no INTEGER,
context_keywords JSONB DEFAULT '[]',
error TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""")
conn.commit()
def close(self): def close(self):
"""Close database connection.""" """Close database connection."""
if self.conn: if self.conn:
@@ -110,7 +177,9 @@ class DocumentDB:
cursor.execute(""" cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages, SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated, fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = %s FROM documents WHERE document_id = %s
""", (doc_id,)) """, (doc_id,))
row = cursor.fetchone() row = cursor.fetchone()
@@ -129,6 +198,12 @@ class DocumentDB:
'processing_time_ms': row[8], 'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None, 'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'), 'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': [] 'field_results': []
} }
@@ -253,7 +328,9 @@ class DocumentDB:
cursor.execute(""" cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages, SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated, fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors processing_time_ms, timestamp, errors,
split, customer_number, supplier_name,
supplier_organisation_number, supplier_accounts
FROM documents WHERE document_id = ANY(%s) FROM documents WHERE document_id = ANY(%s)
""", (doc_ids,)) """, (doc_ids,))
@@ -270,6 +347,12 @@ class DocumentDB:
'processing_time_ms': row[8], 'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None, 'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'), 'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
# New fields
'split': row[11],
'customer_number': row[12],
'supplier_name': row[13],
'supplier_organisation_number': row[14],
'supplier_accounts': row[15],
'field_results': [] 'field_results': []
} }
@@ -315,8 +398,9 @@ class DocumentDB:
INSERT INTO documents INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages, (document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated, fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors) processing_time_ms, timestamp, errors,
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", ( """, (
doc_id, doc_id,
report.get('pdf_path'), report.get('pdf_path'),
@@ -328,7 +412,13 @@ class DocumentDB:
report.get('annotations_generated'), report.get('annotations_generated'),
report.get('processing_time_ms'), report.get('processing_time_ms'),
report.get('timestamp'), report.get('timestamp'),
json.dumps(report.get('errors', [])) json.dumps(report.get('errors', [])),
# New fields
report.get('split'),
report.get('customer_number'),
report.get('supplier_name'),
report.get('supplier_organisation_number'),
report.get('supplier_accounts'),
)) ))
# Batch insert field results using execute_values # Batch insert field results using execute_values
@@ -387,7 +477,13 @@ class DocumentDB:
r.get('annotations_generated'), r.get('annotations_generated'),
r.get('processing_time_ms'), r.get('processing_time_ms'),
r.get('timestamp'), r.get('timestamp'),
json.dumps(r.get('errors', [])) json.dumps(r.get('errors', [])),
# New fields
r.get('split'),
r.get('customer_number'),
r.get('supplier_name'),
r.get('supplier_organisation_number'),
r.get('supplier_accounts'),
) )
for r in reports for r in reports
] ]
@@ -395,7 +491,8 @@ class DocumentDB:
INSERT INTO documents INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages, (document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated, fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors) processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s VALUES %s
""", doc_values) """, doc_values)

View File

@@ -14,6 +14,12 @@ from functools import cached_property
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})') _DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+') _WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D') _NON_DIGIT_PATTERN = re.compile(r'\D')
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212]') # en-dash, em-dash, minus sign
def _normalize_dashes(text: str) -> str:
"""Normalize different dash types to standard hyphen-minus (ASCII 45)."""
return _DASH_PATTERN.sub('-', text)
class TokenLike(Protocol): class TokenLike(Protocol):
@@ -143,6 +149,9 @@ CONTEXT_KEYWORDS = {
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'], 'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'], 'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'], 'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
} }
@@ -207,7 +216,10 @@ class FieldMatcher:
# Strategy 4: Substring match (for values embedded in longer text) # Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205" # e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'): # Note: Amount is excluded because short numbers like "451" can incorrectly match
# in OCR payment lines or other unrelated text
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name) substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches) matches.extend(substring_matches)
@@ -237,7 +249,8 @@ class FieldMatcher:
"""Find tokens that exactly match the value.""" """Find tokens that exactly match the value."""
matches = [] matches = []
value_lower = value.lower() value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
'supplier_organisation_number', 'supplier_accounts') else None
for token in tokens: for token in tokens:
token_text = token.text.strip() token_text = token.text.strip()
@@ -355,33 +368,36 @@ class FieldMatcher:
matches = [] matches = []
# Supported fields for substring matching # Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount') supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
'supplier_organisation_number', 'supplier_accounts')
if field_name not in supported_fields: if field_name not in supported_fields:
return matches return matches
for token in tokens: for token in tokens:
token_text = token.text.strip() token_text = token.text.strip()
# Normalize different dash types to hyphen-minus for matching
token_text_normalized = _normalize_dashes(token_text)
# Skip if token is the same length as value (would be exact match) # Skip if token is the same length as value (would be exact match)
if len(token_text) <= len(value): if len(token_text_normalized) <= len(value):
continue continue
# Check if value appears as substring # Check if value appears as substring (using normalized text)
if value in token_text: if value in token_text_normalized:
# Verify it's a proper boundary match (not part of a larger number) # Verify it's a proper boundary match (not part of a larger number)
idx = token_text.find(value) idx = token_text_normalized.find(value)
# Check character before (if exists) # Check character before (if exists)
if idx > 0: if idx > 0:
char_before = token_text[idx - 1] char_before = token_text_normalized[idx - 1]
# Must be non-digit (allow : space - etc) # Must be non-digit (allow : space - etc)
if char_before.isdigit(): if char_before.isdigit():
continue continue
# Check character after (if exists) # Check character after (if exists)
end_idx = idx + len(value) end_idx = idx + len(value)
if end_idx < len(token_text): if end_idx < len(token_text_normalized):
char_after = token_text[end_idx] char_after = token_text_normalized[end_idx]
# Must be non-digit # Must be non-digit
if char_after.isdigit(): if char_after.isdigit():
continue continue

View File

@@ -39,9 +39,12 @@ class FieldNormalizer:
@staticmethod @staticmethod
def clean_text(text: str) -> str: def clean_text(text: str) -> str:
"""Remove invisible characters and normalize whitespace.""" """Remove invisible characters and normalize whitespace and dashes."""
# Remove zero-width characters # Remove zero-width characters
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text) text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
# Normalize different dash types to standard hyphen-minus (ASCII 45)
# en-dash (, U+2013), em-dash (—, U+2014), minus sign (, U+2212)
text = re.sub(r'[\u2013\u2014\u2212]', '-', text)
# Normalize whitespace # Normalize whitespace
text = ' '.join(text.split()) text = ' '.join(text.split())
return text.strip() return text.strip()
@@ -130,6 +133,133 @@ class FieldNormalizer:
return list(set(v for v in variants if v)) return list(set(v for v in variants if v))
@staticmethod
def normalize_organisation_number(value: str) -> list[str]:
"""
Normalize Swedish organisation number and generate VAT number variants.
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
Swedish VAT format: SE + org_number (10 digits) + 01
Examples:
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
"""
value = FieldNormalizer.clean_text(value)
# Check if input is a VAT number (starts with SE, ends with 01)
org_digits = None
if value.upper().startswith('SE') and len(value) >= 12:
# Extract org number from VAT: SE + 10 digits + 01
potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits
if len(potential_org) == 12 and potential_org.endswith('01'):
org_digits = potential_org[:-2] # Remove trailing 01
elif len(potential_org) == 10:
org_digits = potential_org
if org_digits is None:
org_digits = re.sub(r'\D', '', value)
variants = [value]
if org_digits:
variants.append(org_digits)
# Standard format: NNNNNN-NNNN (10 digits total)
if len(org_digits) == 10:
with_dash = f"{org_digits[:6]}-{org_digits[6:]}"
variants.append(with_dash)
# Swedish VAT format: SE + org_number + 01
vat_number = f"SE{org_digits}01"
variants.append(vat_number)
variants.append(vat_number.lower()) # se556123456701
# With spaces: SE 5561234567 01
variants.append(f"SE {org_digits} 01")
variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01")
# Without 01 suffix (some invoices show just SE + org)
variants.append(f"SE{org_digits}")
variants.append(f"SE {org_digits}")
# Some may have 12 digits (century prefix): NNNNNNNN-NNNN
elif len(org_digits) == 12:
with_dash = f"{org_digits[:8]}-{org_digits[8:]}"
variants.append(with_dash)
# Also try without century prefix
short_version = org_digits[2:]
variants.append(short_version)
variants.append(f"{short_version[:6]}-{short_version[6:]}")
# VAT with short version
vat_number = f"SE{short_version}01"
variants.append(vat_number)
return list(set(v for v in variants if v))
@staticmethod
def normalize_supplier_accounts(value: str) -> list[str]:
"""
Normalize supplier accounts field.
The field may contain multiple accounts separated by ' | '.
Format examples:
'PG:48676043 | PG:49128028 | PG:8915035'
'BG:5393-9484'
Each account is normalized separately to generate variants.
Examples:
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
"""
value = FieldNormalizer.clean_text(value)
variants = []
# Split by ' | ' to handle multiple accounts
accounts = [acc.strip() for acc in value.split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Add original value
variants.append(account)
# Remove prefix (PG:, BG:, etc.)
if ':' in account:
prefix, number = account.split(':', 1)
number = number.strip()
variants.append(number) # Just the number without prefix
# Also add with different prefix formats
prefix_upper = prefix.strip().upper()
variants.append(f"{prefix_upper}:{number}")
variants.append(f"{prefix_upper}: {number}") # With space
else:
number = account
# Extract digits only
digits_only = re.sub(r'\D', '', number)
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also try 4-4 format for bankgiro
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
elif len(digits_only) == 10:
# 6-4 format (like org number)
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
return list(set(v for v in variants if v))
@staticmethod @staticmethod
def normalize_amount(value: str) -> list[str]: def normalize_amount(value: str) -> list[str]:
""" """
@@ -264,40 +394,71 @@ class FieldNormalizer:
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025'] '2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...] '13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
'13 december 2025' -> ['2025-12-13', ...] '13 december 2025' -> ['2025-12-13', ...]
Note: For ambiguous formats like DD/MM/YYYY vs MM/DD/YYYY,
we generate variants for BOTH interpretations to maximize matching.
""" """
value = FieldNormalizer.clean_text(value) value = FieldNormalizer.clean_text(value)
variants = [value] variants = [value]
parsed_date = None parsed_dates = [] # May have multiple interpretations
# Try different date formats # Try different date formats
date_patterns = [ date_patterns = [
# ISO format with optional time (e.g., 2026-01-09 00:00:00) # ISO format with optional time (e.g., 2026-01-09 00:00:00)
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))), (r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
# European format with /
(r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# European format with .
(r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# European format with -
(r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# Swedish format: YYMMDD # Swedish format: YYMMDD
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))), (r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYYYMMDD # Swedish format: YYYYMMDD
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))), (r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
] ]
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
ambiguous_patterns = [
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
# Format with . - typically European DD.MM.YYYY
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
# Format with - (not ISO) - could be DD-MM-YYYY or MM-DD-YYYY
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
]
# Try unambiguous patterns first
for pattern, extractor in date_patterns: for pattern, extractor in date_patterns:
match = re.match(pattern, value) match = re.match(pattern, value)
if match: if match:
try: try:
year, month, day = extractor(match) year, month, day = extractor(match)
parsed_date = datetime(year, month, day) parsed_dates.append(datetime(year, month, day))
break break
except ValueError: except ValueError:
continue continue
# Try ambiguous patterns with both interpretations
if not parsed_dates:
for pattern in ambiguous_patterns:
match = re.match(pattern, value)
if match:
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
# Try DD/MM/YYYY (European - day first)
try:
parsed_dates.append(datetime(year, n2, n1))
except ValueError:
pass
# Try MM/DD/YYYY (US - month first) if different and valid
if n1 != n2:
try:
parsed_dates.append(datetime(year, n1, n2))
except ValueError:
pass
if parsed_dates:
break
# Try Swedish month names # Try Swedish month names
if not parsed_date: if not parsed_dates:
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items(): for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
if month_name in value.lower(): if month_name in value.lower():
# Extract day and year # Extract day and year
@@ -308,16 +469,28 @@ class FieldNormalizer:
if year < 100: if year < 100:
year = 2000 + year if year < 50 else 1900 + year year = 2000 + year if year < 50 else 1900 + year
try: try:
parsed_date = datetime(year, int(month_num), day) parsed_dates.append(datetime(year, int(month_num), day))
break break
except ValueError: except ValueError:
continue continue
if parsed_date: # Generate variants for all parsed date interpretations
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
for parsed_date in parsed_dates:
# Generate different formats # Generate different formats
iso = parsed_date.strftime('%Y-%m-%d') iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y') eu_slash = parsed_date.strftime('%d/%m/%Y')
us_slash = parsed_date.strftime('%m/%d/%Y') # US format MM/DD/YYYY
eu_dot = parsed_date.strftime('%d.%m.%Y') eu_dot = parsed_date.strftime('%d.%m.%Y')
iso_dot = parsed_date.strftime('%Y.%m.%d') # ISO with dots (e.g., 2024.02.08)
compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108) compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
@@ -329,21 +502,13 @@ class FieldNormalizer:
spaced_short = parsed_date.strftime('%y %m %d') spaced_short = parsed_date.strftime('%y %m %d')
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026") # Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
month_full = swedish_months_full[parsed_date.month - 1] month_full = swedish_months_full[parsed_date.month - 1]
month_abbrev = swedish_months_abbrev[parsed_date.month - 1] month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}" swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}" swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
variants.extend([ variants.extend([
iso, eu_slash, eu_dot, compact, compact_short, iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
eu_dot_short, spaced_full, spaced_short, eu_dot_short, spaced_full, spaced_short,
swedish_format_full, swedish_format_abbrev swedish_format_full, swedish_format_abbrev
]) ])
@@ -360,6 +525,8 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
'Amount': FieldNormalizer.normalize_amount, 'Amount': FieldNormalizer.normalize_amount,
'InvoiceDate': FieldNormalizer.normalize_date, 'InvoiceDate': FieldNormalizer.normalize_date,
'InvoiceDueDate': FieldNormalizer.normalize_date, 'InvoiceDueDate': FieldNormalizer.normalize_date,
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
} }

View File

@@ -11,6 +11,7 @@ import csv
# Field class mapping for YOLO # Field class mapping for YOLO
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
FIELD_CLASSES = { FIELD_CLASSES = {
'InvoiceNumber': 0, 'InvoiceNumber': 0,
'InvoiceDate': 1, 'InvoiceDate': 1,
@@ -19,6 +20,16 @@ FIELD_CLASSES = {
'Bankgiro': 4, 'Bankgiro': 4,
'Plusgiro': 5, 'Plusgiro': 5,
'Amount': 6, 'Amount': 6,
'supplier_organisation_number': 7,
}
# Fields that need matching but map to other YOLO classes
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
ACCOUNT_FIELD_MAPPING = {
'supplier_accounts': {
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
}
} }
CLASS_NAMES = [ CLASS_NAMES = [
@@ -29,6 +40,7 @@ CLASS_NAMES = [
'bankgiro', 'bankgiro',
'plusgiro', 'plusgiro',
'amount', 'amount',
'supplier_org_number',
] ]

View File

@@ -52,6 +52,7 @@ class DatasetItem:
page_no: int page_no: int
labels: list[YOLOAnnotation] labels: list[YOLOAnnotation]
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
csv_split: str | None = None # CSV-defined split ('train', 'test', etc.)
class DBYOLODataset: class DBYOLODataset:
@@ -202,7 +203,7 @@ class DBYOLODataset:
total_images += len(images) total_images += len(images)
continue continue
labels_by_page, is_scanned = doc_data labels_by_page, is_scanned, csv_split = doc_data
for image_path in images: for image_path in images:
total_images += 1 total_images += 1
@@ -218,7 +219,8 @@ class DBYOLODataset:
image_path=image_path, image_path=image_path,
page_no=page_no, page_no=page_no,
labels=page_labels, labels=page_labels,
is_scanned=is_scanned is_scanned=is_scanned,
csv_split=csv_split
)) ))
else: else:
skipped_no_labels += 1 skipped_no_labels += 1
@@ -237,16 +239,17 @@ class DBYOLODataset:
self.items, self._doc_ids_ordered = self._split_dataset(all_items) self.items, self._doc_ids_ordered = self._split_dataset(all_items)
print(f"Split '{self.split}': {len(self.items)} items") print(f"Split '{self.split}': {len(self.items)} items")
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]]: def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
""" """
Load labels from database for given document IDs using batch queries. Load labels from database for given document IDs using batch queries.
Returns: Returns:
Dict of doc_id -> (page_labels, is_scanned) Dict of doc_id -> (page_labels, is_scanned, split)
where page_labels is {page_no -> list[YOLOAnnotation]} where page_labels is {page_no -> list[YOLOAnnotation]}
and is_scanned indicates if bbox is in pixel coords (True) or PDF points (False) is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
split is the CSV-defined split ('train', 'test', etc.) or None
""" """
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]] = {} result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]] = {}
# Query in batches using efficient batch method # Query in batches using efficient batch method
batch_size = 500 batch_size = 500
@@ -263,6 +266,9 @@ class DBYOLODataset:
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points) # Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
is_scanned = doc.get('pdf_type') == 'scanned' is_scanned = doc.get('pdf_type') == 'scanned'
# Get CSV-defined split
csv_split = doc.get('split')
page_labels: dict[int, list[YOLOAnnotation]] = {} page_labels: dict[int, list[YOLOAnnotation]] = {}
for field_result in doc.get('field_results', []): for field_result in doc.get('field_results', []):
@@ -292,7 +298,7 @@ class DBYOLODataset:
page_labels[page_no].append(annotation) page_labels[page_no].append(annotation)
if page_labels: if page_labels:
result[doc_id] = (page_labels, is_scanned) result[doc_id] = (page_labels, is_scanned, csv_split)
return result return result
@@ -333,7 +339,10 @@ class DBYOLODataset:
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]: def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
""" """
Split items into train/val/test based on document ID. Split items into train/val/test based on CSV-defined split field.
If CSV has 'split' field, use it directly.
Otherwise, fall back to random splitting based on train_ratio/val_ratio.
Returns: Returns:
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
@@ -341,13 +350,64 @@ class DBYOLODataset:
""" """
# Group by document ID for proper splitting # Group by document ID for proper splitting
doc_items: dict[str, list[DatasetItem]] = {} doc_items: dict[str, list[DatasetItem]] = {}
doc_csv_split: dict[str, str | None] = {} # Track CSV split per document
for item in items: for item in items:
if item.document_id not in doc_items: if item.document_id not in doc_items:
doc_items[item.document_id] = [] doc_items[item.document_id] = []
doc_csv_split[item.document_id] = item.csv_split
doc_items[item.document_id].append(item) doc_items[item.document_id].append(item)
# Shuffle document IDs # Check if we have CSV-defined splits
has_csv_splits = any(split is not None for split in doc_csv_split.values())
doc_ids = list(doc_items.keys()) doc_ids = list(doc_items.keys())
if has_csv_splits:
# Use CSV-defined splits
print("Using CSV-defined split field for train/val/test assignment")
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
# 'val' is taken from train set using val_ratio
split_doc_ids = []
if self.split == 'train':
# Get documents marked as 'train' or no split defined
train_docs = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
# Take train_ratio of train docs for actual training, rest for val
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[:n_train]
elif self.split == 'val':
# Get documents marked as 'train' and take val portion
train_docs = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[n_train:]
else: # test
# Get documents marked as 'test'
split_doc_ids = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
# Apply limit if specified
if self.limit is not None and self.limit < len(split_doc_ids):
split_doc_ids = split_doc_ids[:self.limit]
print(f"Limited to {self.limit} documents")
else:
# Fall back to random splitting (original behavior)
print("No CSV split field found, using random splitting")
random.seed(self.seed) random.seed(self.seed)
random.shuffle(doc_ids) random.shuffle(doc_ids)
@@ -381,23 +441,58 @@ class DBYOLODataset:
Split items using cached data from a shared dataset. Split items using cached data from a shared dataset.
Uses pre-computed doc_ids order for consistent splits. Uses pre-computed doc_ids order for consistent splits.
Respects CSV-defined splits if available.
""" """
# Group by document ID # Group by document ID and track CSV splits
doc_items: dict[str, list[DatasetItem]] = {} doc_items: dict[str, list[DatasetItem]] = {}
doc_csv_split: dict[str, str | None] = {}
for item in self._all_items: for item in self._all_items:
if item.document_id not in doc_items: if item.document_id not in doc_items:
doc_items[item.document_id] = [] doc_items[item.document_id] = []
doc_csv_split[item.document_id] = item.csv_split
doc_items[item.document_id].append(item) doc_items[item.document_id].append(item)
# Use cached doc_ids order # Check if we have CSV-defined splits
has_csv_splits = any(split is not None for split in doc_csv_split.values())
doc_ids = self._doc_ids_ordered doc_ids = self._doc_ids_ordered
# Calculate split indices if has_csv_splits:
# Use CSV-defined splits
if self.split == 'train':
train_docs = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[:n_train]
elif self.split == 'val':
train_docs = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[n_train:]
else: # test
split_doc_ids = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
else:
# Fall back to random splitting
n_total = len(doc_ids) n_total = len(doc_ids)
n_train = int(n_total * self.train_ratio) n_train = int(n_total * self.train_ratio)
n_val = int(n_total * self.val_ratio) n_val = int(n_total * self.val_ratio)
# Split document IDs based on split type
if self.split == 'train': if self.split == 'train':
split_doc_ids = doc_ids[:n_train] split_doc_ids = doc_ids[:n_train]
elif self.split == 'val': elif self.split == 'val':