WIP
This commit is contained in:
@@ -72,7 +72,10 @@
|
|||||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")",
|
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")",
|
||||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")",
|
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")",
|
||||||
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")",
|
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")",
|
||||||
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")"
|
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")",
|
||||||
|
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/structured_data/*.csv 2>/dev/null | head -20\")",
|
||||||
|
"Bash(tasklist:*)",
|
||||||
|
"Bash(findstr:*)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": [],
|
"ask": [],
|
||||||
|
|||||||
@@ -9,12 +9,24 @@ import argparse
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
# Global flag for graceful shutdown
|
||||||
|
_shutdown_requested = False
|
||||||
|
|
||||||
|
|
||||||
|
def _signal_handler(signum, frame):
|
||||||
|
"""Handle interrupt signals for graceful shutdown."""
|
||||||
|
global _shutdown_requested
|
||||||
|
_shutdown_requested = True
|
||||||
|
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
||||||
|
print("(Press Ctrl+C again to force quit)\n")
|
||||||
|
|
||||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
@@ -111,6 +123,12 @@ def process_single_document(args_tuple):
|
|||||||
|
|
||||||
report = AutoLabelReport(document_id=doc_id)
|
report = AutoLabelReport(document_id=doc_id)
|
||||||
report.pdf_path = str(pdf_path)
|
report.pdf_path = str(pdf_path)
|
||||||
|
# Store metadata fields from CSV
|
||||||
|
report.split = row_dict.get('split')
|
||||||
|
report.customer_number = row_dict.get('customer_number')
|
||||||
|
report.supplier_name = row_dict.get('supplier_name')
|
||||||
|
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
|
||||||
|
report.supplier_accounts = row_dict.get('supplier_accounts')
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'doc_id': doc_id,
|
'doc_id': doc_id,
|
||||||
@@ -204,6 +222,67 @@ def process_single_document(args_tuple):
|
|||||||
context_keywords=best.context_keywords
|
context_keywords=best.context_keywords
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# Match supplier_accounts and map to Bankgiro/Plusgiro
|
||||||
|
supplier_accounts_value = row_dict.get('supplier_accounts')
|
||||||
|
if supplier_accounts_value:
|
||||||
|
# Parse accounts: "BG:xxx | PG:yyy" format
|
||||||
|
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
|
||||||
|
for account in accounts:
|
||||||
|
account = account.strip()
|
||||||
|
if not account:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine account type (BG or PG) and extract account number
|
||||||
|
account_type = None
|
||||||
|
account_number = account # Default to full value
|
||||||
|
|
||||||
|
if account.upper().startswith('BG:'):
|
||||||
|
account_type = 'Bankgiro'
|
||||||
|
account_number = account[3:].strip() # Remove "BG:" prefix
|
||||||
|
elif account.upper().startswith('BG '):
|
||||||
|
account_type = 'Bankgiro'
|
||||||
|
account_number = account[2:].strip() # Remove "BG" prefix
|
||||||
|
elif account.upper().startswith('PG:'):
|
||||||
|
account_type = 'Plusgiro'
|
||||||
|
account_number = account[3:].strip() # Remove "PG:" prefix
|
||||||
|
elif account.upper().startswith('PG '):
|
||||||
|
account_type = 'Plusgiro'
|
||||||
|
account_number = account[2:].strip() # Remove "PG" prefix
|
||||||
|
else:
|
||||||
|
# Try to guess from format - Plusgiro often has format XXXXXXX-X
|
||||||
|
digits = ''.join(c for c in account if c.isdigit())
|
||||||
|
if len(digits) == 8 and '-' in account:
|
||||||
|
account_type = 'Plusgiro'
|
||||||
|
elif len(digits) in (7, 8):
|
||||||
|
account_type = 'Bankgiro' # Default to Bankgiro
|
||||||
|
|
||||||
|
if not account_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Normalize and match using the account number (without prefix)
|
||||||
|
normalized = normalize_field('supplier_accounts', account_number)
|
||||||
|
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
|
||||||
|
|
||||||
|
if field_matches:
|
||||||
|
best = field_matches[0]
|
||||||
|
# Add to matches under the target class (Bankgiro/Plusgiro)
|
||||||
|
if account_type not in matches:
|
||||||
|
matches[account_type] = []
|
||||||
|
matches[account_type].extend(field_matches)
|
||||||
|
matched_fields.add('supplier_accounts')
|
||||||
|
|
||||||
|
report.add_field_result(FieldMatchResult(
|
||||||
|
field_name=f'supplier_accounts({account_type})',
|
||||||
|
csv_value=account_number, # Store without prefix
|
||||||
|
matched=True,
|
||||||
|
score=best.score,
|
||||||
|
matched_text=best.matched_text,
|
||||||
|
candidate_used=best.value,
|
||||||
|
bbox=best.bbox,
|
||||||
|
page_no=page_no,
|
||||||
|
context_keywords=best.context_keywords
|
||||||
|
))
|
||||||
|
|
||||||
# Count annotations
|
# Count annotations
|
||||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||||
|
|
||||||
@@ -329,6 +408,10 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGINT, _signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
|
|
||||||
# Import here to avoid slow startup
|
# Import here to avoid slow startup
|
||||||
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
||||||
from ..data.autolabel_report import ReportWriter
|
from ..data.autolabel_report import ReportWriter
|
||||||
@@ -364,6 +447,7 @@ def main():
|
|||||||
from ..data.db import DocumentDB
|
from ..data.db import DocumentDB
|
||||||
db = DocumentDB()
|
db = DocumentDB()
|
||||||
db.connect()
|
db.connect()
|
||||||
|
db.create_tables() # Ensure tables exist
|
||||||
print("Connected to database for status checking")
|
print("Connected to database for status checking")
|
||||||
|
|
||||||
# Global stats
|
# Global stats
|
||||||
@@ -458,6 +542,11 @@ def main():
|
|||||||
try:
|
try:
|
||||||
# Process CSV files one by one (streaming)
|
# Process CSV files one by one (streaming)
|
||||||
for csv_idx, csv_file in enumerate(csv_files):
|
for csv_idx, csv_file in enumerate(csv_files):
|
||||||
|
# Check for shutdown request
|
||||||
|
if _shutdown_requested:
|
||||||
|
print("\nShutdown requested. Stopping after current batch...")
|
||||||
|
break
|
||||||
|
|
||||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||||
|
|
||||||
# Load only this CSV file
|
# Load only this CSV file
|
||||||
@@ -548,6 +637,13 @@ def main():
|
|||||||
'Bankgiro': row.Bankgiro,
|
'Bankgiro': row.Bankgiro,
|
||||||
'Plusgiro': row.Plusgiro,
|
'Plusgiro': row.Plusgiro,
|
||||||
'Amount': row.Amount,
|
'Amount': row.Amount,
|
||||||
|
# New fields
|
||||||
|
'supplier_organisation_number': row.supplier_organisation_number,
|
||||||
|
'supplier_accounts': row.supplier_accounts,
|
||||||
|
# Metadata fields (not for matching, but for database storage)
|
||||||
|
'split': row.split,
|
||||||
|
'customer_number': row.customer_number,
|
||||||
|
'supplier_name': row.supplier_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.append((
|
tasks.append((
|
||||||
@@ -647,11 +743,19 @@ def main():
|
|||||||
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||||
for task in tasks}
|
for task in tasks}
|
||||||
|
|
||||||
|
# Per-document timeout: 120 seconds (2 minutes)
|
||||||
|
# This prevents a single stuck document from blocking the entire batch
|
||||||
|
DOCUMENT_TIMEOUT = 120
|
||||||
|
|
||||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||||
doc_id = futures[future]
|
doc_id = futures[future]
|
||||||
try:
|
try:
|
||||||
result = future.result()
|
result = future.result(timeout=DOCUMENT_TIMEOUT)
|
||||||
handle_result(result)
|
handle_result(result)
|
||||||
|
except TimeoutError:
|
||||||
|
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
|
||||||
|
# Cancel the stuck future
|
||||||
|
future.cancel()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
handle_error(doc_id, e)
|
handle_error(doc_id, e)
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,13 @@ def create_tables(conn):
|
|||||||
annotations_generated INTEGER,
|
annotations_generated INTEGER,
|
||||||
processing_time_ms REAL,
|
processing_time_ms REAL,
|
||||||
timestamp TIMESTAMPTZ,
|
timestamp TIMESTAMPTZ,
|
||||||
errors JSONB DEFAULT '[]'
|
errors JSONB DEFAULT '[]',
|
||||||
|
-- New fields for extended CSV format
|
||||||
|
split TEXT,
|
||||||
|
customer_number TEXT,
|
||||||
|
supplier_name TEXT,
|
||||||
|
supplier_organisation_number TEXT,
|
||||||
|
supplier_accounts TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS field_results (
|
CREATE TABLE IF NOT EXISTS field_results (
|
||||||
@@ -56,6 +62,26 @@ def create_tables(conn):
|
|||||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
|
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
|
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
|
||||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
|
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
|
||||||
|
|
||||||
|
-- Add new columns to existing tables if they don't exist (for migration)
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN split TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN customer_number TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
@@ -82,7 +108,8 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
|||||||
INSERT INTO documents
|
INSERT INTO documents
|
||||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||||
fields_matched, fields_total, annotations_generated,
|
fields_matched, fields_total, annotations_generated,
|
||||||
processing_time_ms, timestamp, errors)
|
processing_time_ms, timestamp, errors,
|
||||||
|
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
|
||||||
VALUES %s
|
VALUES %s
|
||||||
ON CONFLICT (document_id) DO UPDATE SET
|
ON CONFLICT (document_id) DO UPDATE SET
|
||||||
pdf_path = EXCLUDED.pdf_path,
|
pdf_path = EXCLUDED.pdf_path,
|
||||||
@@ -94,7 +121,12 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
|||||||
annotations_generated = EXCLUDED.annotations_generated,
|
annotations_generated = EXCLUDED.annotations_generated,
|
||||||
processing_time_ms = EXCLUDED.processing_time_ms,
|
processing_time_ms = EXCLUDED.processing_time_ms,
|
||||||
timestamp = EXCLUDED.timestamp,
|
timestamp = EXCLUDED.timestamp,
|
||||||
errors = EXCLUDED.errors
|
errors = EXCLUDED.errors,
|
||||||
|
split = EXCLUDED.split,
|
||||||
|
customer_number = EXCLUDED.customer_number,
|
||||||
|
supplier_name = EXCLUDED.supplier_name,
|
||||||
|
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
|
||||||
|
supplier_accounts = EXCLUDED.supplier_accounts
|
||||||
""", doc_batch)
|
""", doc_batch)
|
||||||
doc_batch = []
|
doc_batch = []
|
||||||
|
|
||||||
@@ -150,7 +182,13 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
|||||||
record.get('annotations_generated'),
|
record.get('annotations_generated'),
|
||||||
record.get('processing_time_ms'),
|
record.get('processing_time_ms'),
|
||||||
record.get('timestamp'),
|
record.get('timestamp'),
|
||||||
json.dumps(record.get('errors', []))
|
json.dumps(record.get('errors', [])),
|
||||||
|
# New fields
|
||||||
|
record.get('split'),
|
||||||
|
record.get('customer_number'),
|
||||||
|
record.get('supplier_name'),
|
||||||
|
record.get('supplier_organisation_number'),
|
||||||
|
record.get('supplier_accounts'),
|
||||||
))
|
))
|
||||||
|
|
||||||
for field in record.get('field_results', []):
|
for field in record.get('field_results', []):
|
||||||
|
|||||||
@@ -63,6 +63,12 @@ class AutoLabelReport:
|
|||||||
processing_time_ms: float = 0.0
|
processing_time_ms: float = 0.0
|
||||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
errors: list[str] = field(default_factory=list)
|
errors: list[str] = field(default_factory=list)
|
||||||
|
# New metadata fields (from CSV, not for matching)
|
||||||
|
split: str | None = None
|
||||||
|
customer_number: str | None = None
|
||||||
|
supplier_name: str | None = None
|
||||||
|
supplier_organisation_number: str | None = None
|
||||||
|
supplier_accounts: str | None = None
|
||||||
|
|
||||||
def add_field_result(self, result: FieldMatchResult) -> None:
|
def add_field_result(self, result: FieldMatchResult) -> None:
|
||||||
"""Add a field matching result."""
|
"""Add a field matching result."""
|
||||||
@@ -87,7 +93,13 @@ class AutoLabelReport:
|
|||||||
'label_paths': self.label_paths,
|
'label_paths': self.label_paths,
|
||||||
'processing_time_ms': self.processing_time_ms,
|
'processing_time_ms': self.processing_time_ms,
|
||||||
'timestamp': self.timestamp,
|
'timestamp': self.timestamp,
|
||||||
'errors': self.errors
|
'errors': self.errors,
|
||||||
|
# New metadata fields
|
||||||
|
'split': self.split,
|
||||||
|
'customer_number': self.customer_number,
|
||||||
|
'supplier_name': self.supplier_name,
|
||||||
|
'supplier_organisation_number': self.supplier_organisation_number,
|
||||||
|
'supplier_accounts': self.supplier_accounts,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_json(self, indent: int | None = None) -> str:
|
def to_json(self, indent: int | None = None) -> str:
|
||||||
|
|||||||
@@ -25,6 +25,12 @@ class InvoiceRow:
|
|||||||
Bankgiro: str | None = None
|
Bankgiro: str | None = None
|
||||||
Plusgiro: str | None = None
|
Plusgiro: str | None = None
|
||||||
Amount: Decimal | None = None
|
Amount: Decimal | None = None
|
||||||
|
# New fields
|
||||||
|
split: str | None = None # train/test split indicator
|
||||||
|
customer_number: str | None = None # Customer number (no matching needed)
|
||||||
|
supplier_name: str | None = None # Supplier name (no matching)
|
||||||
|
supplier_organisation_number: str | None = None # Swedish org number (needs matching)
|
||||||
|
supplier_accounts: str | None = None # Supplier accounts (needs matching)
|
||||||
|
|
||||||
# Raw values for reference
|
# Raw values for reference
|
||||||
raw_data: dict = field(default_factory=dict)
|
raw_data: dict = field(default_factory=dict)
|
||||||
@@ -40,6 +46,8 @@ class InvoiceRow:
|
|||||||
'Bankgiro': self.Bankgiro,
|
'Bankgiro': self.Bankgiro,
|
||||||
'Plusgiro': self.Plusgiro,
|
'Plusgiro': self.Plusgiro,
|
||||||
'Amount': str(self.Amount) if self.Amount else None,
|
'Amount': str(self.Amount) if self.Amount else None,
|
||||||
|
'supplier_organisation_number': self.supplier_organisation_number,
|
||||||
|
'supplier_accounts': self.supplier_accounts,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_field_value(self, field_name: str) -> str | None:
|
def get_field_value(self, field_name: str) -> str | None:
|
||||||
@@ -68,6 +76,12 @@ class CSVLoader:
|
|||||||
'Bankgiro': 'Bankgiro',
|
'Bankgiro': 'Bankgiro',
|
||||||
'Plusgiro': 'Plusgiro',
|
'Plusgiro': 'Plusgiro',
|
||||||
'Amount': 'Amount',
|
'Amount': 'Amount',
|
||||||
|
# New fields
|
||||||
|
'split': 'split',
|
||||||
|
'customer_number': 'customer_number',
|
||||||
|
'supplier_name': 'supplier_name',
|
||||||
|
'supplier_organisation_number': 'supplier_organisation_number',
|
||||||
|
'supplier_accounts': 'supplier_accounts',
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -200,6 +214,12 @@ class CSVLoader:
|
|||||||
Bankgiro=self._parse_string(row.get('Bankgiro')),
|
Bankgiro=self._parse_string(row.get('Bankgiro')),
|
||||||
Plusgiro=self._parse_string(row.get('Plusgiro')),
|
Plusgiro=self._parse_string(row.get('Plusgiro')),
|
||||||
Amount=self._parse_amount(row.get('Amount')),
|
Amount=self._parse_amount(row.get('Amount')),
|
||||||
|
# New fields
|
||||||
|
split=self._parse_string(row.get('split')),
|
||||||
|
customer_number=self._parse_string(row.get('customer_number')),
|
||||||
|
supplier_name=self._parse_string(row.get('supplier_name')),
|
||||||
|
supplier_organisation_number=self._parse_string(row.get('supplier_organisation_number')),
|
||||||
|
supplier_accounts=self._parse_string(row.get('supplier_accounts')),
|
||||||
raw_data=dict(row)
|
raw_data=dict(row)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,14 +338,16 @@ class CSVLoader:
|
|||||||
row.OCR,
|
row.OCR,
|
||||||
row.Bankgiro,
|
row.Bankgiro,
|
||||||
row.Plusgiro,
|
row.Plusgiro,
|
||||||
row.Amount
|
row.Amount,
|
||||||
|
row.supplier_organisation_number,
|
||||||
|
row.supplier_accounts,
|
||||||
]
|
]
|
||||||
if not any(matchable_fields):
|
if not any(matchable_fields):
|
||||||
issues.append({
|
issues.append({
|
||||||
'row': i,
|
'row': i,
|
||||||
'doc_id': row.DocumentId,
|
'doc_id': row.DocumentId,
|
||||||
'field': 'All',
|
'field': 'All',
|
||||||
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)'
|
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount/supplier_organisation_number/supplier_accounts)'
|
||||||
})
|
})
|
||||||
|
|
||||||
return issues
|
return issues
|
||||||
|
|||||||
111
src/data/db.py
111
src/data/db.py
@@ -26,6 +26,73 @@ class DocumentDB:
|
|||||||
self.conn = psycopg2.connect(self.connection_string)
|
self.conn = psycopg2.connect(self.connection_string)
|
||||||
return self.conn
|
return self.conn
|
||||||
|
|
||||||
|
def create_tables(self):
|
||||||
|
"""Create database tables if they don't exist."""
|
||||||
|
conn = self.connect()
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS documents (
|
||||||
|
document_id TEXT PRIMARY KEY,
|
||||||
|
pdf_path TEXT,
|
||||||
|
pdf_type TEXT,
|
||||||
|
success BOOLEAN,
|
||||||
|
total_pages INTEGER,
|
||||||
|
fields_matched INTEGER,
|
||||||
|
fields_total INTEGER,
|
||||||
|
annotations_generated INTEGER,
|
||||||
|
processing_time_ms REAL,
|
||||||
|
timestamp TIMESTAMPTZ,
|
||||||
|
errors JSONB DEFAULT '[]',
|
||||||
|
-- Extended CSV format fields
|
||||||
|
split TEXT,
|
||||||
|
customer_number TEXT,
|
||||||
|
supplier_name TEXT,
|
||||||
|
supplier_organisation_number TEXT,
|
||||||
|
supplier_accounts TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS field_results (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
|
||||||
|
field_name TEXT,
|
||||||
|
csv_value TEXT,
|
||||||
|
matched BOOLEAN,
|
||||||
|
score REAL,
|
||||||
|
matched_text TEXT,
|
||||||
|
candidate_used TEXT,
|
||||||
|
bbox JSONB,
|
||||||
|
page_no INTEGER,
|
||||||
|
context_keywords JSONB DEFAULT '[]',
|
||||||
|
error TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
|
||||||
|
|
||||||
|
-- Add new columns to existing tables if they don't exist (for migration)
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN split TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN customer_number TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
|
||||||
|
END IF;
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
|
||||||
|
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close database connection."""
|
"""Close database connection."""
|
||||||
if self.conn:
|
if self.conn:
|
||||||
@@ -110,7 +177,9 @@ class DocumentDB:
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||||
fields_matched, fields_total, annotations_generated,
|
fields_matched, fields_total, annotations_generated,
|
||||||
processing_time_ms, timestamp, errors
|
processing_time_ms, timestamp, errors,
|
||||||
|
split, customer_number, supplier_name,
|
||||||
|
supplier_organisation_number, supplier_accounts
|
||||||
FROM documents WHERE document_id = %s
|
FROM documents WHERE document_id = %s
|
||||||
""", (doc_id,))
|
""", (doc_id,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
@@ -129,6 +198,12 @@ class DocumentDB:
|
|||||||
'processing_time_ms': row[8],
|
'processing_time_ms': row[8],
|
||||||
'timestamp': str(row[9]) if row[9] else None,
|
'timestamp': str(row[9]) if row[9] else None,
|
||||||
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
||||||
|
# New fields
|
||||||
|
'split': row[11],
|
||||||
|
'customer_number': row[12],
|
||||||
|
'supplier_name': row[13],
|
||||||
|
'supplier_organisation_number': row[14],
|
||||||
|
'supplier_accounts': row[15],
|
||||||
'field_results': []
|
'field_results': []
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,7 +328,9 @@ class DocumentDB:
|
|||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||||
fields_matched, fields_total, annotations_generated,
|
fields_matched, fields_total, annotations_generated,
|
||||||
processing_time_ms, timestamp, errors
|
processing_time_ms, timestamp, errors,
|
||||||
|
split, customer_number, supplier_name,
|
||||||
|
supplier_organisation_number, supplier_accounts
|
||||||
FROM documents WHERE document_id = ANY(%s)
|
FROM documents WHERE document_id = ANY(%s)
|
||||||
""", (doc_ids,))
|
""", (doc_ids,))
|
||||||
|
|
||||||
@@ -270,6 +347,12 @@ class DocumentDB:
|
|||||||
'processing_time_ms': row[8],
|
'processing_time_ms': row[8],
|
||||||
'timestamp': str(row[9]) if row[9] else None,
|
'timestamp': str(row[9]) if row[9] else None,
|
||||||
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
||||||
|
# New fields
|
||||||
|
'split': row[11],
|
||||||
|
'customer_number': row[12],
|
||||||
|
'supplier_name': row[13],
|
||||||
|
'supplier_organisation_number': row[14],
|
||||||
|
'supplier_accounts': row[15],
|
||||||
'field_results': []
|
'field_results': []
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,8 +398,9 @@ class DocumentDB:
|
|||||||
INSERT INTO documents
|
INSERT INTO documents
|
||||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||||
fields_matched, fields_total, annotations_generated,
|
fields_matched, fields_total, annotations_generated,
|
||||||
processing_time_ms, timestamp, errors)
|
processing_time_ms, timestamp, errors,
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
""", (
|
""", (
|
||||||
doc_id,
|
doc_id,
|
||||||
report.get('pdf_path'),
|
report.get('pdf_path'),
|
||||||
@@ -328,7 +412,13 @@ class DocumentDB:
|
|||||||
report.get('annotations_generated'),
|
report.get('annotations_generated'),
|
||||||
report.get('processing_time_ms'),
|
report.get('processing_time_ms'),
|
||||||
report.get('timestamp'),
|
report.get('timestamp'),
|
||||||
json.dumps(report.get('errors', []))
|
json.dumps(report.get('errors', [])),
|
||||||
|
# New fields
|
||||||
|
report.get('split'),
|
||||||
|
report.get('customer_number'),
|
||||||
|
report.get('supplier_name'),
|
||||||
|
report.get('supplier_organisation_number'),
|
||||||
|
report.get('supplier_accounts'),
|
||||||
))
|
))
|
||||||
|
|
||||||
# Batch insert field results using execute_values
|
# Batch insert field results using execute_values
|
||||||
@@ -387,7 +477,13 @@ class DocumentDB:
|
|||||||
r.get('annotations_generated'),
|
r.get('annotations_generated'),
|
||||||
r.get('processing_time_ms'),
|
r.get('processing_time_ms'),
|
||||||
r.get('timestamp'),
|
r.get('timestamp'),
|
||||||
json.dumps(r.get('errors', []))
|
json.dumps(r.get('errors', [])),
|
||||||
|
# New fields
|
||||||
|
r.get('split'),
|
||||||
|
r.get('customer_number'),
|
||||||
|
r.get('supplier_name'),
|
||||||
|
r.get('supplier_organisation_number'),
|
||||||
|
r.get('supplier_accounts'),
|
||||||
)
|
)
|
||||||
for r in reports
|
for r in reports
|
||||||
]
|
]
|
||||||
@@ -395,7 +491,8 @@ class DocumentDB:
|
|||||||
INSERT INTO documents
|
INSERT INTO documents
|
||||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||||
fields_matched, fields_total, annotations_generated,
|
fields_matched, fields_total, annotations_generated,
|
||||||
processing_time_ms, timestamp, errors)
|
processing_time_ms, timestamp, errors,
|
||||||
|
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
|
||||||
VALUES %s
|
VALUES %s
|
||||||
""", doc_values)
|
""", doc_values)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ from functools import cached_property
|
|||||||
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||||
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||||
|
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212]') # en-dash, em-dash, minus sign
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_dashes(text: str) -> str:
|
||||||
|
"""Normalize different dash types to standard hyphen-minus (ASCII 45)."""
|
||||||
|
return _DASH_PATTERN.sub('-', text)
|
||||||
|
|
||||||
|
|
||||||
class TokenLike(Protocol):
|
class TokenLike(Protocol):
|
||||||
@@ -143,6 +149,9 @@ CONTEXT_KEYWORDS = {
|
|||||||
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||||
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||||
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||||
|
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
|
||||||
|
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
|
||||||
|
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -207,7 +216,10 @@ class FieldMatcher:
|
|||||||
|
|
||||||
# Strategy 4: Substring match (for values embedded in longer text)
|
# Strategy 4: Substring match (for values embedded in longer text)
|
||||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'):
|
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||||
|
# in OCR payment lines or other unrelated text
|
||||||
|
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||||
|
'supplier_organisation_number', 'supplier_accounts'):
|
||||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||||
matches.extend(substring_matches)
|
matches.extend(substring_matches)
|
||||||
|
|
||||||
@@ -237,7 +249,8 @@ class FieldMatcher:
|
|||||||
"""Find tokens that exactly match the value."""
|
"""Find tokens that exactly match the value."""
|
||||||
matches = []
|
matches = []
|
||||||
value_lower = value.lower()
|
value_lower = value.lower()
|
||||||
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None
|
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||||
|
'supplier_organisation_number', 'supplier_accounts') else None
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
token_text = token.text.strip()
|
token_text = token.text.strip()
|
||||||
@@ -355,33 +368,36 @@ class FieldMatcher:
|
|||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
# Supported fields for substring matching
|
# Supported fields for substring matching
|
||||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount')
|
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
||||||
|
'supplier_organisation_number', 'supplier_accounts')
|
||||||
if field_name not in supported_fields:
|
if field_name not in supported_fields:
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
token_text = token.text.strip()
|
token_text = token.text.strip()
|
||||||
|
# Normalize different dash types to hyphen-minus for matching
|
||||||
|
token_text_normalized = _normalize_dashes(token_text)
|
||||||
|
|
||||||
# Skip if token is the same length as value (would be exact match)
|
# Skip if token is the same length as value (would be exact match)
|
||||||
if len(token_text) <= len(value):
|
if len(token_text_normalized) <= len(value):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if value appears as substring
|
# Check if value appears as substring (using normalized text)
|
||||||
if value in token_text:
|
if value in token_text_normalized:
|
||||||
# Verify it's a proper boundary match (not part of a larger number)
|
# Verify it's a proper boundary match (not part of a larger number)
|
||||||
idx = token_text.find(value)
|
idx = token_text_normalized.find(value)
|
||||||
|
|
||||||
# Check character before (if exists)
|
# Check character before (if exists)
|
||||||
if idx > 0:
|
if idx > 0:
|
||||||
char_before = token_text[idx - 1]
|
char_before = token_text_normalized[idx - 1]
|
||||||
# Must be non-digit (allow : space - etc)
|
# Must be non-digit (allow : space - etc)
|
||||||
if char_before.isdigit():
|
if char_before.isdigit():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check character after (if exists)
|
# Check character after (if exists)
|
||||||
end_idx = idx + len(value)
|
end_idx = idx + len(value)
|
||||||
if end_idx < len(token_text):
|
if end_idx < len(token_text_normalized):
|
||||||
char_after = token_text[end_idx]
|
char_after = token_text_normalized[end_idx]
|
||||||
# Must be non-digit
|
# Must be non-digit
|
||||||
if char_after.isdigit():
|
if char_after.isdigit():
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -39,9 +39,12 @@ class FieldNormalizer:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clean_text(text: str) -> str:
|
def clean_text(text: str) -> str:
|
||||||
"""Remove invisible characters and normalize whitespace."""
|
"""Remove invisible characters and normalize whitespace and dashes."""
|
||||||
# Remove zero-width characters
|
# Remove zero-width characters
|
||||||
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
|
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
|
||||||
|
# Normalize different dash types to standard hyphen-minus (ASCII 45)
|
||||||
|
# en-dash (–, U+2013), em-dash (—, U+2014), minus sign (−, U+2212)
|
||||||
|
text = re.sub(r'[\u2013\u2014\u2212]', '-', text)
|
||||||
# Normalize whitespace
|
# Normalize whitespace
|
||||||
text = ' '.join(text.split())
|
text = ' '.join(text.split())
|
||||||
return text.strip()
|
return text.strip()
|
||||||
@@ -130,6 +133,133 @@ class FieldNormalizer:
|
|||||||
|
|
||||||
return list(set(v for v in variants if v))
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_organisation_number(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize Swedish organisation number and generate VAT number variants.
|
||||||
|
|
||||||
|
Organisation number format: NNNNNN-NNNN (6 digits + hyphen + 4 digits)
|
||||||
|
Swedish VAT format: SE + org_number (10 digits) + 01
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'556123-4567' -> ['556123-4567', '5561234567', 'SE556123456701', ...]
|
||||||
|
'5561234567' -> ['5561234567', '556123-4567', 'SE556123456701', ...]
|
||||||
|
'SE556123456701' -> ['SE556123456701', '5561234567', '556123-4567', ...]
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
|
||||||
|
# Check if input is a VAT number (starts with SE, ends with 01)
|
||||||
|
org_digits = None
|
||||||
|
if value.upper().startswith('SE') and len(value) >= 12:
|
||||||
|
# Extract org number from VAT: SE + 10 digits + 01
|
||||||
|
potential_org = re.sub(r'\D', '', value[2:]) # Remove SE prefix, keep digits
|
||||||
|
if len(potential_org) == 12 and potential_org.endswith('01'):
|
||||||
|
org_digits = potential_org[:-2] # Remove trailing 01
|
||||||
|
elif len(potential_org) == 10:
|
||||||
|
org_digits = potential_org
|
||||||
|
|
||||||
|
if org_digits is None:
|
||||||
|
org_digits = re.sub(r'\D', '', value)
|
||||||
|
|
||||||
|
variants = [value]
|
||||||
|
|
||||||
|
if org_digits:
|
||||||
|
variants.append(org_digits)
|
||||||
|
|
||||||
|
# Standard format: NNNNNN-NNNN (10 digits total)
|
||||||
|
if len(org_digits) == 10:
|
||||||
|
with_dash = f"{org_digits[:6]}-{org_digits[6:]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
|
||||||
|
# Swedish VAT format: SE + org_number + 01
|
||||||
|
vat_number = f"SE{org_digits}01"
|
||||||
|
variants.append(vat_number)
|
||||||
|
variants.append(vat_number.lower()) # se556123456701
|
||||||
|
# With spaces: SE 5561234567 01
|
||||||
|
variants.append(f"SE {org_digits} 01")
|
||||||
|
variants.append(f"SE {org_digits[:6]}-{org_digits[6:]} 01")
|
||||||
|
# Without 01 suffix (some invoices show just SE + org)
|
||||||
|
variants.append(f"SE{org_digits}")
|
||||||
|
variants.append(f"SE {org_digits}")
|
||||||
|
|
||||||
|
# Some may have 12 digits (century prefix): NNNNNNNN-NNNN
|
||||||
|
elif len(org_digits) == 12:
|
||||||
|
with_dash = f"{org_digits[:8]}-{org_digits[8:]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
# Also try without century prefix
|
||||||
|
short_version = org_digits[2:]
|
||||||
|
variants.append(short_version)
|
||||||
|
variants.append(f"{short_version[:6]}-{short_version[6:]}")
|
||||||
|
# VAT with short version
|
||||||
|
vat_number = f"SE{short_version}01"
|
||||||
|
variants.append(vat_number)
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_supplier_accounts(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize supplier accounts field.
|
||||||
|
|
||||||
|
The field may contain multiple accounts separated by ' | '.
|
||||||
|
Format examples:
|
||||||
|
'PG:48676043 | PG:49128028 | PG:8915035'
|
||||||
|
'BG:5393-9484'
|
||||||
|
|
||||||
|
Each account is normalized separately to generate variants.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'PG:48676043' -> ['PG:48676043', '48676043', '4867604-3']
|
||||||
|
'BG:5393-9484' -> ['BG:5393-9484', '5393-9484', '53939484']
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
variants = []
|
||||||
|
|
||||||
|
# Split by ' | ' to handle multiple accounts
|
||||||
|
accounts = [acc.strip() for acc in value.split('|')]
|
||||||
|
|
||||||
|
for account in accounts:
|
||||||
|
account = account.strip()
|
||||||
|
if not account:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add original value
|
||||||
|
variants.append(account)
|
||||||
|
|
||||||
|
# Remove prefix (PG:, BG:, etc.)
|
||||||
|
if ':' in account:
|
||||||
|
prefix, number = account.split(':', 1)
|
||||||
|
number = number.strip()
|
||||||
|
variants.append(number) # Just the number without prefix
|
||||||
|
|
||||||
|
# Also add with different prefix formats
|
||||||
|
prefix_upper = prefix.strip().upper()
|
||||||
|
variants.append(f"{prefix_upper}:{number}")
|
||||||
|
variants.append(f"{prefix_upper}: {number}") # With space
|
||||||
|
else:
|
||||||
|
number = account
|
||||||
|
|
||||||
|
# Extract digits only
|
||||||
|
digits_only = re.sub(r'\D', '', number)
|
||||||
|
|
||||||
|
if digits_only:
|
||||||
|
variants.append(digits_only)
|
||||||
|
|
||||||
|
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
|
||||||
|
if len(digits_only) == 8:
|
||||||
|
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
# Also try 4-4 format for bankgiro
|
||||||
|
variants.append(f"{digits_only[:4]}-{digits_only[4:]}")
|
||||||
|
elif len(digits_only) == 7:
|
||||||
|
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
elif len(digits_only) == 10:
|
||||||
|
# 6-4 format (like org number)
|
||||||
|
variants.append(f"{digits_only[:6]}-{digits_only[6:]}")
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_amount(value: str) -> list[str]:
|
def normalize_amount(value: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -264,40 +394,71 @@ class FieldNormalizer:
|
|||||||
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
|
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
|
||||||
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
|
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
|
||||||
'13 december 2025' -> ['2025-12-13', ...]
|
'13 december 2025' -> ['2025-12-13', ...]
|
||||||
|
|
||||||
|
Note: For ambiguous formats like DD/MM/YYYY vs MM/DD/YYYY,
|
||||||
|
we generate variants for BOTH interpretations to maximize matching.
|
||||||
"""
|
"""
|
||||||
value = FieldNormalizer.clean_text(value)
|
value = FieldNormalizer.clean_text(value)
|
||||||
variants = [value]
|
variants = [value]
|
||||||
|
|
||||||
parsed_date = None
|
parsed_dates = [] # May have multiple interpretations
|
||||||
|
|
||||||
# Try different date formats
|
# Try different date formats
|
||||||
date_patterns = [
|
date_patterns = [
|
||||||
# ISO format with optional time (e.g., 2026-01-09 00:00:00)
|
# ISO format with optional time (e.g., 2026-01-09 00:00:00)
|
||||||
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
||||||
# European format with /
|
|
||||||
(r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
|
||||||
# European format with .
|
|
||||||
(r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
|
||||||
# European format with -
|
|
||||||
(r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
|
||||||
# Swedish format: YYMMDD
|
# Swedish format: YYMMDD
|
||||||
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
|
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
|
||||||
# Swedish format: YYYYMMDD
|
# Swedish format: YYYYMMDD
|
||||||
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Ambiguous patterns - try both DD/MM and MM/DD interpretations
|
||||||
|
ambiguous_patterns = [
|
||||||
|
# Format with / - could be DD/MM/YYYY (European) or MM/DD/YYYY (US)
|
||||||
|
r'^(\d{1,2})/(\d{1,2})/(\d{4})$',
|
||||||
|
# Format with . - typically European DD.MM.YYYY
|
||||||
|
r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$',
|
||||||
|
# Format with - (not ISO) - could be DD-MM-YYYY or MM-DD-YYYY
|
||||||
|
r'^(\d{1,2})-(\d{1,2})-(\d{4})$',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Try unambiguous patterns first
|
||||||
for pattern, extractor in date_patterns:
|
for pattern, extractor in date_patterns:
|
||||||
match = re.match(pattern, value)
|
match = re.match(pattern, value)
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
year, month, day = extractor(match)
|
year, month, day = extractor(match)
|
||||||
parsed_date = datetime(year, month, day)
|
parsed_dates.append(datetime(year, month, day))
|
||||||
break
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Try ambiguous patterns with both interpretations
|
||||||
|
if not parsed_dates:
|
||||||
|
for pattern in ambiguous_patterns:
|
||||||
|
match = re.match(pattern, value)
|
||||||
|
if match:
|
||||||
|
n1, n2, year = int(match[1]), int(match[2]), int(match[3])
|
||||||
|
|
||||||
|
# Try DD/MM/YYYY (European - day first)
|
||||||
|
try:
|
||||||
|
parsed_dates.append(datetime(year, n2, n1))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try MM/DD/YYYY (US - month first) if different and valid
|
||||||
|
if n1 != n2:
|
||||||
|
try:
|
||||||
|
parsed_dates.append(datetime(year, n1, n2))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if parsed_dates:
|
||||||
|
break
|
||||||
|
|
||||||
# Try Swedish month names
|
# Try Swedish month names
|
||||||
if not parsed_date:
|
if not parsed_dates:
|
||||||
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
|
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
|
||||||
if month_name in value.lower():
|
if month_name in value.lower():
|
||||||
# Extract day and year
|
# Extract day and year
|
||||||
@@ -308,16 +469,28 @@ class FieldNormalizer:
|
|||||||
if year < 100:
|
if year < 100:
|
||||||
year = 2000 + year if year < 50 else 1900 + year
|
year = 2000 + year if year < 50 else 1900 + year
|
||||||
try:
|
try:
|
||||||
parsed_date = datetime(year, int(month_num), day)
|
parsed_dates.append(datetime(year, int(month_num), day))
|
||||||
break
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if parsed_date:
|
# Generate variants for all parsed date interpretations
|
||||||
|
swedish_months_full = [
|
||||||
|
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
|
||||||
|
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
|
||||||
|
]
|
||||||
|
swedish_months_abbrev = [
|
||||||
|
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
|
||||||
|
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
|
||||||
|
]
|
||||||
|
|
||||||
|
for parsed_date in parsed_dates:
|
||||||
# Generate different formats
|
# Generate different formats
|
||||||
iso = parsed_date.strftime('%Y-%m-%d')
|
iso = parsed_date.strftime('%Y-%m-%d')
|
||||||
eu_slash = parsed_date.strftime('%d/%m/%Y')
|
eu_slash = parsed_date.strftime('%d/%m/%Y')
|
||||||
|
us_slash = parsed_date.strftime('%m/%d/%Y') # US format MM/DD/YYYY
|
||||||
eu_dot = parsed_date.strftime('%d.%m.%Y')
|
eu_dot = parsed_date.strftime('%d.%m.%Y')
|
||||||
|
iso_dot = parsed_date.strftime('%Y.%m.%d') # ISO with dots (e.g., 2024.02.08)
|
||||||
compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
|
compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
|
||||||
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
|
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
|
||||||
|
|
||||||
@@ -329,21 +502,13 @@ class FieldNormalizer:
|
|||||||
spaced_short = parsed_date.strftime('%y %m %d')
|
spaced_short = parsed_date.strftime('%y %m %d')
|
||||||
|
|
||||||
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
|
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
|
||||||
swedish_months_full = [
|
|
||||||
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
|
|
||||||
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
|
|
||||||
]
|
|
||||||
swedish_months_abbrev = [
|
|
||||||
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
|
|
||||||
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
|
|
||||||
]
|
|
||||||
month_full = swedish_months_full[parsed_date.month - 1]
|
month_full = swedish_months_full[parsed_date.month - 1]
|
||||||
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
|
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
|
||||||
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
|
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
|
||||||
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
|
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
|
||||||
|
|
||||||
variants.extend([
|
variants.extend([
|
||||||
iso, eu_slash, eu_dot, compact, compact_short,
|
iso, eu_slash, us_slash, eu_dot, iso_dot, compact, compact_short,
|
||||||
eu_dot_short, spaced_full, spaced_short,
|
eu_dot_short, spaced_full, spaced_short,
|
||||||
swedish_format_full, swedish_format_abbrev
|
swedish_format_full, swedish_format_abbrev
|
||||||
])
|
])
|
||||||
@@ -360,6 +525,8 @@ NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
|
|||||||
'Amount': FieldNormalizer.normalize_amount,
|
'Amount': FieldNormalizer.normalize_amount,
|
||||||
'InvoiceDate': FieldNormalizer.normalize_date,
|
'InvoiceDate': FieldNormalizer.normalize_date,
|
||||||
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
||||||
|
'supplier_organisation_number': FieldNormalizer.normalize_organisation_number,
|
||||||
|
'supplier_accounts': FieldNormalizer.normalize_supplier_accounts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import csv
|
|||||||
|
|
||||||
|
|
||||||
# Field class mapping for YOLO
|
# Field class mapping for YOLO
|
||||||
|
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
|
||||||
FIELD_CLASSES = {
|
FIELD_CLASSES = {
|
||||||
'InvoiceNumber': 0,
|
'InvoiceNumber': 0,
|
||||||
'InvoiceDate': 1,
|
'InvoiceDate': 1,
|
||||||
@@ -19,6 +20,16 @@ FIELD_CLASSES = {
|
|||||||
'Bankgiro': 4,
|
'Bankgiro': 4,
|
||||||
'Plusgiro': 5,
|
'Plusgiro': 5,
|
||||||
'Amount': 6,
|
'Amount': 6,
|
||||||
|
'supplier_organisation_number': 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Fields that need matching but map to other YOLO classes
|
||||||
|
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
|
||||||
|
ACCOUNT_FIELD_MAPPING = {
|
||||||
|
'supplier_accounts': {
|
||||||
|
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
|
||||||
|
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CLASS_NAMES = [
|
CLASS_NAMES = [
|
||||||
@@ -29,6 +40,7 @@ CLASS_NAMES = [
|
|||||||
'bankgiro',
|
'bankgiro',
|
||||||
'plusgiro',
|
'plusgiro',
|
||||||
'amount',
|
'amount',
|
||||||
|
'supplier_org_number',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class DatasetItem:
|
|||||||
page_no: int
|
page_no: int
|
||||||
labels: list[YOLOAnnotation]
|
labels: list[YOLOAnnotation]
|
||||||
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
|
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
|
||||||
|
csv_split: str | None = None # CSV-defined split ('train', 'test', etc.)
|
||||||
|
|
||||||
|
|
||||||
class DBYOLODataset:
|
class DBYOLODataset:
|
||||||
@@ -202,7 +203,7 @@ class DBYOLODataset:
|
|||||||
total_images += len(images)
|
total_images += len(images)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
labels_by_page, is_scanned = doc_data
|
labels_by_page, is_scanned, csv_split = doc_data
|
||||||
|
|
||||||
for image_path in images:
|
for image_path in images:
|
||||||
total_images += 1
|
total_images += 1
|
||||||
@@ -218,7 +219,8 @@ class DBYOLODataset:
|
|||||||
image_path=image_path,
|
image_path=image_path,
|
||||||
page_no=page_no,
|
page_no=page_no,
|
||||||
labels=page_labels,
|
labels=page_labels,
|
||||||
is_scanned=is_scanned
|
is_scanned=is_scanned,
|
||||||
|
csv_split=csv_split
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
skipped_no_labels += 1
|
skipped_no_labels += 1
|
||||||
@@ -237,16 +239,17 @@ class DBYOLODataset:
|
|||||||
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
||||||
print(f"Split '{self.split}': {len(self.items)} items")
|
print(f"Split '{self.split}': {len(self.items)} items")
|
||||||
|
|
||||||
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]]:
|
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
|
||||||
"""
|
"""
|
||||||
Load labels from database for given document IDs using batch queries.
|
Load labels from database for given document IDs using batch queries.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict of doc_id -> (page_labels, is_scanned)
|
Dict of doc_id -> (page_labels, is_scanned, split)
|
||||||
where page_labels is {page_no -> list[YOLOAnnotation]}
|
where page_labels is {page_no -> list[YOLOAnnotation]}
|
||||||
and is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
|
is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
|
||||||
|
split is the CSV-defined split ('train', 'test', etc.) or None
|
||||||
"""
|
"""
|
||||||
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]] = {}
|
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]] = {}
|
||||||
|
|
||||||
# Query in batches using efficient batch method
|
# Query in batches using efficient batch method
|
||||||
batch_size = 500
|
batch_size = 500
|
||||||
@@ -263,6 +266,9 @@ class DBYOLODataset:
|
|||||||
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
|
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
|
||||||
is_scanned = doc.get('pdf_type') == 'scanned'
|
is_scanned = doc.get('pdf_type') == 'scanned'
|
||||||
|
|
||||||
|
# Get CSV-defined split
|
||||||
|
csv_split = doc.get('split')
|
||||||
|
|
||||||
page_labels: dict[int, list[YOLOAnnotation]] = {}
|
page_labels: dict[int, list[YOLOAnnotation]] = {}
|
||||||
|
|
||||||
for field_result in doc.get('field_results', []):
|
for field_result in doc.get('field_results', []):
|
||||||
@@ -292,7 +298,7 @@ class DBYOLODataset:
|
|||||||
page_labels[page_no].append(annotation)
|
page_labels[page_no].append(annotation)
|
||||||
|
|
||||||
if page_labels:
|
if page_labels:
|
||||||
result[doc_id] = (page_labels, is_scanned)
|
result[doc_id] = (page_labels, is_scanned, csv_split)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -333,7 +339,10 @@ class DBYOLODataset:
|
|||||||
|
|
||||||
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
|
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
|
||||||
"""
|
"""
|
||||||
Split items into train/val/test based on document ID.
|
Split items into train/val/test based on CSV-defined split field.
|
||||||
|
|
||||||
|
If CSV has 'split' field, use it directly.
|
||||||
|
Otherwise, fall back to random splitting based on train_ratio/val_ratio.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
|
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
|
||||||
@@ -341,13 +350,64 @@ class DBYOLODataset:
|
|||||||
"""
|
"""
|
||||||
# Group by document ID for proper splitting
|
# Group by document ID for proper splitting
|
||||||
doc_items: dict[str, list[DatasetItem]] = {}
|
doc_items: dict[str, list[DatasetItem]] = {}
|
||||||
|
doc_csv_split: dict[str, str | None] = {} # Track CSV split per document
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
if item.document_id not in doc_items:
|
if item.document_id not in doc_items:
|
||||||
doc_items[item.document_id] = []
|
doc_items[item.document_id] = []
|
||||||
|
doc_csv_split[item.document_id] = item.csv_split
|
||||||
doc_items[item.document_id].append(item)
|
doc_items[item.document_id].append(item)
|
||||||
|
|
||||||
# Shuffle document IDs
|
# Check if we have CSV-defined splits
|
||||||
|
has_csv_splits = any(split is not None for split in doc_csv_split.values())
|
||||||
|
|
||||||
doc_ids = list(doc_items.keys())
|
doc_ids = list(doc_items.keys())
|
||||||
|
|
||||||
|
if has_csv_splits:
|
||||||
|
# Use CSV-defined splits
|
||||||
|
print("Using CSV-defined split field for train/val/test assignment")
|
||||||
|
|
||||||
|
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
|
||||||
|
# 'val' is taken from train set using val_ratio
|
||||||
|
split_doc_ids = []
|
||||||
|
|
||||||
|
if self.split == 'train':
|
||||||
|
# Get documents marked as 'train' or no split defined
|
||||||
|
train_docs = [doc_id for doc_id in doc_ids
|
||||||
|
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||||
|
|
||||||
|
# Take train_ratio of train docs for actual training, rest for val
|
||||||
|
random.seed(self.seed)
|
||||||
|
random.shuffle(train_docs)
|
||||||
|
|
||||||
|
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||||
|
split_doc_ids = train_docs[:n_train]
|
||||||
|
|
||||||
|
elif self.split == 'val':
|
||||||
|
# Get documents marked as 'train' and take val portion
|
||||||
|
train_docs = [doc_id for doc_id in doc_ids
|
||||||
|
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||||
|
|
||||||
|
random.seed(self.seed)
|
||||||
|
random.shuffle(train_docs)
|
||||||
|
|
||||||
|
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||||
|
split_doc_ids = train_docs[n_train:]
|
||||||
|
|
||||||
|
else: # test
|
||||||
|
# Get documents marked as 'test'
|
||||||
|
split_doc_ids = [doc_id for doc_id in doc_ids
|
||||||
|
if doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
|
||||||
|
|
||||||
|
# Apply limit if specified
|
||||||
|
if self.limit is not None and self.limit < len(split_doc_ids):
|
||||||
|
split_doc_ids = split_doc_ids[:self.limit]
|
||||||
|
print(f"Limited to {self.limit} documents")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fall back to random splitting (original behavior)
|
||||||
|
print("No CSV split field found, using random splitting")
|
||||||
|
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
random.shuffle(doc_ids)
|
random.shuffle(doc_ids)
|
||||||
|
|
||||||
@@ -381,23 +441,58 @@ class DBYOLODataset:
|
|||||||
Split items using cached data from a shared dataset.
|
Split items using cached data from a shared dataset.
|
||||||
|
|
||||||
Uses pre-computed doc_ids order for consistent splits.
|
Uses pre-computed doc_ids order for consistent splits.
|
||||||
|
Respects CSV-defined splits if available.
|
||||||
"""
|
"""
|
||||||
# Group by document ID
|
# Group by document ID and track CSV splits
|
||||||
doc_items: dict[str, list[DatasetItem]] = {}
|
doc_items: dict[str, list[DatasetItem]] = {}
|
||||||
|
doc_csv_split: dict[str, str | None] = {}
|
||||||
|
|
||||||
for item in self._all_items:
|
for item in self._all_items:
|
||||||
if item.document_id not in doc_items:
|
if item.document_id not in doc_items:
|
||||||
doc_items[item.document_id] = []
|
doc_items[item.document_id] = []
|
||||||
|
doc_csv_split[item.document_id] = item.csv_split
|
||||||
doc_items[item.document_id].append(item)
|
doc_items[item.document_id].append(item)
|
||||||
|
|
||||||
# Use cached doc_ids order
|
# Check if we have CSV-defined splits
|
||||||
|
has_csv_splits = any(split is not None for split in doc_csv_split.values())
|
||||||
|
|
||||||
doc_ids = self._doc_ids_ordered
|
doc_ids = self._doc_ids_ordered
|
||||||
|
|
||||||
# Calculate split indices
|
if has_csv_splits:
|
||||||
|
# Use CSV-defined splits
|
||||||
|
if self.split == 'train':
|
||||||
|
train_docs = [doc_id for doc_id in doc_ids
|
||||||
|
if doc_id in doc_csv_split and
|
||||||
|
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||||
|
|
||||||
|
random.seed(self.seed)
|
||||||
|
random.shuffle(train_docs)
|
||||||
|
|
||||||
|
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||||
|
split_doc_ids = train_docs[:n_train]
|
||||||
|
|
||||||
|
elif self.split == 'val':
|
||||||
|
train_docs = [doc_id for doc_id in doc_ids
|
||||||
|
if doc_id in doc_csv_split and
|
||||||
|
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||||
|
|
||||||
|
random.seed(self.seed)
|
||||||
|
random.shuffle(train_docs)
|
||||||
|
|
||||||
|
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||||
|
split_doc_ids = train_docs[n_train:]
|
||||||
|
|
||||||
|
else: # test
|
||||||
|
split_doc_ids = [doc_id for doc_id in doc_ids
|
||||||
|
if doc_id in doc_csv_split and
|
||||||
|
doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fall back to random splitting
|
||||||
n_total = len(doc_ids)
|
n_total = len(doc_ids)
|
||||||
n_train = int(n_total * self.train_ratio)
|
n_train = int(n_total * self.train_ratio)
|
||||||
n_val = int(n_total * self.val_ratio)
|
n_val = int(n_total * self.val_ratio)
|
||||||
|
|
||||||
# Split document IDs based on split type
|
|
||||||
if self.split == 'train':
|
if self.split == 'train':
|
||||||
split_doc_ids = doc_ids[:n_train]
|
split_doc_ids = doc_ids[:n_train]
|
||||||
elif self.split == 'val':
|
elif self.split == 'val':
|
||||||
|
|||||||
Reference in New Issue
Block a user