From 425b8fdedf42942a2229c328411f6ced35fed68d Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Fri, 16 Jan 2026 23:10:01 +0100 Subject: [PATCH] WIP --- .claude/settings.local.json | 5 +- src/cli/autolabel.py | 108 +++++++++++++++- src/cli/import_report_to_db.py | 46 ++++++- src/data/autolabel_report.py | 14 ++- src/data/csv_loader.py | 26 +++- src/data/db.py | 111 ++++++++++++++-- src/matcher/field_matcher.py | 36 ++++-- src/normalize/normalizer.py | 209 +++++++++++++++++++++++++++---- src/yolo/annotation_generator.py | 12 ++ src/yolo/db_dataset.py | 173 +++++++++++++++++++------ 10 files changed, 653 insertions(+), 87 deletions(-) diff --git a/.claude/settings.local.json b/.claude/settings.local.json index ee7b05d..3c93f74 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -72,7 +72,10 @@ "Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")", "Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")", "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")", - "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")" + "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")", + "Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")", + "Bash(tasklist:*)", + "Bash(findstr:*)" ], "deny": [], "ask": [], diff --git a/src/cli/autolabel.py b/src/cli/autolabel.py index d0c8d3c..7c240cf 100644 --- a/src/cli/autolabel.py +++ b/src/cli/autolabel.py @@ -9,12 +9,24 @@ import argparse import sys import time import os +import signal import warnings from pathlib import Path from tqdm import tqdm -from concurrent.futures import ProcessPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError import multiprocessing +# Global flag for graceful shutdown +_shutdown_requested = False + + +def _signal_handler(signum, frame): + """Handle interrupt signals for graceful shutdown.""" + global _shutdown_requested + _shutdown_requested = True + print("\n\nShutdown requested. Finishing current batch and saving progress...") + print("(Press Ctrl+C again to force quit)\n") + # Windows compatibility: use 'spawn' method for multiprocessing # This is required on Windows and is also safer for libraries like PaddleOCR if sys.platform == 'win32': @@ -111,6 +123,12 @@ def process_single_document(args_tuple): report = AutoLabelReport(document_id=doc_id) report.pdf_path = str(pdf_path) + # Store metadata fields from CSV + report.split = row_dict.get('split') + report.customer_number = row_dict.get('customer_number') + report.supplier_name = row_dict.get('supplier_name') + report.supplier_organisation_number = row_dict.get('supplier_organisation_number') + report.supplier_accounts = row_dict.get('supplier_accounts') result = { 'doc_id': doc_id, @@ -204,6 +222,67 @@ def process_single_document(args_tuple): context_keywords=best.context_keywords )) + # Match supplier_accounts and map to Bankgiro/Plusgiro + supplier_accounts_value = row_dict.get('supplier_accounts') + if supplier_accounts_value: + # Parse accounts: "BG:xxx | PG:yyy" format + accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')] + for account in accounts: + account = account.strip() + if not account: + continue + + # Determine account type (BG or PG) and extract account number + account_type = None + account_number = account # Default to full value + + if account.upper().startswith('BG:'): + account_type = 'Bankgiro' + account_number = account[3:].strip() # Remove "BG:" prefix + elif account.upper().startswith('BG '): + account_type = 'Bankgiro' + account_number = account[2:].strip() # Remove "BG" prefix + elif account.upper().startswith('PG:'): + account_type = 'Plusgiro' + account_number = account[3:].strip() # Remove "PG:" prefix + elif account.upper().startswith('PG '): + account_type = 'Plusgiro' + account_number = account[2:].strip() # Remove "PG" prefix + else: + # Try to guess from format - Plusgiro often has format XXXXXXX-X + digits = ''.join(c for c in account if c.isdigit()) + if len(digits) == 8 and '-' in account: + account_type = 'Plusgiro' + elif len(digits) in (7, 8): + account_type = 'Bankgiro' # Default to Bankgiro + + if not account_type: + continue + + # Normalize and match using the account number (without prefix) + normalized = normalize_field('supplier_accounts', account_number) + field_matches = matcher.find_matches(tokens, account_type, normalized, page_no) + + if field_matches: + best = field_matches[0] + # Add to matches under the target class (Bankgiro/Plusgiro) + if account_type not in matches: + matches[account_type] = [] + matches[account_type].extend(field_matches) + matched_fields.add('supplier_accounts') + + report.add_field_result(FieldMatchResult( + field_name=f'supplier_accounts({account_type})', + csv_value=account_number, # Store without prefix + matched=True, + score=best.score, + matched_text=best.matched_text, + candidate_used=best.value, + bbox=best.bbox, + page_no=page_no, + context_keywords=best.context_keywords + )) + # Count annotations annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi) @@ -329,6 +408,10 @@ def main(): args = parser.parse_args() + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + # Import here to avoid slow startup from ..data import CSVLoader, AutoLabelReport, FieldMatchResult from ..data.autolabel_report import ReportWriter @@ -364,6 +447,7 @@ def main(): from ..data.db import DocumentDB db = DocumentDB() db.connect() + db.create_tables() # Ensure tables exist print("Connected to database for status checking") # Global stats @@ -458,6 +542,11 @@ def main(): try: # Process CSV files one by one (streaming) for csv_idx, csv_file in enumerate(csv_files): + # Check for shutdown request + if _shutdown_requested: + print("\nShutdown requested. Stopping after current batch...") + break + print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}") # Load only this CSV file @@ -548,6 +637,13 @@ def main(): 'Bankgiro': row.Bankgiro, 'Plusgiro': row.Plusgiro, 'Amount': row.Amount, + # New fields + 'supplier_organisation_number': row.supplier_organisation_number, + 'supplier_accounts': row.supplier_accounts, + # Metadata fields (not for matching, but for database storage) + 'split': row.split, + 'customer_number': row.customer_number, + 'supplier_name': row.supplier_name, } tasks.append(( @@ -647,11 +743,19 @@ def main(): futures = {executor.submit(process_single_document, task): task[0]['DocumentId'] for task in tasks} + # Per-document timeout: 120 seconds (2 minutes) + # This prevents a single stuck document from blocking the entire batch + DOCUMENT_TIMEOUT = 120 + for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): doc_id = futures[future] try: - result = future.result() + result = future.result(timeout=DOCUMENT_TIMEOUT) handle_result(result) + except TimeoutError: + handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s") + # Cancel the stuck future + future.cancel() except Exception as e: handle_error(doc_id, e) diff --git a/src/cli/import_report_to_db.py b/src/cli/import_report_to_db.py index 4c309a1..3afd81d 100644 --- a/src/cli/import_report_to_db.py +++ b/src/cli/import_report_to_db.py @@ -34,7 +34,13 @@ def create_tables(conn): annotations_generated INTEGER, processing_time_ms REAL, timestamp TIMESTAMPTZ, - errors JSONB DEFAULT '[]' + errors JSONB DEFAULT '[]', + -- New fields for extended CSV format + split TEXT, + customer_number TEXT, + supplier_name TEXT, + supplier_organisation_number TEXT, + supplier_accounts TEXT ); CREATE TABLE IF NOT EXISTS field_results ( @@ -56,6 +62,26 @@ def create_tables(conn): CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id); CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name); CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched); + + -- Add new columns to existing tables if they don't exist (for migration) + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN + ALTER TABLE documents ADD COLUMN split TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN + ALTER TABLE documents ADD COLUMN customer_number TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN + ALTER TABLE documents ADD COLUMN supplier_name TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN + ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN + ALTER TABLE documents ADD COLUMN supplier_accounts TEXT; + END IF; + END $$; """) conn.commit() @@ -82,7 +108,8 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_ INSERT INTO documents (document_id, pdf_path, pdf_type, success, total_pages, fields_matched, fields_total, annotations_generated, - processing_time_ms, timestamp, errors) + processing_time_ms, timestamp, errors, + split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts) VALUES %s ON CONFLICT (document_id) DO UPDATE SET pdf_path = EXCLUDED.pdf_path, @@ -94,7 +121,12 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_ annotations_generated = EXCLUDED.annotations_generated, processing_time_ms = EXCLUDED.processing_time_ms, timestamp = EXCLUDED.timestamp, - errors = EXCLUDED.errors + errors = EXCLUDED.errors, + split = EXCLUDED.split, + customer_number = EXCLUDED.customer_number, + supplier_name = EXCLUDED.supplier_name, + supplier_organisation_number = EXCLUDED.supplier_organisation_number, + supplier_accounts = EXCLUDED.supplier_accounts """, doc_batch) doc_batch = [] @@ -150,7 +182,13 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_ record.get('annotations_generated'), record.get('processing_time_ms'), record.get('timestamp'), - json.dumps(record.get('errors', [])) + json.dumps(record.get('errors', [])), + # New fields + record.get('split'), + record.get('customer_number'), + record.get('supplier_name'), + record.get('supplier_organisation_number'), + record.get('supplier_accounts'), )) for field in record.get('field_results', []): diff --git a/src/data/autolabel_report.py b/src/data/autolabel_report.py index 32a2473..d0a7e7a 100644 --- a/src/data/autolabel_report.py +++ b/src/data/autolabel_report.py @@ -63,6 +63,12 @@ class AutoLabelReport: processing_time_ms: float = 0.0 timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) errors: list[str] = field(default_factory=list) + # New metadata fields (from CSV, not for matching) + split: str | None = None + customer_number: str | None = None + supplier_name: str | None = None + supplier_organisation_number: str | None = None + supplier_accounts: str | None = None def add_field_result(self, result: FieldMatchResult) -> None: """Add a field matching result.""" @@ -87,7 +93,13 @@ class AutoLabelReport: 'label_paths': self.label_paths, 'processing_time_ms': self.processing_time_ms, 'timestamp': self.timestamp, - 'errors': self.errors + 'errors': self.errors, + # New metadata fields + 'split': self.split, + 'customer_number': self.customer_number, + 'supplier_name': self.supplier_name, + 'supplier_organisation_number': self.supplier_organisation_number, + 'supplier_accounts': self.supplier_accounts, } def to_json(self, indent: int | None = None) -> str: diff --git a/src/data/csv_loader.py b/src/data/csv_loader.py index 96c0265..cac43a4 100644 --- a/src/data/csv_loader.py +++ b/src/data/csv_loader.py @@ -25,6 +25,12 @@ class InvoiceRow: Bankgiro: str | None = None Plusgiro: str | None = None Amount: Decimal | None = None + # New fields + split: str | None = None # train/test split indicator + customer_number: str | None = None # Customer number (no matching needed) + supplier_name: str | None = None # Supplier name (no matching) + supplier_organisation_number: str | None = None # Swedish org number (needs matching) + supplier_accounts: str | None = None # Supplier accounts (needs matching) # Raw values for reference raw_data: dict = field(default_factory=dict) @@ -40,6 +46,8 @@ class InvoiceRow: 'Bankgiro': self.Bankgiro, 'Plusgiro': self.Plusgiro, 'Amount': str(self.Amount) if self.Amount else None, + 'supplier_organisation_number': self.supplier_organisation_number, + 'supplier_accounts': self.supplier_accounts, } def get_field_value(self, field_name: str) -> str | None: @@ -68,6 +76,12 @@ class CSVLoader: 'Bankgiro': 'Bankgiro', 'Plusgiro': 'Plusgiro', 'Amount': 'Amount', + # New fields + 'split': 'split', + 'customer_number': 'customer_number', + 'supplier_name': 'supplier_name', + 'supplier_organisation_number': 'supplier_organisation_number', + 'supplier_accounts': 'supplier_accounts', } def __init__( @@ -200,6 +214,12 @@ class CSVLoader: Bankgiro=self._parse_string(row.get('Bankgiro')), Plusgiro=self._parse_string(row.get('Plusgiro')), Amount=self._parse_amount(row.get('Amount')), + # New fields + split=self._parse_string(row.get('split')), + customer_number=self._parse_string(row.get('customer_number')), + supplier_name=self._parse_string(row.get('supplier_name')), + supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')), + supplier_accounts=self._parse_string(row.get('supplier_accounts')), raw_data=dict(row) ) @@ -318,14 +338,16 @@ class CSVLoader: row.OCR, row.Bankgiro, row.Plusgiro, - row.Amount + row.Amount, + row.supplier_organisation_number, + row.supplier_accounts, ] if not any(matchable_fields): issues.append({ 'row': i, 'doc_id': row.DocumentId, 'field': 'All', - 'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)' + 'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)' }) return issues diff --git a/src/data/db.py b/src/data/db.py index 3340abc..b1e0e7f 100644 --- a/src/data/db.py +++ b/src/data/db.py @@ -26,6 +26,73 @@ class DocumentDB: self.conn = psycopg2.connect(self.connection_string) return self.conn + def create_tables(self): + """Create database tables if they don't exist.""" + conn = self.connect() + with conn.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS documents ( + document_id TEXT PRIMARY KEY, + pdf_path TEXT, + pdf_type TEXT, + success BOOLEAN, + total_pages INTEGER, + fields_matched INTEGER, + fields_total INTEGER, + annotations_generated INTEGER, + processing_time_ms REAL, + timestamp TIMESTAMPTZ, + errors JSONB DEFAULT '[]', + -- Extended CSV format fields + split TEXT, + customer_number TEXT, + supplier_name TEXT, + supplier_organisation_number TEXT, + supplier_accounts TEXT + ); + + CREATE TABLE IF NOT EXISTS field_results ( + id SERIAL PRIMARY KEY, + document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE, + field_name TEXT, + csv_value TEXT, + matched BOOLEAN, + score REAL, + matched_text TEXT, + candidate_used TEXT, + bbox JSONB, + page_no INTEGER, + context_keywords JSONB DEFAULT '[]', + error TEXT + ); + + CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success); + CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id); + CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name); + CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched); + + -- Add new columns to existing tables if they don't exist (for migration) + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN + ALTER TABLE documents ADD COLUMN split TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN + ALTER TABLE documents ADD COLUMN customer_number TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN + ALTER TABLE documents ADD COLUMN supplier_name TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN + ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN + ALTER TABLE documents ADD COLUMN supplier_accounts TEXT; + END IF; + END $$; + """) + conn.commit() + def close(self): """Close database connection.""" if self.conn: @@ -110,7 +177,9 @@ class DocumentDB: cursor.execute(""" SELECT document_id, pdf_path, pdf_type, success, total_pages, fields_matched, fields_total, annotations_generated, - processing_time_ms, timestamp, errors + processing_time_ms, timestamp, errors, + split, customer_number, supplier_name, + supplier_organisation_number, supplier_accounts FROM documents WHERE document_id = %s """, (doc_id,)) row = cursor.fetchone() @@ -129,6 +198,12 @@ class DocumentDB: 'processing_time_ms': row[8], 'timestamp': str(row[9]) if row[9] else None, 'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'), + # New fields + 'split': row[11], + 'customer_number': row[12], + 'supplier_name': row[13], + 'supplier_organisation_number': row[14], + 'supplier_accounts': row[15], 'field_results': [] } @@ -253,7 +328,9 @@ class DocumentDB: cursor.execute(""" SELECT document_id, pdf_path, pdf_type, success, total_pages, fields_matched, fields_total, annotations_generated, - processing_time_ms, timestamp, errors + processing_time_ms, timestamp, errors, + split, customer_number, supplier_name, + supplier_organisation_number, supplier_accounts FROM documents WHERE document_id = ANY(%s) """, (doc_ids,)) @@ -270,6 +347,12 @@ class DocumentDB: 'processing_time_ms': row[8], 'timestamp': str(row[9]) if row[9] else None, 'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'), + # New fields + 'split': row[11], + 'customer_number': row[12], + 'supplier_name': row[13], + 'supplier_organisation_number': row[14], + 'supplier_accounts': row[15], 'field_results': [] } @@ -315,8 +398,9 @@ class DocumentDB: INSERT INTO documents (document_id, pdf_path, pdf_type, success, total_pages, fields_matched, fields_total, annotations_generated, - processing_time_ms, timestamp, errors) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + processing_time_ms, timestamp, errors, + split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( doc_id, report.get('pdf_path'), @@ -328,7 +412,13 @@ class DocumentDB: report.get('annotations_generated'), report.get('processing_time_ms'), report.get('timestamp'), - json.dumps(report.get('errors', [])) + json.dumps(report.get('errors', [])), + # New fields + report.get('split'), + report.get('customer_number'), + report.get('supplier_name'), + report.get('supplier_organisation_number'), + report.get('supplier_accounts'), )) # Batch insert field results using execute_values @@ -387,7 +477,13 @@ class DocumentDB: r.get('annotations_generated'), r.get('processing_time_ms'), r.get('timestamp'), - json.dumps(r.get('errors', [])) + json.dumps(r.get('errors', [])), + # New fields + r.get('split'), + r.get('customer_number'), + r.get('supplier_name'), + r.get('supplier_organisation_number'), + r.get('supplier_accounts'), ) for r in reports ] @@ -395,7 +491,8 @@ class DocumentDB: INSERT INTO documents (document_id, pdf_path, pdf_type, success, total_pages, fields_matched, fields_total, annotations_generated, - processing_time_ms, timestamp, errors) + processing_time_ms, timestamp, errors, + split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts) VALUES %s """, doc_values) diff --git a/src/matcher/field_matcher.py b/src/matcher/field_matcher.py index 611aa77..4f14a9f 100644 --- a/src/matcher/field_matcher.py +++ b/src/matcher/field_matcher.py @@ -14,6 +14,12 @@ from functools import cached_property _DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})') _WHITESPACE_PATTERN = re.compile(r'\s+') _NON_DIGIT_PATTERN = re.compile(r'\D') +_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212]') # en-dash, em-dash, minus sign + + +def _normalize_dashes(text: str) -> str: + """Normalize different dash types to standard hyphen-minus (ASCII 45).""" + return _DASH_PATTERN.sub('-', text) class TokenLike(Protocol): @@ -143,6 +149,9 @@ CONTEXT_KEYWORDS = { 'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'], 'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'], 'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'], + 'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer', + 'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'], + 'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'], } @@ -207,7 +216,10 @@ class FieldMatcher: # Strategy 4: Substring match (for values embedded in longer text) # e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205" - if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'): + # Note: Amount is excluded because short numbers like "451" can incorrectly match + # in OCR payment lines or other unrelated text + if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', + 'supplier_organisation_number', 'supplier_accounts'): substring_matches = self._find_substring_matches(page_tokens, value, field_name) matches.extend(substring_matches) @@ -237,7 +249,8 @@ class FieldMatcher: """Find tokens that exactly match the value.""" matches = [] value_lower = value.lower() - value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None + value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', + 'supplier_organisation_number', 'supplier_accounts') else None for token in tokens: token_text = token.text.strip() @@ -355,33 +368,36 @@ class FieldMatcher: matches = [] # Supported fields for substring matching - supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount') + supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount', + 'supplier_organisation_number', 'supplier_accounts') if field_name not in supported_fields: return matches for token in tokens: token_text = token.text.strip() + # Normalize different dash types to hyphen-minus for matching + token_text_normalized = _normalize_dashes(token_text) # Skip if token is the same length as value (would be exact match) - if len(token_text) <= len(value): + if len(token_text_normalized) <= len(value): continue - # Check if value appears as substring - if value in token_text: + # Check if value appears as substring (using normalized text) + if value in token_text_normalized: # Verify it's a proper boundary match (not part of a larger number) - idx = token_text.find(value) + idx = token_text_normalized.find(value) # Check character before (if exists) if idx > 0: - char_before = token_text[idx - 1] + char_before = token_text_normalized[idx - 1] # Must be non-digit (allow : space - etc) if char_before.isdigit(): continue # Check character after (if exists) end_idx = idx + len(value) - if end_idx < len(token_text): - char_after = token_text[end_idx] + if end_idx < len(token_text_normalized): + char_after = token_text_normalized[end_idx] # Must be non-digit if char_after.isdigit(): continue diff --git a/src/normalize/normalizer.py b/src/normalize/normalizer.py index f802a67..ba66209 100644 --- a/src/normalize/normalizer.py +++ b/src/normalize/normalizer.py @@ -39,9 +39,12 @@ class FieldNormalizer: @staticmethod def clean_text(text: str) -> str: - """Remove invisible characters and normalize whitespace.""" + """Remove invisible characters and normalize whitespace and dashes.""" # Remove zero-width characters text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text) + # Normalize different dash types to standard hyphen-minus (ASCII 45) + # en-dash (–, U+2013), em-dash (—, U+2014), minus sign (−, U+2212) + text = re.sub(r'[\u2013\u2014\u2212]', '-', text) # Normalize whitespace text = ' '.join(text.split()) return text.strip() @@ -130,6 +133,133 @@ class FieldNormalizer: return list(set(v for v in variants if v)) + @staticmethod + def normalize_organisation_number(value: str) -> list[str]: + """ + Normalize Swedish organisation number and generate VAT number variants. + + Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits) + Swedish VAT format: SE + org_number (10 digits) + 01 + + Examples: + '556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...] + '5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...] + 'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...] + """ + value = FieldNormalizer.clean_text(value) + + # Check if input is a VAT number (starts with SE, ends with 01) + org_digits = None + if value.upper().startswith('SE') and len(value) >= 12: + # Extract org number from VAT: SE + 10 digits + 01 + potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits + if len(potential_org) == 12 and potential_org.endswith('01'): + org_digits = potential_org[:-2] # Remove trailing 01 + elif len(potential_org) == 10: + org_digits = potential_org + + if org_digits is None: + org_digits = re.sub(r'\D', '', value) + + variants = [value] + + if org_digits: + variants.append(org_digits) + + # Standard format: NNNNNN-NNNN (10 digits total) + if len(org_digits) == 10: + with_dash = f"{org_digits[:6]}-{org_digits[6:]}" + variants.append(with_dash) + + # Swedish VAT format: SE + org_number + 01 + vat_number = f"SE{org_digits}01" + variants.append(vat_number) + variants.append(vat_number.lower()) # se556123456701 + # With spaces: SE 5561234567 01 + variants.append(f"SE {org_digits} 01") + variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01") + # Without 01 suffix (some invoices show just SE + org) + variants.append(f"SE{org_digits}") + variants.append(f"SE {org_digits}") + + # Some may have 12 digits (century prefix): NNNNNNNN-NNNN + elif len(org_digits) == 12: + with_dash = f"{org_digits[:8]}-{org_digits[8:]}" + variants.append(with_dash) + # Also try without century prefix + short_version = org_digits[2:] + variants.append(short_version) + variants.append(f"{short_version[:6]}-{short_version[6:]}") + # VAT with short version + vat_number = f"SE{short_version}01" + variants.append(vat_number) + + return list(set(v for v in variants if v)) + + @staticmethod + def normalize_supplier_accounts(value: str) -> list[str]: + """ + Normalize supplier accounts field. + + The field may contain multiple accounts separated by ' | '. + Format examples: + 'PG:48676043 | PG:49128028 | PG:8915035' + 'BG:5393-9484' + + Each account is normalized separately to generate variants. + + Examples: + 'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3'] + 'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484'] + """ + value = FieldNormalizer.clean_text(value) + variants = [] + + # Split by ' | ' to handle multiple accounts + accounts = [acc.strip() for acc in value.split('|')] + + for account in accounts: + account = account.strip() + if not account: + continue + + # Add original value + variants.append(account) + + # Remove prefix (PG:, BG:, etc.) + if ':' in account: + prefix, number = account.split(':', 1) + number = number.strip() + variants.append(number) # Just the number without prefix + + # Also add with different prefix formats + prefix_upper = prefix.strip().upper() + variants.append(f"{prefix_upper}:{number}") + variants.append(f"{prefix_upper}: {number}") # With space + else: + number = account + + # Extract digits only + digits_only = re.sub(r'\D', '', number) + + if digits_only: + variants.append(digits_only) + + # Plusgiro format: XXXXXXX-X (7 digits + check digit) + if len(digits_only) == 8: + with_dash = f"{digits_only[:-1]}-{digits_only[-1]}" + variants.append(with_dash) + # Also try 4-4 format for bankgiro + variants.append(f"{digits_only[:4]}-{digits_only[4:]}") + elif len(digits_only) == 7: + with_dash = f"{digits_only[:-1]}-{digits_only[-1]}" + variants.append(with_dash) + elif len(digits_only) == 10: + # 6-4 format (like org number) + variants.append(f"{digits_only[:6]}-{digits_only[6:]}") + + return list(set(v for v in variants if v)) + @staticmethod def normalize_amount(value: str) -> list[str]: """ @@ -264,40 +394,71 @@ class FieldNormalizer: '2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025'] '13/12/2025' -> ['2025-12-13', '13/12/2025', ...] '13 december 2025' -> ['2025-12-13', ...] + + Note: For ambiguous formats like DD/MM/YYYY vs MM/DD/YYYY, + we generate variants for BOTH interpretations to maximize matching. """ value = FieldNormalizer.clean_text(value) variants = [value] - parsed_date = None + parsed_dates = [] # May have multiple interpretations # Try different date formats date_patterns = [ # ISO format with optional time (e.g., 2026-01-09 00:00:00) (r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))), - # European format with / - (r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))), - # European format with . - (r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))), - # European format with - - (r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))), # Swedish format: YYMMDD (r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))), # Swedish format: YYYYMMDD (r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))), ] + # Ambiguous patterns - try both DD/MM and MM/DD interpretations + ambiguous_patterns = [ + # Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US) + r'^(\d{1,2})/(\d{1,2})/(\d{4})$', + # Format with . - typically European DD.MM.YYYY + r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', + # Format with - (not ISO) - could be DD-MM-YYYY or MM-DD-YYYY + r'^(\d{1,2})-(\d{1,2})-(\d{4})$', + ] + + # Try unambiguous patterns first for pattern, extractor in date_patterns: match = re.match(pattern, value) if match: try: year, month, day = extractor(match) - parsed_date = datetime(year, month, day) + parsed_dates.append(datetime(year, month, day)) break except ValueError: continue + # Try ambiguous patterns with both interpretations + if not parsed_dates: + for pattern in ambiguous_patterns: + match = re.match(pattern, value) + if match: + n1, n2, year = int(match[1]), int(match[2]), int(match[3]) + + # Try DD/MM/YYYY (European - day first) + try: + parsed_dates.append(datetime(year, n2, n1)) + except ValueError: + pass + + # Try MM/DD/YYYY (US - month first) if different and valid + if n1 != n2: + try: + parsed_dates.append(datetime(year, n1, n2)) + except ValueError: + pass + + if parsed_dates: + break + # Try Swedish month names - if not parsed_date: + if not parsed_dates: for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items(): if month_name in value.lower(): # Extract day and year @@ -308,16 +469,28 @@ class FieldNormalizer: if year < 100: year = 2000 + year if year < 50 else 1900 + year try: - parsed_date = datetime(year, int(month_num), day) + parsed_dates.append(datetime(year, int(month_num), day)) break except ValueError: continue - if parsed_date: + # Generate variants for all parsed date interpretations + swedish_months_full = [ + 'januari', 'februari', 'mars', 'april', 'maj', 'juni', + 'juli', 'augusti', 'september', 'oktober', 'november', 'december' + ] + swedish_months_abbrev = [ + 'jan', 'feb', 'mar', 'apr', 'maj', 'jun', + 'jul', 'aug', 'sep', 'okt', 'nov', 'dec' + ] + + for parsed_date in parsed_dates: # Generate different formats iso = parsed_date.strftime('%Y-%m-%d') eu_slash = parsed_date.strftime('%d/%m/%Y') + us_slash = parsed_date.strftime('%m/%d/%Y') # US format MM/DD/YYYY eu_dot = parsed_date.strftime('%d.%m.%Y') + iso_dot = parsed_date.strftime('%Y.%m.%d') # ISO with dots (e.g., 2024.02.08) compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108) @@ -329,21 +502,13 @@ class FieldNormalizer: spaced_short = parsed_date.strftime('%y %m %d') # Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026") - swedish_months_full = [ - 'januari', 'februari', 'mars', 'april', 'maj', 'juni', - 'juli', 'augusti', 'september', 'oktober', 'november', 'december' - ] - swedish_months_abbrev = [ - 'jan', 'feb', 'mar', 'apr', 'maj', 'jun', - 'jul', 'aug', 'sep', 'okt', 'nov', 'dec' - ] month_full = swedish_months_full[parsed_date.month - 1] month_abbrev = swedish_months_abbrev[parsed_date.month - 1] swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}" swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}" variants.extend([ - iso, eu_slash, eu_dot, compact, compact_short, + iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short, eu_dot_short, spaced_full, spaced_short, swedish_format_full, swedish_format_abbrev ]) @@ -360,6 +525,8 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = { 'Amount': FieldNormalizer.normalize_amount, 'InvoiceDate': FieldNormalizer.normalize_date, 'InvoiceDueDate': FieldNormalizer.normalize_date, + 'supplier_organisation_number': FieldNormalizer.normalize_organisation_number, + 'supplier_accounts': FieldNormalizer.normalize_supplier_accounts, } diff --git a/src/yolo/annotation_generator.py b/src/yolo/annotation_generator.py index dc95e38..9dc0823 100644 --- a/src/yolo/annotation_generator.py +++ b/src/yolo/annotation_generator.py @@ -11,6 +11,7 @@ import csv # Field class mapping for YOLO +# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro FIELD_CLASSES = { 'InvoiceNumber': 0, 'InvoiceDate': 1, @@ -19,6 +20,16 @@ FIELD_CLASSES = { 'Bankgiro': 4, 'Plusgiro': 5, 'Amount': 6, + 'supplier_organisation_number': 7, +} + +# Fields that need matching but map to other YOLO classes +# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type +ACCOUNT_FIELD_MAPPING = { + 'supplier_accounts': { + 'BG': 'Bankgiro', # BG:xxx -> Bankgiro class + 'PG': 'Plusgiro', # PG:xxx -> Plusgiro class + } } CLASS_NAMES = [ @@ -29,6 +40,7 @@ CLASS_NAMES = [ 'bankgiro', 'plusgiro', 'amount', + 'supplier_org_number', ] diff --git a/src/yolo/db_dataset.py b/src/yolo/db_dataset.py index 5ea6665..8043446 100644 --- a/src/yolo/db_dataset.py +++ b/src/yolo/db_dataset.py @@ -52,6 +52,7 @@ class DatasetItem: page_no: int labels: list[YOLOAnnotation] is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points + csv_split: str | None = None # CSV-defined split ('train', 'test', etc.) class DBYOLODataset: @@ -202,7 +203,7 @@ class DBYOLODataset: total_images += len(images) continue - labels_by_page, is_scanned = doc_data + labels_by_page, is_scanned, csv_split = doc_data for image_path in images: total_images += 1 @@ -218,7 +219,8 @@ class DBYOLODataset: image_path=image_path, page_no=page_no, labels=page_labels, - is_scanned=is_scanned + is_scanned=is_scanned, + csv_split=csv_split )) else: skipped_no_labels += 1 @@ -237,16 +239,17 @@ class DBYOLODataset: self.items, self._doc_ids_ordered = self._split_dataset(all_items) print(f"Split '{self.split}': {len(self.items)} items") - def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]]: + def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]: """ Load labels from database for given document IDs using batch queries. Returns: - Dict of doc_id -> (page_labels, is_scanned) + Dict of doc_id -> (page_labels, is_scanned, split) where page_labels is {page_no -> list[YOLOAnnotation]} - and is_scanned indicates if bbox is in pixel coords (True) or PDF points (False) + is_scanned indicates if bbox is in pixel coords (True) or PDF points (False) + split is the CSV-defined split ('train', 'test', etc.) or None """ - result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]] = {} + result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]] = {} # Query in batches using efficient batch method batch_size = 500 @@ -263,6 +266,9 @@ class DBYOLODataset: # Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points) is_scanned = doc.get('pdf_type') == 'scanned' + # Get CSV-defined split + csv_split = doc.get('split') + page_labels: dict[int, list[YOLOAnnotation]] = {} for field_result in doc.get('field_results', []): @@ -292,7 +298,7 @@ class DBYOLODataset: page_labels[page_no].append(annotation) if page_labels: - result[doc_id] = (page_labels, is_scanned) + result[doc_id] = (page_labels, is_scanned, csv_split) return result @@ -333,7 +339,10 @@ class DBYOLODataset: def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]: """ - Split items into train/val/test based on document ID. + Split items into train/val/test based on CSV-defined split field. + + If CSV has 'split' field, use it directly. + Otherwise, fall back to random splitting based on train_ratio/val_ratio. Returns: Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be @@ -341,33 +350,84 @@ class DBYOLODataset: """ # Group by document ID for proper splitting doc_items: dict[str, list[DatasetItem]] = {} + doc_csv_split: dict[str, str | None] = {} # Track CSV split per document + for item in items: if item.document_id not in doc_items: doc_items[item.document_id] = [] + doc_csv_split[item.document_id] = item.csv_split doc_items[item.document_id].append(item) - # Shuffle document IDs + # Check if we have CSV-defined splits + has_csv_splits = any(split is not None for split in doc_csv_split.values()) + doc_ids = list(doc_items.keys()) - random.seed(self.seed) - random.shuffle(doc_ids) - # Apply limit if specified (before splitting) - if self.limit is not None and self.limit < len(doc_ids): - doc_ids = doc_ids[:self.limit] - print(f"Limited to {self.limit} documents") + if has_csv_splits: + # Use CSV-defined splits + print("Using CSV-defined split field for train/val/test assignment") - # Calculate split indices - n_total = len(doc_ids) - n_train = int(n_total * self.train_ratio) - n_val = int(n_total * self.val_ratio) + # Map split values: 'train' -> train, 'test' -> test, None -> train (fallback) + # 'val' is taken from train set using val_ratio + split_doc_ids = [] - # Split document IDs - if self.split == 'train': - split_doc_ids = doc_ids[:n_train] - elif self.split == 'val': - split_doc_ids = doc_ids[n_train:n_train + n_val] - else: # test - split_doc_ids = doc_ids[n_train + n_val:] + if self.split == 'train': + # Get documents marked as 'train' or no split defined + train_docs = [doc_id for doc_id in doc_ids + if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')] + + # Take train_ratio of train docs for actual training, rest for val + random.seed(self.seed) + random.shuffle(train_docs) + + n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio))) + split_doc_ids = train_docs[:n_train] + + elif self.split == 'val': + # Get documents marked as 'train' and take val portion + train_docs = [doc_id for doc_id in doc_ids + if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')] + + random.seed(self.seed) + random.shuffle(train_docs) + + n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio))) + split_doc_ids = train_docs[n_train:] + + else: # test + # Get documents marked as 'test' + split_doc_ids = [doc_id for doc_id in doc_ids + if doc_csv_split[doc_id] in ('test', 'Test', 'TEST')] + + # Apply limit if specified + if self.limit is not None and self.limit < len(split_doc_ids): + split_doc_ids = split_doc_ids[:self.limit] + print(f"Limited to {self.limit} documents") + + else: + # Fall back to random splitting (original behavior) + print("No CSV split field found, using random splitting") + + random.seed(self.seed) + random.shuffle(doc_ids) + + # Apply limit if specified (before splitting) + if self.limit is not None and self.limit < len(doc_ids): + doc_ids = doc_ids[:self.limit] + print(f"Limited to {self.limit} documents") + + # Calculate split indices + n_total = len(doc_ids) + n_train = int(n_total * self.train_ratio) + n_val = int(n_total * self.val_ratio) + + # Split document IDs + if self.split == 'train': + split_doc_ids = doc_ids[:n_train] + elif self.split == 'val': + split_doc_ids = doc_ids[n_train:n_train + n_val] + else: # test + split_doc_ids = doc_ids[n_train + n_val:] # Collect items for this split split_items = [] @@ -381,29 +441,64 @@ class DBYOLODataset: Split items using cached data from a shared dataset. Uses pre-computed doc_ids order for consistent splits. + Respects CSV-defined splits if available. """ - # Group by document ID + # Group by document ID and track CSV splits doc_items: dict[str, list[DatasetItem]] = {} + doc_csv_split: dict[str, str | None] = {} + for item in self._all_items: if item.document_id not in doc_items: doc_items[item.document_id] = [] + doc_csv_split[item.document_id] = item.csv_split doc_items[item.document_id].append(item) - # Use cached doc_ids order + # Check if we have CSV-defined splits + has_csv_splits = any(split is not None for split in doc_csv_split.values()) + doc_ids = self._doc_ids_ordered - # Calculate split indices - n_total = len(doc_ids) - n_train = int(n_total * self.train_ratio) - n_val = int(n_total * self.val_ratio) + if has_csv_splits: + # Use CSV-defined splits + if self.split == 'train': + train_docs = [doc_id for doc_id in doc_ids + if doc_id in doc_csv_split and + doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')] - # Split document IDs based on split type - if self.split == 'train': - split_doc_ids = doc_ids[:n_train] - elif self.split == 'val': - split_doc_ids = doc_ids[n_train:n_train + n_val] - else: # test - split_doc_ids = doc_ids[n_train + n_val:] + random.seed(self.seed) + random.shuffle(train_docs) + + n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio))) + split_doc_ids = train_docs[:n_train] + + elif self.split == 'val': + train_docs = [doc_id for doc_id in doc_ids + if doc_id in doc_csv_split and + doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')] + + random.seed(self.seed) + random.shuffle(train_docs) + + n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio))) + split_doc_ids = train_docs[n_train:] + + else: # test + split_doc_ids = [doc_id for doc_id in doc_ids + if doc_id in doc_csv_split and + doc_csv_split[doc_id] in ('test', 'Test', 'TEST')] + + else: + # Fall back to random splitting + n_total = len(doc_ids) + n_train = int(n_total * self.train_ratio) + n_val = int(n_total * self.val_ratio) + + if self.split == 'train': + split_doc_ids = doc_ids[:n_train] + elif self.split == 'val': + split_doc_ids = doc_ids[n_train:n_train + n_val] + else: # test + split_doc_ids = doc_ids[n_train + n_val:] # Collect items for this split split_items = []