restructure project

This commit is contained in:
Yaojia Wang
2026-01-27 23:58:17 +01:00
parent 58bf75db68
commit d6550375b0
230 changed files with 5513 additions and 1756 deletions

View File

@@ -0,0 +1,20 @@
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
&& rm -rf /var/lib/apt/lists/*
# Install shared package
COPY packages/shared /app/packages/shared
RUN pip install --no-cache-dir -e /app/packages/shared
# Install training package
COPY packages/training /app/packages/training
RUN pip install --no-cache-dir -e /app/packages/training
WORKDIR /app/packages/training
CMD ["python", "run_training.py", "--task-id", "${TASK_ID}"]

View File

@@ -0,0 +1,4 @@
-e ../shared
ultralytics>=8.1.0
tqdm>=4.65.0
torch>=2.0.0

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""
Training Service Entry Point.
Runs a specific training task by ID (for Azure ACI on-demand mode)
or polls the database for pending tasks (for local dev).
"""
import argparse
import logging
import sys
import time
from training.data.training_db import TrainingTaskDB
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
"""Execute a single training task."""
task_id = task["task_id"]
config = task.get("config") or {}
logger.info("Starting training task %s with config: %s", task_id, config)
db.update_status(task_id, "running")
try:
from training.cli.train import run_training
result = run_training(
epochs=config.get("epochs", 100),
batch=config.get("batch_size", 16),
model=config.get("base_model", "yolo11n.pt"),
imgsz=config.get("imgsz", 1280),
name=config.get("name", f"training_{task_id[:8]}"),
)
db.complete_task(
task_id,
model_path=result.get("model_path", ""),
metrics=result.get("metrics", {}),
)
logger.info("Training task %s completed successfully.", task_id)
except Exception as e:
logger.exception("Training task %s failed", task_id)
db.fail_task(task_id, str(e))
sys.exit(1)
def main() -> None:
parser = argparse.ArgumentParser(description="Invoice Training Service")
parser.add_argument(
"--task-id",
help="Specific task ID to run (ACI on-demand mode)",
)
parser.add_argument(
"--poll",
action="store_true",
help="Poll database for pending tasks (local dev mode)",
)
parser.add_argument(
"--poll-interval",
type=int,
default=60,
help="Seconds between polls (default: 60)",
)
args = parser.parse_args()
db = TrainingTaskDB()
if args.task_id:
task = db.get_task(args.task_id)
if not task:
logger.error("Task %s not found", args.task_id)
sys.exit(1)
execute_training_task(db, task)
elif args.poll:
logger.info(
"Starting training service in poll mode (interval=%ds)",
args.poll_interval,
)
while True:
tasks = db.get_pending_tasks(limit=1)
for task in tasks:
execute_training_task(db, task)
time.sleep(args.poll_interval)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,13 @@
from setuptools import setup, find_packages
setup(
name="invoice-training",
version="0.1.0",
packages=find_packages(),
python_requires=">=3.11",
install_requires=[
"invoice-shared",
"ultralytics>=8.1.0",
"tqdm>=4.65.0",
],
)

View File

View File

@@ -0,0 +1,599 @@
#!/usr/bin/env python3
"""
Label Analysis CLI
Analyzes auto-generated labels to identify failures and diagnose root causes.
Now reads from PostgreSQL database instead of JSONL files.
"""
import argparse
import csv
import json
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from shared.config import get_db_connection_string
from shared.normalize import normalize_field
from shared.matcher import FieldMatcher
from shared.pdf import is_text_pdf, extract_text_tokens
from training.yolo.annotation_generator import FIELD_CLASSES
from shared.data.db import DocumentDB
@dataclass
class FieldAnalysis:
"""Analysis result for a single field."""
field_name: str
csv_value: str
expected: bool # True if CSV has value
labeled: bool # True if label file has this field
matched: bool # True if matcher finds it
# Diagnosis
failure_reason: Optional[str] = None
details: dict = field(default_factory=dict)
@dataclass
class DocumentAnalysis:
"""Analysis result for a document."""
doc_id: str
pdf_exists: bool
pdf_type: str # "text" or "scanned"
total_pages: int
# Per-field analysis
fields: list[FieldAnalysis] = field(default_factory=list)
# Summary
csv_fields_count: int = 0 # Fields with values in CSV
labeled_fields_count: int = 0 # Fields in label files
matched_fields_count: int = 0 # Fields matcher can find
@property
def has_issues(self) -> bool:
"""Check if document has any labeling issues."""
return any(
f.expected and not f.labeled
for f in self.fields
)
@property
def missing_labels(self) -> list[FieldAnalysis]:
"""Get fields that should be labeled but aren't."""
return [f for f in self.fields if f.expected and not f.labeled]
class LabelAnalyzer:
"""Analyzes labels and diagnoses failures."""
def __init__(
self,
csv_path: str,
pdf_dir: str,
dataset_dir: str,
use_db: bool = True
):
self.csv_path = Path(csv_path)
self.pdf_dir = Path(pdf_dir)
self.dataset_dir = Path(dataset_dir)
self.use_db = use_db
self.matcher = FieldMatcher()
self.csv_data = {}
self.label_data = {}
self.report_data = {}
# Database connection
self.db = None
if use_db:
self.db = DocumentDB()
self.db.connect()
# Class ID to name mapping
self.class_names = list(FIELD_CLASSES.keys())
def load_csv(self):
"""Load CSV data."""
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row['DocumentId']
self.csv_data[doc_id] = row
print(f"Loaded {len(self.csv_data)} records from CSV")
def load_labels(self):
"""Load all label files from dataset."""
for split in ['train', 'val', 'test']:
label_dir = self.dataset_dir / split / 'labels'
if not label_dir.exists():
continue
for label_file in label_dir.glob('*.txt'):
# Parse document ID from filename (uuid_page_XXX.txt)
name = label_file.stem
parts = name.rsplit('_page_', 1)
if len(parts) == 2:
doc_id = parts[0]
page_no = int(parts[1])
else:
continue
if doc_id not in self.label_data:
self.label_data[doc_id] = {'pages': {}, 'split': split}
# Parse label file
labels = []
with open(label_file, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
class_id = int(parts[0])
labels.append({
'class_id': class_id,
'class_name': self.class_names[class_id],
'x_center': float(parts[1]),
'y_center': float(parts[2]),
'width': float(parts[3]),
'height': float(parts[4])
})
self.label_data[doc_id]['pages'][page_no] = labels
total_docs = len(self.label_data)
total_labels = sum(
len(labels)
for doc in self.label_data.values()
for labels in doc['pages'].values()
)
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
def load_report(self):
"""Load autolabel report from database."""
if not self.db:
print("Database not configured, skipping report loading")
return
# Get document IDs from CSV to query
doc_ids = list(self.csv_data.keys())
if not doc_ids:
return
# Query in batches to avoid memory issues
batch_size = 1000
loaded = 0
for i in range(0, len(doc_ids), batch_size):
batch_ids = doc_ids[i:i + batch_size]
for doc_id in batch_ids:
doc = self.db.get_document(doc_id)
if doc:
self.report_data[doc_id] = doc
loaded += 1
print(f"Loaded {loaded} autolabel reports from database")
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
"""Analyze a single document."""
csv_row = self.csv_data.get(doc_id, {})
label_info = self.label_data.get(doc_id, {'pages': {}})
report = self.report_data.get(doc_id, {})
# Check PDF
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
pdf_exists = pdf_path.exists()
# Skip documents without PDF if requested
if skip_missing_pdf and not pdf_exists:
return None
pdf_type = "unknown"
total_pages = 0
if pdf_exists:
pdf_type = "scanned" if not is_text_pdf(pdf_path) else "text"
total_pages = len(label_info['pages']) or report.get('total_pages', 0)
analysis = DocumentAnalysis(
doc_id=doc_id,
pdf_exists=pdf_exists,
pdf_type=pdf_type,
total_pages=total_pages
)
# Get labeled classes
labeled_classes = set()
for page_labels in label_info['pages'].values():
for label in page_labels:
labeled_classes.add(label['class_name'])
# Analyze each field
for field_name in FIELD_CLASSES.keys():
csv_value = csv_row.get(field_name, '')
if csv_value is None:
csv_value = ''
csv_value = str(csv_value).strip()
# Handle datetime values (remove time part)
if ' 00:00:00' in csv_value:
csv_value = csv_value.replace(' 00:00:00', '')
expected = bool(csv_value)
labeled = field_name in labeled_classes
field_analysis = FieldAnalysis(
field_name=field_name,
csv_value=csv_value,
expected=expected,
labeled=labeled,
matched=False
)
if expected:
analysis.csv_fields_count += 1
if labeled:
analysis.labeled_fields_count += 1
# Diagnose failures
if expected and not labeled:
field_analysis.failure_reason = self._diagnose_failure(
doc_id, field_name, csv_value, pdf_path, pdf_type, report
)
field_analysis.details = self._get_failure_details(
doc_id, field_name, csv_value, pdf_path, pdf_type
)
elif not expected and labeled:
field_analysis.failure_reason = "EXTRA_LABEL"
field_analysis.details = {'note': 'Labeled but no CSV value'}
analysis.fields.append(field_analysis)
return analysis
def _diagnose_failure(
self,
doc_id: str,
field_name: str,
csv_value: str,
pdf_path: Path,
pdf_type: str,
report: dict
) -> str:
"""Diagnose why a field wasn't labeled."""
if not pdf_path.exists():
return "PDF_NOT_FOUND"
if pdf_type == "scanned":
return "SCANNED_PDF"
# Try to match now with current normalizer (not historical report)
if pdf_path.exists() and pdf_type == "text":
try:
# Check all pages
for page_no in range(10): # Max 10 pages
try:
tokens = list(extract_text_tokens(pdf_path, page_no))
if not tokens:
break
normalized = normalize_field(field_name, csv_value)
matches = self.matcher.find_matches(tokens, field_name, normalized, page_no)
if matches:
return "MATCHER_OK_NOW" # Would match with current normalizer
except Exception:
break
return "VALUE_NOT_IN_PDF"
except Exception as e:
return f"PDF_ERROR: {str(e)[:50]}"
return "UNKNOWN"
def _get_failure_details(
self,
doc_id: str,
field_name: str,
csv_value: str,
pdf_path: Path,
pdf_type: str
) -> dict:
"""Get detailed information about a failure."""
details = {
'csv_value': csv_value,
'normalized_candidates': [],
'pdf_tokens_sample': [],
'potential_matches': []
}
# Get normalized candidates
try:
details['normalized_candidates'] = normalize_field(field_name, csv_value)
except Exception:
pass
# Get PDF tokens if available
if pdf_path.exists() and pdf_type == "text":
try:
tokens = list(extract_text_tokens(pdf_path, 0))[:100]
# Find tokens that might be related
candidates = details['normalized_candidates']
for token in tokens:
text = token.text.strip()
# Check if any candidate is substring or similar
for cand in candidates:
if cand in text or text in cand:
details['potential_matches'].append({
'token': text,
'candidate': cand,
'bbox': token.bbox
})
break
# Also collect date-like or number-like tokens for reference
if field_name in ('InvoiceDate', 'InvoiceDueDate'):
if any(c.isdigit() for c in text) and len(text) >= 6:
details['pdf_tokens_sample'].append(text)
elif field_name == 'Amount':
if any(c.isdigit() for c in text) and (',' in text or '.' in text or len(text) >= 4):
details['pdf_tokens_sample'].append(text)
# Limit samples
details['pdf_tokens_sample'] = details['pdf_tokens_sample'][:10]
details['potential_matches'] = details['potential_matches'][:5]
except Exception:
pass
return details
def run_analysis(self, limit: Optional[int] = None, skip_missing_pdf: bool = True) -> list[DocumentAnalysis]:
"""Run analysis on all documents."""
self.load_csv()
self.load_labels()
self.load_report()
results = []
doc_ids = list(self.csv_data.keys())
skipped = 0
for doc_id in doc_ids:
analysis = self.analyze_document(doc_id, skip_missing_pdf=skip_missing_pdf)
if analysis is None:
skipped += 1
continue
results.append(analysis)
if limit and len(results) >= limit:
break
if skipped > 0:
print(f"Skipped {skipped} documents without PDF files")
return results
def generate_report(
self,
results: list[DocumentAnalysis],
output_path: str,
verbose: bool = False
):
"""Generate analysis report."""
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
# Collect statistics
stats = {
'total_documents': len(results),
'documents_with_issues': 0,
'total_expected_fields': 0,
'total_labeled_fields': 0,
'missing_labels': 0,
'extra_labels': 0,
'failure_reasons': defaultdict(int),
'failures_by_field': defaultdict(lambda: defaultdict(int))
}
issues = []
for analysis in results:
stats['total_expected_fields'] += analysis.csv_fields_count
stats['total_labeled_fields'] += analysis.labeled_fields_count
if analysis.has_issues:
stats['documents_with_issues'] += 1
for f in analysis.fields:
if f.expected and not f.labeled:
stats['missing_labels'] += 1
stats['failure_reasons'][f.failure_reason] += 1
stats['failures_by_field'][f.field_name][f.failure_reason] += 1
issues.append({
'doc_id': analysis.doc_id,
'field': f.field_name,
'csv_value': f.csv_value,
'reason': f.failure_reason,
'details': f.details if verbose else {}
})
elif not f.expected and f.labeled:
stats['extra_labels'] += 1
# Write JSON report
report = {
'summary': {
'total_documents': stats['total_documents'],
'documents_with_issues': stats['documents_with_issues'],
'issue_rate': f"{stats['documents_with_issues'] / stats['total_documents'] * 100:.1f}%",
'total_expected_fields': stats['total_expected_fields'],
'total_labeled_fields': stats['total_labeled_fields'],
'label_coverage': f"{stats['total_labeled_fields'] / max(1, stats['total_expected_fields']) * 100:.1f}%",
'missing_labels': stats['missing_labels'],
'extra_labels': stats['extra_labels']
},
'failure_reasons': dict(stats['failure_reasons']),
'failures_by_field': {
field: dict(reasons)
for field, reasons in stats['failures_by_field'].items()
},
'issues': issues
}
with open(output, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
print(f"\nReport saved to: {output}")
return report
def print_summary(report: dict):
"""Print summary to console."""
summary = report['summary']
print("\n" + "=" * 60)
print("LABEL ANALYSIS SUMMARY")
print("=" * 60)
print(f"\nDocuments:")
print(f" Total: {summary['total_documents']}")
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
print(f"\nFields:")
print(f" Expected: {summary['total_expected_fields']}")
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
print(f" Missing: {summary['missing_labels']}")
print(f" Extra: {summary['extra_labels']}")
print(f"\nFailure Reasons:")
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
print(f" {reason}: {count}")
print(f"\nFailures by Field:")
for field, reasons in report['failures_by_field'].items():
total = sum(reasons.values())
print(f" {field}: {total}")
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" - {reason}: {count}")
# Show sample issues
if report['issues']:
print(f"\n" + "-" * 60)
print("SAMPLE ISSUES (first 10)")
print("-" * 60)
for issue in report['issues'][:10]:
print(f"\n[{issue['doc_id']}] {issue['field']}")
print(f" CSV value: {issue['csv_value']}")
print(f" Reason: {issue['reason']}")
if issue.get('details'):
details = issue['details']
if details.get('normalized_candidates'):
print(f" Candidates: {details['normalized_candidates'][:5]}")
if details.get('pdf_tokens_sample'):
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
if details.get('potential_matches'):
print(f" Potential matches:")
for pm in details['potential_matches'][:3]:
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
def main():
parser = argparse.ArgumentParser(
description='Analyze auto-generated labels and diagnose failures'
)
parser.add_argument(
'--csv', '-c',
default='data/structured_data/document_export_20260109_220326.csv',
help='Path to structured data CSV file'
)
parser.add_argument(
'--pdf-dir', '-p',
default='data/raw_pdfs',
help='Directory containing PDF files'
)
parser.add_argument(
'--dataset', '-d',
default='data/dataset',
help='Dataset directory with labels'
)
parser.add_argument(
'--output', '-o',
default='reports/label_analysis.json',
help='Output path for analysis report'
)
parser.add_argument(
'--limit', '-l',
type=int,
default=None,
help='Limit number of documents to analyze'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Include detailed failure information'
)
parser.add_argument(
'--single', '-s',
help='Analyze single document ID'
)
parser.add_argument(
'--no-db',
action='store_true',
help='Skip database, only analyze label files'
)
args = parser.parse_args()
analyzer = LabelAnalyzer(
csv_path=args.csv,
pdf_dir=args.pdf_dir,
dataset_dir=args.dataset,
use_db=not args.no_db
)
if args.single:
# Analyze single document
analyzer.load_csv()
analyzer.load_labels()
analyzer.load_report()
analysis = analyzer.analyze_document(args.single)
print(f"\n{'=' * 60}")
print(f"Document: {analysis.doc_id}")
print(f"{'=' * 60}")
print(f"PDF exists: {analysis.pdf_exists}")
print(f"PDF type: {analysis.pdf_type}")
print(f"Pages: {analysis.total_pages}")
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
for f in analysis.fields:
status = "" if f.labeled else ("" if f.expected else "-")
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
print(f" [{status}] {f.field_name}: {value_str}")
if f.failure_reason:
print(f" Reason: {f.failure_reason}")
if f.details.get('normalized_candidates'):
print(f" Candidates: {f.details['normalized_candidates']}")
if f.details.get('potential_matches'):
print(f" Potential matches in PDF:")
for pm in f.details['potential_matches'][:3]:
print(f" - '{pm['token']}'")
else:
# Full analysis
print("Running label analysis...")
results = analyzer.run_analysis(limit=args.limit)
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
print_summary(report)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,434 @@
#!/usr/bin/env python3
"""
Analyze Auto-Label Report
Generates statistics and insights from database or autolabel_report.jsonl
"""
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
from shared.config import get_db_connection_string
def load_reports_from_db() -> dict:
"""Load statistics directly from database using SQL aggregation."""
from shared.data.db import DocumentDB
db = DocumentDB()
db.connect()
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
'by_field': defaultdict(lambda: {
'total': 0,
'matched': 0,
'exact_match': 0,
'flexible_match': 0,
'scores': [],
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
}),
'errors': defaultdict(int),
'processing_times': [],
}
conn = db.connect()
with conn.cursor() as cursor:
# Overall stats
cursor.execute("""
SELECT
COUNT(*) as total,
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful,
SUM(CASE WHEN NOT success THEN 1 ELSE 0 END) as failed
FROM documents
""")
row = cursor.fetchone()
stats['total'] = row[0] or 0
stats['successful'] = row[1] or 0
stats['failed'] = row[2] or 0
# By PDF type
cursor.execute("""
SELECT
pdf_type,
COUNT(*) as total,
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful
FROM documents
GROUP BY pdf_type
""")
for row in cursor.fetchall():
pdf_type = row[0] or 'unknown'
stats['by_pdf_type'][pdf_type] = {
'total': row[1] or 0,
'successful': row[2] or 0
}
# Processing times
cursor.execute("""
SELECT AVG(processing_time_ms), MIN(processing_time_ms), MAX(processing_time_ms)
FROM documents
WHERE processing_time_ms > 0
""")
row = cursor.fetchone()
if row[0]:
stats['processing_time_stats'] = {
'avg_ms': float(row[0]),
'min_ms': float(row[1]),
'max_ms': float(row[2])
}
# Field stats
cursor.execute("""
SELECT
field_name,
COUNT(*) as total,
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched,
SUM(CASE WHEN matched AND score >= 0.99 THEN 1 ELSE 0 END) as exact_match,
SUM(CASE WHEN matched AND score < 0.99 THEN 1 ELSE 0 END) as flexible_match,
AVG(CASE WHEN matched THEN score END) as avg_score
FROM field_results
GROUP BY field_name
ORDER BY field_name
""")
for row in cursor.fetchall():
field_name = row[0]
stats['by_field'][field_name] = {
'total': row[1] or 0,
'matched': row[2] or 0,
'exact_match': row[3] or 0,
'flexible_match': row[4] or 0,
'avg_score': float(row[5]) if row[5] else 0,
'scores': [], # Not loading individual scores for efficiency
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
}
# Field stats by PDF type
cursor.execute("""
SELECT
fr.field_name,
d.pdf_type,
COUNT(*) as total,
SUM(CASE WHEN fr.matched THEN 1 ELSE 0 END) as matched
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
GROUP BY fr.field_name, d.pdf_type
""")
for row in cursor.fetchall():
field_name = row[0]
pdf_type = row[1] or 'unknown'
if field_name in stats['by_field']:
stats['by_field'][field_name]['by_pdf_type'][pdf_type] = {
'total': row[2] or 0,
'matched': row[3] or 0
}
db.close()
return stats
def load_reports_from_file(report_path: str) -> list[dict]:
"""Load all reports from JSONL file(s). Supports glob patterns."""
path = Path(report_path)
# Handle glob pattern
if '*' in str(path) or '?' in str(path):
parent = path.parent
pattern = path.name
report_files = sorted(parent.glob(pattern))
else:
report_files = [path]
if not report_files:
return []
print(f"Reading {len(report_files)} report file(s):")
for f in report_files:
print(f" - {f.name}")
reports = []
for report_file in report_files:
if not report_file.exists():
continue
with open(report_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
reports.append(json.loads(line))
return reports
def analyze_reports(reports: list[dict]) -> dict:
"""Analyze reports and generate statistics."""
stats = {
'total': len(reports),
'successful': 0,
'failed': 0,
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
'by_field': defaultdict(lambda: {
'total': 0,
'matched': 0,
'exact_match': 0, # score == 1.0
'flexible_match': 0, # score < 1.0
'scores': [],
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
}),
'errors': defaultdict(int),
'processing_times': [],
}
for report in reports:
pdf_type = report.get('pdf_type') or 'unknown'
success = report.get('success', False)
# Overall stats
if success:
stats['successful'] += 1
else:
stats['failed'] += 1
# By PDF type
stats['by_pdf_type'][pdf_type]['total'] += 1
if success:
stats['by_pdf_type'][pdf_type]['successful'] += 1
# Processing time
proc_time = report.get('processing_time_ms', 0)
if proc_time > 0:
stats['processing_times'].append(proc_time)
# Errors
for error in report.get('errors', []):
stats['errors'][error] += 1
# Field results
for field_result in report.get('field_results', []):
field_name = field_result['field_name']
matched = field_result.get('matched', False)
score = field_result.get('score', 0.0)
stats['by_field'][field_name]['total'] += 1
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['total'] += 1
if matched:
stats['by_field'][field_name]['matched'] += 1
stats['by_field'][field_name]['scores'].append(score)
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['matched'] += 1
if score >= 0.99:
stats['by_field'][field_name]['exact_match'] += 1
else:
stats['by_field'][field_name]['flexible_match'] += 1
return stats
def print_report(stats: dict, verbose: bool = False):
"""Print analysis report."""
print("\n" + "=" * 60)
print("AUTO-LABEL REPORT ANALYSIS")
print("=" * 60)
# Overall stats
print(f"\n{'OVERALL STATISTICS':^60}")
print("-" * 60)
total = stats['total']
successful = stats['successful']
failed = stats['failed']
success_rate = successful / total * 100 if total > 0 else 0
print(f"Total documents: {total:>8}")
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
# Processing time
if 'processing_time_stats' in stats:
pts = stats['processing_time_stats']
print(f"\nProcessing time (ms):")
print(f" Average: {pts['avg_ms']:>8.1f}")
print(f" Min: {pts['min_ms']:>8.1f}")
print(f" Max: {pts['max_ms']:>8.1f}")
elif stats.get('processing_times'):
times = stats['processing_times']
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
print(f"\nProcessing time (ms):")
print(f" Average: {avg_time:>8.1f}")
print(f" Min: {min_time:>8.1f}")
print(f" Max: {max_time:>8.1f}")
# By PDF type
print(f"\n{'BY PDF TYPE':^60}")
print("-" * 60)
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
print("-" * 60)
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
type_total = type_stats['total']
type_success = type_stats['successful']
type_rate = type_success / type_total * 100 if type_total > 0 else 0
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
# By field
print(f"\n{'FIELD MATCH STATISTICS':^60}")
print("-" * 60)
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
print("-" * 60)
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
if field_name not in stats['by_field']:
continue
field_stats = stats['by_field'][field_name]
total = field_stats['total']
matched = field_stats['matched']
exact = field_stats['exact_match']
flex = field_stats['flexible_match']
rate = matched / total * 100 if total > 0 else 0
# Handle avg_score from either DB or file analysis
if 'avg_score' in field_stats:
avg_score = field_stats['avg_score']
elif field_stats['scores']:
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
else:
avg_score = 0
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
# Field match by PDF type
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
print("-" * 60)
for pdf_type in sorted(stats['by_pdf_type'].keys()):
print(f"\n[{pdf_type.upper()}]")
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
print("-" * 50)
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
if field_name not in stats['by_field']:
continue
type_stats = stats['by_field'][field_name]['by_pdf_type'].get(pdf_type, {'total': 0, 'matched': 0})
total = type_stats['total']
matched = type_stats['matched']
rate = matched / total * 100 if total > 0 else 0
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
# Errors
if stats.get('errors') and verbose:
print(f"\n{'ERRORS':^60}")
print("-" * 60)
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
print(f"{count:>5}x {error[:50]}")
print("\n" + "=" * 60)
def export_json(stats: dict, output_path: str):
"""Export statistics to JSON file."""
# Convert defaultdicts to regular dicts for JSON serialization
export_data = {
'total': stats['total'],
'successful': stats['successful'],
'failed': stats['failed'],
'by_pdf_type': dict(stats['by_pdf_type']),
'by_field': {},
'errors': dict(stats.get('errors', {})),
}
# Processing time stats
if 'processing_time_stats' in stats:
export_data['processing_time_stats'] = stats['processing_time_stats']
elif stats.get('processing_times'):
times = stats['processing_times']
export_data['processing_time_stats'] = {
'avg_ms': sum(times) / len(times),
'min_ms': min(times),
'max_ms': max(times),
'count': len(times)
}
# Field stats
for field_name, field_stats in stats['by_field'].items():
avg_score = field_stats.get('avg_score', 0)
if not avg_score and field_stats.get('scores'):
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
export_data['by_field'][field_name] = {
'total': field_stats['total'],
'matched': field_stats['matched'],
'exact_match': field_stats['exact_match'],
'flexible_match': field_stats['flexible_match'],
'match_rate': field_stats['matched'] / field_stats['total'] if field_stats['total'] > 0 else 0,
'avg_score': avg_score,
'by_pdf_type': dict(field_stats['by_pdf_type'])
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
print(f"\nStatistics exported to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description='Analyze auto-label report'
)
parser.add_argument(
'--report', '-r',
default=None,
help='Path to autolabel report JSONL file (uses database if not specified)'
)
parser.add_argument(
'--output', '-o',
help='Export statistics to JSON file'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Show detailed error messages'
)
parser.add_argument(
'--from-file',
action='store_true',
help='Force reading from JSONL file instead of database'
)
args = parser.parse_args()
# Decide source
use_db = not args.from_file and args.report is None
if use_db:
print("Loading statistics from database...")
stats = load_reports_from_db()
print(f"Loaded stats for {stats['total']} documents")
else:
report_path = args.report or 'reports/autolabel_report.jsonl'
path = Path(report_path)
# Check if file exists (handle glob patterns)
if '*' not in str(path) and '?' not in str(path) and not path.exists():
print(f"Error: Report file not found: {path}")
return 1
print(f"Loading reports from: {report_path}")
reports = load_reports_from_file(report_path)
print(f"Loaded {len(reports)} reports")
stats = analyze_reports(reports)
print_report(stats, verbose=args.verbose)
if args.output:
export_json(stats, args.output)
return 0
if __name__ == '__main__':
exit(main())

View File

@@ -0,0 +1,752 @@
#!/usr/bin/env python3
"""
Auto-labeling CLI
Generates YOLO training data from PDFs and structured CSV data.
"""
import argparse
import sys
import time
import os
import signal
import shutil
import warnings
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import multiprocessing
# Global flag for graceful shutdown
_shutdown_requested = False
def _signal_handler(signum, frame):
"""Handle interrupt signals for graceful shutdown."""
global _shutdown_requested
_shutdown_requested = True
print("\n\nShutdown requested. Finishing current batch and saving progress...")
print("(Press Ctrl+C again to force quit)\n")
# Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32':
multiprocessing.set_start_method('spawn', force=True)
from shared.config import get_db_connection_string, PATHS, AUTOLABEL
# Global OCR engine for worker processes (initialized once per worker)
_worker_ocr_engine = None
_worker_initialized = False
_worker_type = None # 'cpu' or 'gpu'
def _init_cpu_worker():
"""Initialize CPU worker (no OCR engine needed)."""
global _worker_initialized, _worker_type
_worker_initialized = True
_worker_type = 'cpu'
def _init_gpu_worker():
"""Initialize GPU worker with OCR engine (called once per worker)."""
global _worker_ocr_engine, _worker_initialized, _worker_type
# Suppress PaddlePaddle/PaddleX reinitialization warnings
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
warnings.filterwarnings('ignore', message='.*reinitialization.*')
# Set environment variable to suppress paddle warnings
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
# OCR engine will be lazily initialized on first use
_worker_ocr_engine = None
_worker_initialized = True
_worker_type = 'gpu'
def _init_worker():
"""Initialize worker process with OCR engine (called once per worker).
Legacy function for backwards compatibility.
"""
_init_gpu_worker()
def _get_ocr_engine():
"""Get or create OCR engine for current worker."""
global _worker_ocr_engine
if _worker_ocr_engine is None:
# Suppress warnings during OCR initialization
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
from shared.ocr import OCREngine
_worker_ocr_engine = OCREngine()
return _worker_ocr_engine
def _save_output_img(output_img, image_path: Path) -> None:
"""Save OCR output_img to replace the original rendered image."""
from PIL import Image as PILImage
# Convert numpy array to PIL Image and save
if output_img is not None:
img = PILImage.fromarray(output_img)
img.save(str(image_path))
# If output_img is None, the original image is already saved
def process_single_document(args_tuple):
"""
Process a single document (worker function for parallel processing).
Args:
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
Returns:
dict with results
"""
import shutil
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
# Import inside worker to avoid pickling issues
from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields
start_time = time.time()
pdf_path = Path(pdf_path_str)
output_dir = Path(output_dir_str)
doc_id = row_dict['DocumentId']
# Clean up existing temp folder for this document (for re-matching)
temp_doc_dir = output_dir / 'temp' / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
# Store metadata fields from CSV
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = {
'doc_id': doc_id,
'success': False,
'pages': [],
'report': None,
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
}
try:
# Use PDFDocument context manager for efficient PDF handling
# Opens PDF only once, caches dimensions, handles cleanup automatically
with PDFDocument(pdf_path) as pdf_doc:
# Check PDF type (uses cached document)
use_ocr = not pdf_doc.is_text_pdf()
report.pdf_type = "scanned" if use_ocr else "text"
# Skip OCR if requested
if use_ocr and skip_ocr:
report.errors.append("Skipped (scanned PDF)")
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
# Get OCR engine from worker cache (only created once per worker)
ocr_engine = None
if use_ocr:
ocr_engine = _get_ocr_engine()
# Process each page
page_annotations = []
matched_fields = set()
# Render all pages and process (uses cached document handle)
images_dir = output_dir / 'temp' / doc_id / 'images'
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
# Get dimensions from cache (no additional PDF open)
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Extract tokens
if use_ocr:
# Use extract_with_image to get both tokens and preprocessed image
# PaddleOCR coordinates are relative to output_img, not original image
ocr_result = ocr_engine.extract_with_image(
str(image_path),
page_no,
scale_to_pdf_points=72 / dpi
)
tokens = ocr_result.tokens
# Save output_img to replace the original rendered image
# This ensures coordinates match the saved image
_save_output_img(ocr_result.output_img, image_path)
# Update image dimensions to match output_img
if ocr_result.output_img is not None:
img_height, img_width = ocr_result.output_img.shape[:2]
else:
# Use cached document for text extraction
tokens = list(pdf_doc.extract_text_tokens(page_no))
# Get page dimensions
page = pdf_doc.doc[page_no]
page_height = page.rect.height
page_width = page.rect.width
# Use shared processing logic
matches = {}
annotations, ann_count = process_page(
tokens=tokens,
row_dict=row_dict,
page_no=page_no,
page_height=page_height,
page_width=page_width,
img_width=img_width,
img_height=img_height,
dpi=dpi,
min_confidence=min_confidence,
matches=matches,
matched_fields=matched_fields,
report=report,
result_stats=result['stats'],
)
if annotations:
page_annotations.append({
'image_path': str(image_path),
'page_no': page_no,
'count': ann_count
})
report.annotations_generated += ann_count
# Record unmatched fields using shared logic
record_unmatched_fields(row_dict, matched_fields, report)
if page_annotations:
result['pages'] = page_annotations
result['success'] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
def main():
parser = argparse.ArgumentParser(
description='Generate YOLO annotations from PDFs and CSV data'
)
parser.add_argument(
'--csv', '-c',
default=f"{PATHS['csv_dir']}/*.csv",
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
)
parser.add_argument(
'--pdf-dir', '-p',
default=PATHS['pdf_dir'],
help='Directory containing PDF files'
)
parser.add_argument(
'--output', '-o',
default=PATHS['output_dir'],
help='Output directory for dataset'
)
parser.add_argument(
'--dpi',
type=int,
default=AUTOLABEL['dpi'],
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
)
parser.add_argument(
'--min-confidence',
type=float,
default=AUTOLABEL['min_confidence'],
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
)
parser.add_argument(
'--report',
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
)
parser.add_argument(
'--max-records',
type=int,
default=10000,
help='Max records per report file for sharding (default: 10000, 0 = single file)'
)
parser.add_argument(
'--single',
help='Process single document ID only'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
parser.add_argument(
'--workers', '-w',
type=int,
default=4,
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
)
parser.add_argument(
'--cpu-workers',
type=int,
default=None,
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
)
parser.add_argument(
'--gpu-workers',
type=int,
default=1,
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
)
parser.add_argument(
'--skip-ocr',
action='store_true',
help='Skip scanned PDFs (text-layer only)'
)
parser.add_argument(
'--limit', '-l',
type=int,
default=None,
help='Limit number of documents to process (for testing)'
)
args = parser.parse_args()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
# Import here to avoid slow startup
from shared.data import CSVLoader
from training.data.autolabel_report import AutoLabelReport, FieldMatchResult, ReportWriter
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from shared.pdf.renderer import get_render_dimensions
from shared.ocr import OCREngine
from shared.matcher import FieldMatcher
from shared.normalize import normalize_field
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
# Handle comma-separated CSV paths
csv_input = args.csv
if ',' in csv_input and '*' not in csv_input:
csv_input = [p.strip() for p in csv_input.split(',')]
# Get list of CSV files (don't load all data at once)
temp_loader = CSVLoader(csv_input, args.pdf_dir)
csv_files = temp_loader.csv_paths
pdf_dir = temp_loader.pdf_dir
print(f"Found {len(csv_files)} CSV file(s) to process")
# Setup output directories
output_dir = Path(args.output)
# Only create temp directory for images (no train/val/test split during labeling)
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
# Report writer with optional sharding
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
# Database connection for checking existing documents
from shared.data.db import DocumentDB
db = DocumentDB()
db.connect()
db.create_tables() # Ensure tables exist
print("Connected to database for status checking")
# Global stats
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'skipped': 0,
'skipped_db': 0, # Skipped because already in DB
'retried': 0, # Re-processed failed ones
'annotations': 0,
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
}
# Track all processed items for final split (write to temp file to save memory)
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
processed_count = 0
seen_doc_ids = set()
# Batch for database updates
db_batch = []
DB_BATCH_SIZE = 100
# Helper function to handle result and update database
# Defined outside the loop so nonlocal can properly reference db_batch
def handle_result(result):
nonlocal processed_count, db_batch
# Write report to file
if result['report']:
report_writer.write_dict(result['report'])
# Add to database batch
db_batch.append(result['report'])
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if result['success']:
# Write to temp file instead of memory
import json
processed_items_writer.write(json.dumps({
'doc_id': result['doc_id'],
'pages': result['pages']
}) + '\n')
processed_items_writer.flush()
processed_count += 1
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
def handle_error(doc_id, error):
nonlocal db_batch
stats['failed'] += 1
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(error)}"]
}
report_writer.write_dict(error_report)
db_batch.append(error_report)
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if args.verbose:
print(f"Error processing {doc_id}: {error}")
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
dual_pool_coordinator = None
use_dual_pool = args.cpu_workers is not None
if use_dual_pool:
from training.processing import DualPoolCoordinator
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
dual_pool_coordinator = DualPoolCoordinator(
cpu_workers=args.cpu_workers,
gpu_workers=args.gpu_workers,
gpu_id=0,
task_timeout=300.0,
)
dual_pool_coordinator.start()
try:
# Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files):
# Check for shutdown request
if _shutdown_requested:
print("\nShutdown requested. Stopping after current batch...")
break
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Load only this CSV file
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
rows = single_loader.load_all()
# Filter to single document if specified
if args.single:
rows = [r for r in rows if r.DocumentId == args.single]
if not rows:
continue
# Deduplicate across CSV files
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
for r in rows:
seen_doc_ids.add(r.DocumentId)
if not rows:
print(f" Skipping CSV (no new documents)")
continue
# Batch query database for all document IDs in this CSV
csv_doc_ids = [r.DocumentId for r in rows]
db_status_map = db.check_documents_status_batch(csv_doc_ids)
# Count how many are already processed successfully
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
# Skip entire CSV if all documents are already processed
if already_processed == len(rows):
print(f" Skipping CSV (all {len(rows)} documents already processed)")
stats['skipped_db'] += len(rows)
continue
# Count how many new documents need processing in this CSV
new_to_process = len(rows) - already_processed
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
stats['total'] += len(rows)
# Prepare tasks for this CSV
tasks = []
skipped_in_csv = 0
retry_in_csv = 0
# Calculate how many more we can process if limit is set
# Use tasks_submitted counter which tracks across all CSVs
if args.limit:
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
if remaining_limit <= 0:
print(f" Reached limit of {args.limit} new documents, stopping.")
break
else:
remaining_limit = float('inf')
# Collect doc_ids that need retry (for batch delete)
retry_doc_ids = []
for row in rows:
# Stop adding tasks if we've reached the limit
if len(tasks) >= remaining_limit:
break
doc_id = row.DocumentId
# Check document status from batch query result
db_status = db_status_map.get(doc_id) # None if not in DB
# Skip if already successful in database
if db_status is True:
stats['skipped_db'] += 1
skipped_in_csv += 1
continue
# Check if this is a retry (was failed before)
if db_status is False:
stats['retried'] += 1
retry_in_csv += 1
retry_doc_ids.append(doc_id)
pdf_path = single_loader.get_pdf_path(row)
if not pdf_path:
stats['skipped'] += 1
continue
row_dict = {
'DocumentId': row.DocumentId,
'InvoiceNumber': row.InvoiceNumber,
'InvoiceDate': row.InvoiceDate,
'InvoiceDueDate': row.InvoiceDueDate,
'OCR': row.OCR,
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
# New fields for matching
'supplier_organisation_number': row.supplier_organisation_number,
'supplier_accounts': row.supplier_accounts,
'customer_number': row.customer_number,
# Metadata fields (not for matching, but for database storage)
'split': row.split,
'supplier_name': row.supplier_name,
}
tasks.append((
row_dict,
str(pdf_path),
str(output_dir),
args.dpi,
args.min_confidence,
args.skip_ocr
))
if skipped_in_csv > 0 or retry_in_csv > 0:
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
# Clean up retry documents: delete from database and remove temp folders
if retry_doc_ids:
# Batch delete from database (field_results will be cascade deleted)
with db.connect().cursor() as cursor:
cursor.execute(
"DELETE FROM documents WHERE document_id = ANY(%s)",
(retry_doc_ids,)
)
db.connect().commit()
# Remove temp folders
for doc_id in retry_doc_ids:
temp_doc_dir = output_dir / 'temp' / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
if not tasks:
continue
# Update tasks_submitted counter for limit tracking
stats['tasks_submitted'] += len(tasks)
if use_dual_pool:
# Dual-pool mode using pre-initialized DualPoolCoordinator
# (process_text_pdf, process_scanned_pdf already imported above)
# Convert tasks to new format
documents = []
for task in tasks:
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
# Pre-classify PDF type
try:
is_text = is_text_pdf(pdf_path_str)
except Exception:
is_text = False
documents.append({
"id": row_dict["DocumentId"],
"row_dict": row_dict,
"pdf_path": pdf_path_str,
"output_dir": output_dir_str,
"dpi": dpi,
"min_confidence": min_confidence,
"is_scanned": not is_text,
"has_text": is_text,
"text_length": 1000 if is_text else 0, # Approximate
})
# Count task types
text_count = sum(1 for d in documents if not d["is_scanned"])
scan_count = len(documents) - text_count
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
# Progress tracking with tqdm
pbar = tqdm(total=len(documents), desc="Processing")
def on_result(task_result):
"""Handle successful result."""
result = task_result.data
handle_result(result)
pbar.update(1)
def on_error(task_id, error):
"""Handle failed task."""
handle_error(task_id, error)
pbar.update(1)
# Process with pre-initialized coordinator (workers stay alive)
results = dual_pool_coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error,
id_field="id",
)
pbar.close()
# Log summary
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
print(f" Batch complete: {successful} successful, {failed} failed")
else:
# Single-pool mode (original behavior)
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
# Process documents in parallel (inside CSV loop for streaming)
# Use single process for debugging or when workers=1
if args.workers == 1:
for task in tqdm(tasks, desc="Processing"):
result = process_single_document(task)
handle_result(result)
else:
# Parallel processing with worker initialization
# Each worker initializes OCR engine once and reuses it
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
# Per-document timeout: 120 seconds (2 minutes)
# This prevents a single stuck document from blocking the entire batch
DOCUMENT_TIMEOUT = 120
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result(timeout=DOCUMENT_TIMEOUT)
handle_result(result)
except TimeoutError:
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
# Cancel the stuck future
future.cancel()
except Exception as e:
handle_error(doc_id, e)
# Flush remaining database batch after each CSV
if db_batch:
db.save_documents_batch(db_batch)
db_batch.clear()
finally:
# Shutdown dual-pool coordinator if it was started
if dual_pool_coordinator is not None:
dual_pool_coordinator.shutdown()
# Close temp file
processed_items_writer.close()
# Use the in-memory counter instead of re-reading the file (performance fix)
# processed_count already tracks the number of successfully processed items
# Cleanup processed_items temp file (not needed anymore)
processed_items_file.unlink(missing_ok=True)
# Close database connection
db.close()
# Print summary
print("\n" + "=" * 50)
print("Auto-labeling Complete")
print("=" * 50)
print(f"Total documents: {stats['total']}")
print(f"Successful: {stats['successful']}")
print(f"Failed: {stats['failed']}")
print(f"Skipped (no PDF): {stats['skipped']}")
print(f"Skipped (in DB): {stats['skipped_db']}")
print(f"Retried (failed): {stats['retried']}")
print(f"Total annotations: {stats['annotations']}")
print(f"\nImages saved to: {output_dir / 'temp'}")
print(f"Labels stored in: PostgreSQL database")
print(f"\nAnnotations by field:")
for field, count in stats['by_field'].items():
print(f" {field}: {count}")
shard_files = report_writer.get_shard_files()
if len(shard_files) > 1:
print(f"\nReport files ({len(shard_files)}):")
for sf in shard_files:
print(f" - {sf}")
else:
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
"""
Import existing JSONL report files into PostgreSQL database.
Usage:
python -m src.cli.import_report_to_db --report "reports/autolabel_report_v4*.jsonl"
"""
import argparse
import json
import sys
from pathlib import Path
import psycopg2
from psycopg2.extras import execute_values
# Add project root to path
from shared.config import get_db_connection_string, PATHS
def create_tables(conn):
"""Create database tables."""
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
document_id TEXT PRIMARY KEY,
pdf_path TEXT,
pdf_type TEXT,
success BOOLEAN,
total_pages INTEGER,
fields_matched INTEGER,
fields_total INTEGER,
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]',
-- New fields for extended CSV format
split TEXT,
customer_number TEXT,
supplier_name TEXT,
supplier_organisation_number TEXT,
supplier_accounts TEXT
);
CREATE TABLE IF NOT EXISTS field_results (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
field_name TEXT,
csv_value TEXT,
matched BOOLEAN,
score REAL,
matched_text TEXT,
candidate_used TEXT,
bbox JSONB,
page_no INTEGER,
context_keywords JSONB DEFAULT '[]',
error TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
-- Add new columns to existing tables if they don't exist (for migration)
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
ALTER TABLE documents ADD COLUMN split TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
ALTER TABLE documents ADD COLUMN customer_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
END IF;
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
END IF;
END $$;
""")
conn.commit()
def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_size: int = 1000) -> dict:
"""Import a single JSONL file into database."""
stats = {'imported': 0, 'skipped': 0, 'errors': 0}
# Get existing document IDs if skipping
existing_ids = set()
if skip_existing:
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents")
existing_ids = {row[0] for row in cursor.fetchall()}
doc_batch = []
field_batch = []
def flush_batches():
nonlocal doc_batch, field_batch
if doc_batch:
with conn.cursor() as cursor:
execute_values(cursor, """
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors,
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
VALUES %s
ON CONFLICT (document_id) DO UPDATE SET
pdf_path = EXCLUDED.pdf_path,
pdf_type = EXCLUDED.pdf_type,
success = EXCLUDED.success,
total_pages = EXCLUDED.total_pages,
fields_matched = EXCLUDED.fields_matched,
fields_total = EXCLUDED.fields_total,
annotations_generated = EXCLUDED.annotations_generated,
processing_time_ms = EXCLUDED.processing_time_ms,
timestamp = EXCLUDED.timestamp,
errors = EXCLUDED.errors,
split = EXCLUDED.split,
customer_number = EXCLUDED.customer_number,
supplier_name = EXCLUDED.supplier_name,
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
supplier_accounts = EXCLUDED.supplier_accounts
""", doc_batch)
doc_batch = []
if field_batch:
with conn.cursor() as cursor:
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_batch)
field_batch = []
conn.commit()
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line_no, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError as e:
print(f" Warning: Line {line_no} - JSON parse error: {e}")
stats['errors'] += 1
continue
doc_id = record.get('document_id')
if not doc_id:
stats['errors'] += 1
continue
# Only import successful documents
if not record.get('success'):
stats['skipped'] += 1
continue
# Check if already exists
if skip_existing and doc_id in existing_ids:
stats['skipped'] += 1
continue
# Add to batch
doc_batch.append((
doc_id,
record.get('pdf_path'),
record.get('pdf_type'),
record.get('success'),
record.get('total_pages'),
record.get('fields_matched'),
record.get('fields_total'),
record.get('annotations_generated'),
record.get('processing_time_ms'),
record.get('timestamp'),
json.dumps(record.get('errors', [])),
# New fields
record.get('split'),
record.get('customer_number'),
record.get('supplier_name'),
record.get('supplier_organisation_number'),
record.get('supplier_accounts'),
))
for field in record.get('field_results', []):
field_batch.append((
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
))
stats['imported'] += 1
existing_ids.add(doc_id)
# Flush batch if needed
if len(doc_batch) >= batch_size:
flush_batches()
print(f" Processed {stats['imported'] + stats['skipped']} records...")
# Final flush
flush_batches()
return stats
def main():
parser = argparse.ArgumentParser(description='Import JSONL reports to PostgreSQL database')
parser.add_argument('--report', type=str, default=f"{PATHS['reports_dir']}/autolabel_report*.jsonl",
help='Report file path or glob pattern')
parser.add_argument('--db', type=str, default=None,
help='PostgreSQL connection string (uses config.py if not specified)')
parser.add_argument('--no-skip', action='store_true',
help='Do not skip existing documents (replace them)')
parser.add_argument('--batch-size', type=int, default=1000,
help='Batch size for bulk inserts')
args = parser.parse_args()
# Use config if db not specified
db_connection = args.db or get_db_connection_string()
# Find report files
report_path = Path(args.report)
if '*' in str(report_path) or '?' in str(report_path):
parent = report_path.parent
pattern = report_path.name
report_files = sorted(parent.glob(pattern))
else:
report_files = [report_path] if report_path.exists() else []
if not report_files:
print(f"No report files found: {args.report}")
return
print(f"Found {len(report_files)} report file(s)")
# Connect to database
conn = psycopg2.connect(db_connection)
create_tables(conn)
# Import each file
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
for report_file in report_files:
print(f"\nImporting: {report_file.name}")
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
for key in total_stats:
total_stats[key] += stats[key]
# Print summary
print("\n" + "=" * 50)
print("Import Complete")
print("=" * 50)
print(f"Total imported: {total_stats['imported']}")
print(f"Total skipped: {total_stats['skipped']}")
print(f"Total errors: {total_stats['errors']}")
# Quick stats from database
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM documents")
total_docs = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM documents WHERE success = true")
success_docs = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM field_results")
total_fields = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM field_results WHERE matched = true")
matched_fields = cursor.fetchone()[0]
conn.close()
print(f"\nDatabase Stats:")
print(f" Documents: {total_docs} ({success_docs} successful)")
print(f" Field results: {total_fields} ({matched_fields} matched)")
if total_fields > 0:
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,424 @@
#!/usr/bin/env python3
"""
Re-process failed matches and store detailed information including OCR values,
CSV values, and source CSV filename in a new table.
"""
import argparse
import json
import glob
import os
import sys
import time
from pathlib import Path
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
from tqdm import tqdm
from shared.config import DEFAULT_DPI
from shared.data.db import DocumentDB
from shared.data.csv_loader import CSVLoader
from shared.normalize.normalizer import normalize_field
def create_failed_match_table(db: DocumentDB):
"""Create the failed_match_details table."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS failed_match_details;
CREATE TABLE failed_match_details (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
field_name TEXT NOT NULL,
csv_value TEXT,
csv_value_normalized TEXT,
ocr_value TEXT,
ocr_value_normalized TEXT,
all_ocr_candidates JSONB,
matched BOOLEAN DEFAULT FALSE,
match_score REAL,
pdf_path TEXT,
pdf_type TEXT,
csv_filename TEXT,
page_no INTEGER,
bbox JSONB,
error TEXT,
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(document_id, field_name)
);
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
""")
conn.commit()
print("Created table: failed_match_details")
def get_failed_documents(db: DocumentDB) -> list:
"""Get all documents that have at least one failed field match."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
WHERE fr.matched = false
ORDER BY fr.document_id
""")
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
for row in cursor.fetchall()]
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
"""Get all failed field results for a document."""
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name, csv_value, error
FROM field_results
WHERE document_id = %s AND matched = false
""", (doc_id,))
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
for row in cursor.fetchall()]
# Cache for CSV data
_csv_cache = {}
def build_csv_cache(csv_files: list):
"""Build a cache of document_id to csv_filename mapping."""
global _csv_cache
_csv_cache = {}
for csv_file in csv_files:
csv_filename = os.path.basename(csv_file)
loader = CSVLoader(csv_file)
for row in loader.iter_rows():
if row.DocumentId not in _csv_cache:
_csv_cache[row.DocumentId] = csv_filename
def find_csv_filename(doc_id: str) -> str:
"""Find which CSV file contains the document ID."""
return _csv_cache.get(doc_id, None)
def init_worker():
"""Initialize worker process."""
import os
import warnings
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["GLOG_minloglevel"] = "2"
warnings.filterwarnings("ignore")
def process_single_document(args):
"""Process a single document and extract OCR values for failed fields."""
doc_info, failed_fields, csv_filename = args
doc_id = doc_info['document_id']
pdf_path = doc_info['pdf_path']
pdf_type = doc_info['pdf_type']
results = []
# Try to extract OCR from PDF
try:
if pdf_path and os.path.exists(pdf_path):
from shared.pdf import PDFDocument
from shared.ocr import OCREngine
pdf_doc = PDFDocument(pdf_path)
is_scanned = pdf_doc.detect_type() == "scanned"
# Collect all OCR text blocks
all_ocr_texts = []
if is_scanned:
# Use OCR for scanned PDFs
ocr_engine = OCREngine()
for page_no in range(pdf_doc.page_count):
# Render page to image
img = pdf_doc.render_page(page_no, dpi=DEFAULT_DPI)
if img is None:
continue
# OCR the image
ocr_results = ocr_engine.extract_from_image(img)
for block in ocr_results:
all_ocr_texts.append({
'text': block.get('text', ''),
'bbox': block.get('bbox'),
'page_no': page_no
})
else:
# Use text extraction for text PDFs
for page_no in range(pdf_doc.page_count):
tokens = list(pdf_doc.extract_text_tokens(page_no))
for token in tokens:
all_ocr_texts.append({
'text': token.text,
'bbox': token.bbox,
'page_no': page_no
})
# For each failed field, try to find matching OCR
for field in failed_fields:
field_name = field['field_name']
csv_value = field['csv_value']
error = field['error']
# Normalize CSV value
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
# Try to find best match in OCR
best_score = 0
best_ocr = None
best_bbox = None
best_page = None
for ocr_block in all_ocr_texts:
ocr_text = ocr_block['text']
if not ocr_text:
continue
ocr_normalized = normalize_field(field_name, ocr_text)
# Calculate similarity
if csv_normalized and ocr_normalized:
# Check substring match
if csv_normalized in ocr_normalized:
score = len(csv_normalized) / max(len(ocr_normalized), 1)
if score > best_score:
best_score = score
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
elif ocr_normalized in csv_normalized:
score = len(ocr_normalized) / max(len(csv_normalized), 1)
if score > best_score:
best_score = score
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
# Exact match
elif csv_normalized == ocr_normalized:
best_score = 1.0
best_ocr = ocr_text
best_bbox = ocr_block['bbox']
best_page = ocr_block['page_no']
break
results.append({
'document_id': doc_id,
'field_name': field_name,
'csv_value': csv_value,
'csv_value_normalized': csv_normalized,
'ocr_value': best_ocr,
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
'matched': best_score > 0.8,
'match_score': best_score,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': best_page,
'bbox': list(best_bbox) if best_bbox else None,
'error': error
})
else:
# PDF not found
for field in failed_fields:
results.append({
'document_id': doc_id,
'field_name': field['field_name'],
'csv_value': field['csv_value'],
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
'ocr_value': None,
'ocr_value_normalized': None,
'all_ocr_candidates': [],
'matched': False,
'match_score': 0,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': None,
'bbox': None,
'error': f"PDF not found: {pdf_path}"
})
except Exception as e:
for field in failed_fields:
results.append({
'document_id': doc_id,
'field_name': field['field_name'],
'csv_value': field['csv_value'],
'csv_value_normalized': None,
'ocr_value': None,
'ocr_value_normalized': None,
'all_ocr_candidates': [],
'matched': False,
'match_score': 0,
'pdf_path': pdf_path,
'pdf_type': pdf_type,
'csv_filename': csv_filename,
'page_no': None,
'bbox': None,
'error': str(e)
})
return results
def save_results_batch(db: DocumentDB, results: list):
"""Save results to failed_match_details table."""
if not results:
return
conn = db.connect()
with conn.cursor() as cursor:
for r in results:
cursor.execute("""
INSERT INTO failed_match_details
(document_id, field_name, csv_value, csv_value_normalized,
ocr_value, ocr_value_normalized, all_ocr_candidates,
matched, match_score, pdf_path, pdf_type, csv_filename,
page_no, bbox, error)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (document_id, field_name) DO UPDATE SET
csv_value = EXCLUDED.csv_value,
csv_value_normalized = EXCLUDED.csv_value_normalized,
ocr_value = EXCLUDED.ocr_value,
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
matched = EXCLUDED.matched,
match_score = EXCLUDED.match_score,
pdf_path = EXCLUDED.pdf_path,
pdf_type = EXCLUDED.pdf_type,
csv_filename = EXCLUDED.csv_filename,
page_no = EXCLUDED.page_no,
bbox = EXCLUDED.bbox,
error = EXCLUDED.error,
reprocessed_at = NOW()
""", (
r['document_id'],
r['field_name'],
r['csv_value'],
r['csv_value_normalized'],
r['ocr_value'],
r['ocr_value_normalized'],
json.dumps(r['all_ocr_candidates']),
r['matched'],
r['match_score'],
r['pdf_path'],
r['pdf_type'],
r['csv_filename'],
r['page_no'],
json.dumps(r['bbox']) if r['bbox'] else None,
r['error']
))
conn.commit()
def main():
parser = argparse.ArgumentParser(description='Re-process failed matches')
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
args = parser.parse_args()
# Expand CSV glob
csv_files = sorted(glob.glob(args.csv))
print(f"Found {len(csv_files)} CSV files")
# Build CSV cache
print("Building CSV filename cache...")
build_csv_cache(csv_files)
print(f"Cached {len(_csv_cache)} document IDs")
# Connect to database
db = DocumentDB()
db.connect()
# Create new table
create_failed_match_table(db)
# Get all failed documents
print("Fetching failed documents...")
failed_docs = get_failed_documents(db)
print(f"Found {len(failed_docs)} documents with failed matches")
if args.limit:
failed_docs = failed_docs[:args.limit]
print(f"Limited to {len(failed_docs)} documents")
# Prepare tasks
tasks = []
for doc in failed_docs:
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
csv_filename = find_csv_filename(doc['document_id'])
if failed_fields:
tasks.append((doc, failed_fields, csv_filename))
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
# Process with multiprocessing
total_results = 0
batch_results = []
batch_size = 50
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['document_id']
for task in tasks}
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
results = future.result(timeout=120)
batch_results.extend(results)
total_results += len(results)
# Save in batches
if len(batch_results) >= batch_size:
save_results_batch(db, batch_results)
batch_results = []
except TimeoutError:
print(f"\nTimeout processing {doc_id}")
except Exception as e:
print(f"\nError processing {doc_id}: {e}")
# Save remaining results
if batch_results:
save_results_batch(db, batch_results)
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
# Show summary
conn = db.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name, COUNT(*) as total,
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
COALESCE(AVG(match_score), 0) as avg_score
FROM failed_match_details
GROUP BY field_name
ORDER BY total DESC
""")
print("\nSummary by field:")
print("-" * 70)
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
print("-" * 70)
for row in cursor.fetchall():
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
db.close()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,269 @@
#!/usr/bin/env python3
"""
Training CLI
Trains YOLO model on dataset with labels from PostgreSQL database.
Images are read from filesystem, labels are dynamically generated from DB.
"""
import argparse
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI, PATHS
def main():
parser = argparse.ArgumentParser(
description='Train YOLO model for invoice field detection'
)
parser.add_argument(
'--dataset-dir', '-d',
default=PATHS['output_dir'],
help='Dataset directory containing temp/{doc_id}/images/ (default: data/dataset)'
)
parser.add_argument(
'--model', '-m',
default='yolov8s.pt',
help='Base model (default: yolov8s.pt)'
)
parser.add_argument(
'--epochs', '-e',
type=int,
default=100,
help='Number of epochs (default: 100)'
)
parser.add_argument(
'--batch', '-b',
type=int,
default=16,
help='Batch size (default: 16)'
)
parser.add_argument(
'--imgsz',
type=int,
default=1280,
help='Image size (default: 1280)'
)
parser.add_argument(
'--project',
default='runs/train',
help='Project directory (default: runs/train)'
)
parser.add_argument(
'--name',
default='invoice_fields',
help='Run name (default: invoice_fields)'
)
parser.add_argument(
'--device',
default='0',
help='Device (0, 1, cpu, mps)'
)
parser.add_argument(
'--resume',
action='store_true',
help='Resume from last checkpoint'
)
parser.add_argument(
'--workers',
type=int,
default=4,
help='Number of data loader workers (default: 4, reduce if OOM)'
)
parser.add_argument(
'--cache',
action='store_true',
help='Cache images in RAM (faster but uses more memory)'
)
parser.add_argument(
'--low-memory',
action='store_true',
help='Enable low memory mode (batch=4, workers=2, no cache)'
)
parser.add_argument(
'--train-ratio',
type=float,
default=0.8,
help='Training set ratio (default: 0.8)'
)
parser.add_argument(
'--val-ratio',
type=float,
default=0.1,
help='Validation set ratio (default: 0.1)'
)
parser.add_argument(
'--seed',
type=int,
default=42,
help='Random seed for split (default: 42)'
)
parser.add_argument(
'--dpi',
type=int,
default=DEFAULT_DPI,
help=f'DPI used for rendering (default: {DEFAULT_DPI}, must match autolabel rendering)'
)
parser.add_argument(
'--export-only',
action='store_true',
help='Only export dataset to YOLO format, do not train'
)
parser.add_argument(
'--limit',
type=int,
default=None,
help='Limit number of documents for training (default: all)'
)
args = parser.parse_args()
# Apply low-memory mode if specified
if args.low_memory:
print("🔧 Low memory mode enabled")
args.batch = min(args.batch, 8) # Reduce from 16 to 8
args.workers = min(args.workers, 4) # Reduce from 8 to 4
args.cache = False
print(f" Batch size: {args.batch}")
print(f" Workers: {args.workers}")
print(f" Cache: disabled")
# Validate dataset directory
dataset_dir = Path(args.dataset_dir)
temp_dir = dataset_dir / 'temp'
if not temp_dir.exists():
print(f"Error: Temp directory not found: {temp_dir}")
print("Run autolabel first to generate images.")
sys.exit(1)
print("=" * 60)
print("YOLO Training with Database Labels")
print("=" * 60)
print(f"Dataset dir: {dataset_dir}")
print(f"Model: {args.model}")
print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch}")
print(f"Image size: {args.imgsz}")
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
if args.limit:
print(f"Document limit: {args.limit}")
# Connect to database
from shared.data.db import DocumentDB
print("\nConnecting to database...")
db = DocumentDB()
db.connect()
# Create datasets from database
from training.yolo.db_dataset import create_datasets
print("Loading dataset from database...")
datasets = create_datasets(
images_dir=dataset_dir,
db=db,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
seed=args.seed,
dpi=args.dpi,
limit=args.limit
)
print(f"\nDataset splits:")
print(f" Train: {len(datasets['train'])} items")
print(f" Val: {len(datasets['val'])} items")
print(f" Test: {len(datasets['test'])} items")
if len(datasets['train']) == 0:
print("\nError: No training data found!")
print("Make sure autolabel has been run and images exist in temp directory.")
db.close()
sys.exit(1)
# Export to YOLO format (required for Ultralytics training)
print("\nExporting dataset to YOLO format...")
for split_name, dataset in datasets.items():
count = dataset.export_to_yolo_format(dataset_dir, split_name)
print(f" {split_name}: {count} items exported")
# Generate YOLO config files
from training.yolo.annotation_generator import AnnotationGenerator
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
if args.export_only:
print("\nExport complete (--export-only specified, skipping training)")
db.close()
return
# Start training
print("\n" + "=" * 60)
print("Starting YOLO Training")
print("=" * 60)
from ultralytics import YOLO
# Load model
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
if args.resume and last_checkpoint.exists():
print(f"Resuming from: {last_checkpoint}")
model = YOLO(str(last_checkpoint))
else:
model = YOLO(args.model)
# Training arguments
data_yaml = dataset_dir / 'dataset.yaml'
train_args = {
'data': str(data_yaml.absolute()),
'epochs': args.epochs,
'batch': args.batch,
'imgsz': args.imgsz,
'project': args.project,
'name': args.name,
'device': args.device,
'exist_ok': True,
'pretrained': True,
'verbose': True,
'workers': args.workers,
'cache': args.cache,
'resume': args.resume and last_checkpoint.exists(),
# Document-specific augmentation settings
'degrees': 5.0,
'translate': 0.05,
'scale': 0.2,
'shear': 0.0,
'perspective': 0.0,
'flipud': 0.0,
'fliplr': 0.0,
'mosaic': 0.0,
'mixup': 0.0,
'hsv_h': 0.0,
'hsv_s': 0.1,
'hsv_v': 0.2,
}
# Train
results = model.train(**train_args)
# Print results
print("\n" + "=" * 60)
print("Training Complete")
print("=" * 60)
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
# Validate on test set
print("\nRunning validation on test set...")
metrics = model.val(split='test')
print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}")
# Close database
db.close()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,336 @@
#!/usr/bin/env python3
"""
CLI for cross-validation of invoice field extraction using LLM.
Validates documents with failed field matches by sending them to an LLM
and comparing the extraction results.
"""
import argparse
import sys
from pathlib import Path
def main():
parser = argparse.ArgumentParser(
description='Cross-validate invoice field extraction using LLM'
)
subparsers = parser.add_subparsers(dest='command', help='Commands')
# Stats command
stats_parser = subparsers.add_parser('stats', help='Show failed match statistics')
# Validate command
validate_parser = subparsers.add_parser('validate', help='Validate documents with failed matches')
validate_parser.add_argument(
'--limit', '-l',
type=int,
default=10,
help='Maximum number of documents to validate (default: 10)'
)
validate_parser.add_argument(
'--provider', '-p',
choices=['openai', 'anthropic'],
default='openai',
help='LLM provider to use (default: openai)'
)
validate_parser.add_argument(
'--model', '-m',
help='Model to use (default: gpt-4o for OpenAI, claude-sonnet-4-20250514 for Anthropic)'
)
validate_parser.add_argument(
'--single', '-s',
help='Validate a single document ID'
)
# Compare command
compare_parser = subparsers.add_parser('compare', help='Compare validation results')
compare_parser.add_argument(
'document_id',
nargs='?',
help='Document ID to compare (or omit to show all)'
)
compare_parser.add_argument(
'--limit', '-l',
type=int,
default=20,
help='Maximum number of results to show (default: 20)'
)
# Report command
report_parser = subparsers.add_parser('report', help='Generate validation report')
report_parser.add_argument(
'--output', '-o',
default='reports/llm_validation_report.json',
help='Output file path (default: reports/llm_validation_report.json)'
)
args = parser.parse_args()
if not args.command:
parser.print_help()
return
from inference.validation import LLMValidator
validator = LLMValidator()
validator.connect()
validator.create_validation_table()
if args.command == 'stats':
show_stats(validator)
elif args.command == 'validate':
if args.single:
validate_single(validator, args.single, args.provider, args.model)
else:
validate_batch(validator, args.limit, args.provider, args.model)
elif args.command == 'compare':
if args.document_id:
compare_single(validator, args.document_id)
else:
compare_all(validator, args.limit)
elif args.command == 'report':
generate_report(validator, args.output)
validator.close()
def show_stats(validator):
"""Show statistics about failed matches."""
stats = validator.get_failed_match_stats()
print("\n" + "=" * 50)
print("Failed Match Statistics")
print("=" * 50)
print(f"\nDocuments with failures: {stats['documents_with_failures']}")
print(f"Already validated: {stats['already_validated']}")
print(f"Remaining to validate: {stats['remaining']}")
print("\nFailures by field:")
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
print(f" {field}: {count}")
def validate_single(validator, doc_id: str, provider: str, model: str):
"""Validate a single document."""
print(f"\nValidating document: {doc_id}")
print(f"Provider: {provider}, Model: {model or 'default'}")
print()
result = validator.validate_document(doc_id, provider, model)
if result.error:
print(f"ERROR: {result.error}")
return
print(f"Processing time: {result.processing_time_ms:.0f}ms")
print(f"Model used: {result.model_used}")
print("\nExtracted fields:")
print(f" Invoice Number: {result.invoice_number}")
print(f" Invoice Date: {result.invoice_date}")
print(f" Due Date: {result.invoice_due_date}")
print(f" OCR: {result.ocr_number}")
print(f" Bankgiro: {result.bankgiro}")
print(f" Plusgiro: {result.plusgiro}")
print(f" Amount: {result.amount}")
print(f" Org Number: {result.supplier_organisation_number}")
# Show comparison
print("\n" + "-" * 50)
print("Comparison with autolabel:")
comparison = validator.compare_results(doc_id)
for field, data in comparison.items():
if data.get('csv_value'):
status = "" if data['agreement'] else ""
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
print(f" {status} {field}:")
print(f" CSV: {data['csv_value']}")
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
print(f" LLM: {data['llm_value']}")
def validate_batch(validator, limit: int, provider: str, model: str):
"""Validate a batch of documents."""
print(f"\nValidating up to {limit} documents with failed matches")
print(f"Provider: {provider}, Model: {model or 'default'}")
print()
results = validator.validate_batch(
limit=limit,
provider=provider,
model=model,
verbose=True
)
# Summary
success = sum(1 for r in results if not r.error)
failed = len(results) - success
total_time = sum(r.processing_time_ms or 0 for r in results)
print("\n" + "=" * 50)
print("Validation Complete")
print("=" * 50)
print(f"Total: {len(results)}")
print(f"Success: {success}")
print(f"Failed: {failed}")
print(f"Total time: {total_time/1000:.1f}s")
if success > 0:
print(f"Avg time: {total_time/success:.0f}ms per document")
def compare_single(validator, doc_id: str):
"""Compare results for a single document."""
comparison = validator.compare_results(doc_id)
if 'error' in comparison:
print(f"Error: {comparison['error']}")
return
print(f"\nComparison for document: {doc_id}")
print("=" * 60)
for field, data in comparison.items():
if data.get('csv_value') is None:
continue
status = "" if data['agreement'] else ""
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
print(f"\n{status} {field}:")
print(f" CSV value: {data['csv_value']}")
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
print(f" LLM extracted: {data['llm_value']}")
def compare_all(validator, limit: int):
"""Show comparison summary for all validated documents."""
conn = validator.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT document_id FROM llm_validations
WHERE error IS NULL
ORDER BY created_at DESC
LIMIT %s
""", (limit,))
doc_ids = [row[0] for row in cursor.fetchall()]
if not doc_ids:
print("No validated documents found.")
return
print(f"\nComparison Summary ({len(doc_ids)} documents)")
print("=" * 80)
# Aggregate stats
field_stats = {}
for doc_id in doc_ids:
comparison = validator.compare_results(doc_id)
if 'error' in comparison:
continue
for field, data in comparison.items():
if data.get('csv_value') is None:
continue
if field not in field_stats:
field_stats[field] = {
'total': 0,
'autolabel_matched': 0,
'llm_agrees': 0,
'llm_correct_auto_wrong': 0,
}
stats = field_stats[field]
stats['total'] += 1
if data['autolabel_matched']:
stats['autolabel_matched'] += 1
if data['agreement']:
stats['llm_agrees'] += 1
# LLM found correct value when autolabel failed
if not data['autolabel_matched'] and data['agreement']:
stats['llm_correct_auto_wrong'] += 1
print(f"\n{'Field':<30} {'Total':>6} {'Auto OK':>8} {'LLM Agrees':>10} {'LLM Found':>10}")
print("-" * 80)
for field, stats in sorted(field_stats.items()):
print(f"{field:<30} {stats['total']:>6} {stats['autolabel_matched']:>8} "
f"{stats['llm_agrees']:>10} {stats['llm_correct_auto_wrong']:>10}")
def generate_report(validator, output_path: str):
"""Generate a detailed validation report."""
import json
from datetime import datetime
conn = validator.connect()
with conn.cursor() as cursor:
# Get all validated documents
cursor.execute("""
SELECT document_id, invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number, model_used, processing_time_ms,
error, created_at
FROM llm_validations
ORDER BY created_at DESC
""")
validations = []
for row in cursor.fetchall():
doc_id = row[0]
comparison = validator.compare_results(doc_id) if not row[11] else {}
validations.append({
'document_id': doc_id,
'llm_extraction': {
'invoice_number': row[1],
'invoice_date': row[2],
'invoice_due_date': row[3],
'ocr_number': row[4],
'bankgiro': row[5],
'plusgiro': row[6],
'amount': row[7],
'supplier_organisation_number': row[8],
},
'model_used': row[9],
'processing_time_ms': row[10],
'error': row[11],
'created_at': str(row[12]) if row[12] else None,
'comparison': comparison,
})
# Calculate summary stats
stats = validator.get_failed_match_stats()
report = {
'generated_at': datetime.now().isoformat(),
'summary': {
'total_documents_with_failures': stats['documents_with_failures'],
'documents_validated': stats['already_validated'],
'failures_by_field': stats['failures_by_field'],
},
'validations': validations,
}
# Write report
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
print(f"\nReport generated: {output_path}")
print(f"Total validations: {len(validations)}")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,313 @@
"""
Auto-Label Report Generator
Generates quality control reports for auto-labeling process.
"""
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Any
@dataclass
class FieldMatchResult:
"""Result of matching a single field."""
field_name: str
csv_value: str | None
matched: bool
score: float = 0.0
matched_text: str | None = None
candidate_used: str | None = None # Which normalized variant matched
bbox: tuple[float, float, float, float] | None = None
page_no: int = 0
context_keywords: list[str] = field(default_factory=list)
error: str | None = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
# Convert bbox to native Python floats to avoid numpy serialization issues
bbox_list = None
if self.bbox:
bbox_list = [float(x) for x in self.bbox]
return {
'field_name': self.field_name,
'csv_value': self.csv_value,
'matched': self.matched,
'score': float(self.score) if self.score else 0.0,
'matched_text': self.matched_text,
'candidate_used': self.candidate_used,
'bbox': bbox_list,
'page_no': int(self.page_no) if self.page_no else 0,
'context_keywords': self.context_keywords,
'error': self.error
}
@dataclass
class AutoLabelReport:
"""Report for a single document's auto-labeling process."""
document_id: str
pdf_path: str | None = None
pdf_type: str | None = None # 'text' | 'scanned' | 'mixed'
success: bool = False
total_pages: int = 0
fields_matched: int = 0
fields_total: int = 0
field_results: list[FieldMatchResult] = field(default_factory=list)
annotations_generated: int = 0
image_paths: list[str] = field(default_factory=list)
label_paths: list[str] = field(default_factory=list)
processing_time_ms: float = 0.0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
errors: list[str] = field(default_factory=list)
# New metadata fields (from CSV, not for matching)
split: str | None = None
customer_number: str | None = None
supplier_name: str | None = None
supplier_organisation_number: str | None = None
supplier_accounts: str | None = None
def add_field_result(self, result: FieldMatchResult) -> None:
"""Add a field matching result."""
self.field_results.append(result)
self.fields_total += 1
if result.matched:
self.fields_matched += 1
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
'document_id': self.document_id,
'pdf_path': self.pdf_path,
'pdf_type': self.pdf_type,
'success': self.success,
'total_pages': self.total_pages,
'fields_matched': self.fields_matched,
'fields_total': self.fields_total,
'field_results': [r.to_dict() for r in self.field_results],
'annotations_generated': self.annotations_generated,
'image_paths': self.image_paths,
'label_paths': self.label_paths,
'processing_time_ms': self.processing_time_ms,
'timestamp': self.timestamp,
'errors': self.errors,
# New metadata fields
'split': self.split,
'customer_number': self.customer_number,
'supplier_name': self.supplier_name,
'supplier_organisation_number': self.supplier_organisation_number,
'supplier_accounts': self.supplier_accounts,
}
def to_json(self, indent: int | None = None) -> str:
"""Convert to JSON string."""
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
@property
def match_rate(self) -> float:
"""Calculate field match rate."""
if self.fields_total == 0:
return 0.0
return self.fields_matched / self.fields_total
def get_summary(self) -> dict:
"""Get a summary of the report."""
return {
'document_id': self.document_id,
'success': self.success,
'match_rate': f"{self.match_rate:.1%}",
'fields': f"{self.fields_matched}/{self.fields_total}",
'annotations': self.annotations_generated,
'errors': len(self.errors)
}
class ReportWriter:
"""Writes auto-label reports to file with optional sharding."""
def __init__(
self,
output_path: str | Path,
max_records_per_file: int = 0
):
"""
Initialize report writer.
Args:
output_path: Path to output JSONL file (base name if sharding)
max_records_per_file: Max records per file (0 = no limit, single file)
"""
self.output_path = Path(output_path)
self.output_path.parent.mkdir(parents=True, exist_ok=True)
self.max_records_per_file = max_records_per_file
# Sharding state
self._current_shard = 0
self._records_in_current_shard = 0
self._shard_files: list[Path] = []
def _get_shard_path(self) -> Path:
"""Get the path for current shard."""
if self.max_records_per_file > 0:
base = self.output_path.stem
suffix = self.output_path.suffix
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
else:
shard_path = self.output_path
if shard_path not in self._shard_files:
self._shard_files.append(shard_path)
return shard_path
def _check_shard_rotation(self) -> None:
"""Check if we need to rotate to a new shard file."""
if self.max_records_per_file > 0:
if self._records_in_current_shard >= self.max_records_per_file:
self._current_shard += 1
self._records_in_current_shard = 0
def write(self, report: AutoLabelReport) -> None:
"""Append a report to the output file."""
self._check_shard_rotation()
shard_path = self._get_shard_path()
with open(shard_path, 'a', encoding='utf-8') as f:
f.write(report.to_json() + '\n')
self._records_in_current_shard += 1
def write_dict(self, report_dict: dict) -> None:
"""Append a report dict to the output file (for parallel processing)."""
self._check_shard_rotation()
shard_path = self._get_shard_path()
with open(shard_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
f.flush()
self._records_in_current_shard += 1
def write_batch(self, reports: list[AutoLabelReport]) -> None:
"""Write multiple reports."""
for report in reports:
self.write(report)
def get_shard_files(self) -> list[Path]:
"""Get list of all shard files created."""
return self._shard_files.copy()
class ReportReader:
"""Reads auto-label reports from file(s)."""
def __init__(self, input_path: str | Path):
"""
Initialize report reader.
Args:
input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
"""
self.input_path = Path(input_path)
# Handle glob pattern
if '*' in str(input_path) or '?' in str(input_path):
parent = self.input_path.parent
pattern = self.input_path.name
self.input_paths = sorted(parent.glob(pattern))
else:
self.input_paths = [self.input_path]
def read_all(self) -> list[AutoLabelReport]:
"""Read all reports from file(s)."""
reports = []
for input_path in self.input_paths:
if not input_path.exists():
continue
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
report = self._dict_to_report(data)
reports.append(report)
return reports
def _dict_to_report(self, data: dict) -> AutoLabelReport:
"""Convert dictionary to AutoLabelReport."""
field_results = []
for fr_data in data.get('field_results', []):
bbox = tuple(fr_data['bbox']) if fr_data.get('bbox') else None
field_results.append(FieldMatchResult(
field_name=fr_data['field_name'],
csv_value=fr_data.get('csv_value'),
matched=fr_data.get('matched', False),
score=fr_data.get('score', 0.0),
matched_text=fr_data.get('matched_text'),
candidate_used=fr_data.get('candidate_used'),
bbox=bbox,
page_no=fr_data.get('page_no', 0),
context_keywords=fr_data.get('context_keywords', []),
error=fr_data.get('error')
))
return AutoLabelReport(
document_id=data['document_id'],
pdf_path=data.get('pdf_path'),
pdf_type=data.get('pdf_type'),
success=data.get('success', False),
total_pages=data.get('total_pages', 0),
fields_matched=data.get('fields_matched', 0),
fields_total=data.get('fields_total', 0),
field_results=field_results,
annotations_generated=data.get('annotations_generated', 0),
image_paths=data.get('image_paths', []),
label_paths=data.get('label_paths', []),
processing_time_ms=data.get('processing_time_ms', 0.0),
timestamp=data.get('timestamp', ''),
errors=data.get('errors', [])
)
def get_statistics(self) -> dict:
"""Calculate statistics from all reports."""
reports = self.read_all()
if not reports:
return {'total': 0}
successful = sum(1 for r in reports if r.success)
total_fields_matched = sum(r.fields_matched for r in reports)
total_fields = sum(r.fields_total for r in reports)
total_annotations = sum(r.annotations_generated for r in reports)
# Per-field statistics
field_stats = {}
for report in reports:
for fr in report.field_results:
if fr.field_name not in field_stats:
field_stats[fr.field_name] = {'matched': 0, 'total': 0, 'avg_score': 0.0}
field_stats[fr.field_name]['total'] += 1
if fr.matched:
field_stats[fr.field_name]['matched'] += 1
field_stats[fr.field_name]['avg_score'] += fr.score
# Calculate averages
for field_name, stats in field_stats.items():
if stats['matched'] > 0:
stats['avg_score'] /= stats['matched']
stats['match_rate'] = stats['matched'] / stats['total'] if stats['total'] > 0 else 0
return {
'total': len(reports),
'successful': successful,
'success_rate': successful / len(reports),
'total_fields_matched': total_fields_matched,
'total_fields': total_fields,
'overall_match_rate': total_fields_matched / total_fields if total_fields > 0 else 0,
'total_annotations': total_annotations,
'field_statistics': field_stats
}

View File

@@ -0,0 +1,134 @@
"""Database operations for training tasks."""
import json
import logging
from datetime import datetime, timezone
import psycopg2
import psycopg2.extras
from shared.config import get_db_connection_string
logger = logging.getLogger(__name__)
class TrainingTaskDB:
"""Read/write training_tasks table."""
def _connect(self):
return psycopg2.connect(get_db_connection_string())
def get_task(self, task_id: str) -> dict | None:
"""Get a single training task by ID."""
conn = self._connect()
try:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM training_tasks WHERE task_id = %s",
(task_id,),
)
return cur.fetchone()
finally:
conn.close()
def get_pending_tasks(self, limit: int = 1) -> list[dict]:
"""Get pending tasks ordered by creation time."""
conn = self._connect()
try:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"""
SELECT * FROM training_tasks
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT %s
""",
(limit,),
)
return cur.fetchall()
finally:
conn.close()
def update_status(self, task_id: str, status: str) -> None:
"""Update task status with timestamp."""
conn = self._connect()
try:
with conn.cursor() as cur:
if status == "running":
cur.execute(
"UPDATE training_tasks SET status = %s, started_at = %s WHERE task_id = %s",
(status, datetime.now(timezone.utc), task_id),
)
else:
cur.execute(
"UPDATE training_tasks SET status = %s WHERE task_id = %s",
(status, task_id),
)
conn.commit()
finally:
conn.close()
def complete_task(
self, task_id: str, model_path: str, metrics: dict
) -> None:
"""Mark task as completed with results."""
conn = self._connect()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE training_tasks
SET status = 'completed',
completed_at = %s,
model_path = %s,
metrics = %s
WHERE task_id = %s
""",
(
datetime.now(timezone.utc),
model_path,
json.dumps(metrics),
task_id,
),
)
conn.commit()
finally:
conn.close()
def fail_task(self, task_id: str, error_message: str) -> None:
"""Mark task as failed."""
conn = self._connect()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE training_tasks
SET status = 'failed',
completed_at = %s,
error_message = %s
WHERE task_id = %s
""",
(datetime.now(timezone.utc), error_message[:2000], task_id),
)
conn.commit()
finally:
conn.close()
def create_task(self, config: dict) -> str:
"""Create a new training task. Returns task_id."""
conn = self._connect()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO training_tasks (config)
VALUES (%s)
RETURNING task_id
""",
(json.dumps(config),),
)
task_id = str(cur.fetchone()[0])
conn.commit()
return task_id
finally:
conn.close()

View File

@@ -0,0 +1,22 @@
"""
Processing module for multi-pool parallel processing.
This module provides a robust dual-pool architecture for processing
documents with both CPU-bound and GPU-bound tasks.
"""
from training.processing.worker_pool import WorkerPool, TaskResult
from training.processing.cpu_pool import CPUWorkerPool
from training.processing.gpu_pool import GPUWorkerPool
from training.processing.task_dispatcher import TaskDispatcher, TaskType
from training.processing.dual_pool_coordinator import DualPoolCoordinator
__all__ = [
"WorkerPool",
"TaskResult",
"CPUWorkerPool",
"GPUWorkerPool",
"TaskDispatcher",
"TaskType",
"DualPoolCoordinator",
]

View File

@@ -0,0 +1,323 @@
"""
Task functions for autolabel processing in multi-pool architecture.
Provides CPU and GPU task functions that can be called from worker pools.
"""
from __future__ import annotations
import os
import time
import warnings
from pathlib import Path
from typing import Any, Dict, Optional
from shared.config import DEFAULT_DPI
# Global OCR instance (initialized once per GPU worker process)
_ocr_engine: Optional[Any] = None
def init_cpu_worker() -> None:
"""
Initialize CPU worker process.
Disables GPU access and suppresses unnecessary warnings.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
def init_gpu_worker(gpu_id: int = 0, gpu_mem: int = 4000) -> None:
"""
Initialize GPU worker process with PaddleOCR.
Args:
gpu_id: GPU device ID.
gpu_mem: Maximum GPU memory in MB.
"""
global _ocr_engine
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
os.environ["GLOG_minloglevel"] = "2"
# Suppress PaddleX warnings
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
warnings.filterwarnings("ignore", message=".*reinitialization.*")
# Lazy initialization - OCR will be created on first use
_ocr_engine = None
def _get_ocr_engine():
"""Get or create OCR engine for current GPU worker."""
global _ocr_engine
if _ocr_engine is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
from shared.ocr import OCREngine
_ocr_engine = OCREngine()
return _ocr_engine
def _save_output_img(output_img, image_path: Path) -> None:
"""Save OCR preprocessed image to replace rendered image."""
from PIL import Image as PILImage
if output_img is not None:
img = PILImage.fromarray(output_img)
img.save(str(image_path))
def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a text PDF (CPU task - no OCR needed).
Args:
task_data: Dictionary with keys:
- row_dict: Document fields from CSV
- pdf_path: Path to PDF file
- output_dir: Output directory
- dpi: Rendering DPI
- min_confidence: Minimum match confidence
Returns:
Result dictionary with success status, annotations, and report.
"""
import shutil
from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields
row_dict = task_data["row_dict"]
pdf_path = Path(task_data["pdf_path"])
output_dir = Path(task_data["output_dir"])
dpi = task_data.get("dpi", DEFAULT_DPI)
min_confidence = task_data.get("min_confidence", 0.5)
start_time = time.time()
doc_id = row_dict["DocumentId"]
# Clean up existing temp folder for this document (for re-matching)
temp_doc_dir = output_dir / "temp" / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
report.pdf_type = "text"
# Store metadata fields from CSV (same as single document mode)
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = {
"doc_id": doc_id,
"success": False,
"pages": [],
"report": None,
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
}
try:
with PDFDocument(pdf_path) as pdf_doc:
page_annotations = []
matched_fields = set()
images_dir = output_dir / "temp" / doc_id / "images"
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Text extraction (no OCR)
tokens = list(pdf_doc.extract_text_tokens(page_no))
# Get page dimensions for payment line detection
page = pdf_doc.doc[page_no]
page_height = page.rect.height
page_width = page.rect.width
# Use shared processing logic (same as single document mode)
matches = {}
annotations, ann_count = process_page(
tokens=tokens,
row_dict=row_dict,
page_no=page_no,
page_height=page_height,
page_width=page_width,
img_width=img_width,
img_height=img_height,
dpi=dpi,
min_confidence=min_confidence,
matches=matches,
matched_fields=matched_fields,
report=report,
result_stats=result["stats"],
)
if annotations:
page_annotations.append(
{
"image_path": str(image_path),
"page_no": page_no,
"count": ann_count,
}
)
report.annotations_generated += ann_count
# Record unmatched fields using shared logic
record_unmatched_fields(row_dict, matched_fields, report)
if page_annotations:
result["pages"] = page_annotations
result["success"] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result["report"] = report.to_dict()
return result
def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a scanned PDF (GPU task - requires OCR).
Args:
task_data: Dictionary with keys:
- row_dict: Document fields from CSV
- pdf_path: Path to PDF file
- output_dir: Output directory
- dpi: Rendering DPI
- min_confidence: Minimum match confidence
Returns:
Result dictionary with success status, annotations, and report.
"""
import shutil
from training.data.autolabel_report import AutoLabelReport
from shared.pdf import PDFDocument
from training.yolo.annotation_generator import FIELD_CLASSES
from training.processing.document_processor import process_page, record_unmatched_fields
row_dict = task_data["row_dict"]
pdf_path = Path(task_data["pdf_path"])
output_dir = Path(task_data["output_dir"])
dpi = task_data.get("dpi", DEFAULT_DPI)
min_confidence = task_data.get("min_confidence", 0.5)
start_time = time.time()
doc_id = row_dict["DocumentId"]
# Clean up existing temp folder for this document (for re-matching)
temp_doc_dir = output_dir / "temp" / doc_id
if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True)
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
report.pdf_type = "scanned"
# Store metadata fields from CSV (same as single document mode)
report.split = row_dict.get('split')
report.customer_number = row_dict.get('customer_number')
report.supplier_name = row_dict.get('supplier_name')
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
report.supplier_accounts = row_dict.get('supplier_accounts')
result = {
"doc_id": doc_id,
"success": False,
"pages": [],
"report": None,
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
}
try:
# Get OCR engine from worker cache
ocr_engine = _get_ocr_engine()
with PDFDocument(pdf_path) as pdf_doc:
page_annotations = []
matched_fields = set()
images_dir = output_dir / "temp" / doc_id / "images"
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Get page dimensions for payment line detection
page = pdf_doc.doc[page_no]
page_height = page.rect.height
page_width = page.rect.width
# OCR extraction
ocr_result = ocr_engine.extract_with_image(
str(image_path),
page_no,
scale_to_pdf_points=72 / dpi,
)
tokens = ocr_result.tokens
# Save preprocessed image
_save_output_img(ocr_result.output_img, image_path)
# Update dimensions to match OCR output
if ocr_result.output_img is not None:
img_height, img_width = ocr_result.output_img.shape[:2]
# Use shared processing logic (same as single document mode)
matches = {}
annotations, ann_count = process_page(
tokens=tokens,
row_dict=row_dict,
page_no=page_no,
page_height=page_height,
page_width=page_width,
img_width=img_width,
img_height=img_height,
dpi=dpi,
min_confidence=min_confidence,
matches=matches,
matched_fields=matched_fields,
report=report,
result_stats=result["stats"],
)
if annotations:
page_annotations.append(
{
"image_path": str(image_path),
"page_no": page_no,
"count": ann_count,
}
)
report.annotations_generated += ann_count
# Record unmatched fields using shared logic
record_unmatched_fields(row_dict, matched_fields, report)
if page_annotations:
result["pages"] = page_annotations
result["success"] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result["report"] = report.to_dict()
return result

View File

@@ -0,0 +1,71 @@
"""
CPU Worker Pool for text PDF processing.
This pool handles CPU-bound tasks like text extraction from native PDFs
that don't require OCR.
"""
from __future__ import annotations
import logging
import os
from typing import Callable, Optional
from training.processing.worker_pool import WorkerPool
logger = logging.getLogger(__name__)
# Global resources for CPU workers (initialized once per process)
_cpu_initialized: bool = False
def _init_cpu_worker() -> None:
"""
Initialize a CPU worker process.
Disables GPU access and sets up CPU-only environment.
"""
global _cpu_initialized
# Disable GPU access for CPU workers
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Set threading limits for better CPU utilization
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
_cpu_initialized = True
logger.debug(f"CPU worker initialized in process {os.getpid()}")
class CPUWorkerPool(WorkerPool):
"""
Worker pool for CPU-bound tasks.
Handles text PDF processing that doesn't require OCR.
Each worker is initialized with CUDA disabled to prevent
accidental GPU memory consumption.
Example:
with CPUWorkerPool(max_workers=4) as pool:
future = pool.submit(process_text_pdf, pdf_path)
result = future.result()
"""
def __init__(self, max_workers: int = 4) -> None:
"""
Initialize CPU worker pool.
Args:
max_workers: Number of CPU worker processes.
Defaults to 4 for balanced performance.
"""
super().__init__(max_workers=max_workers, use_gpu=False, gpu_id=-1)
def get_initializer(self) -> Optional[Callable[..., None]]:
"""Return the CPU worker initializer."""
return _init_cpu_worker
def get_init_args(self) -> tuple:
"""Return empty args for CPU initializer."""
return ()

View File

@@ -0,0 +1,448 @@
"""
Shared document processing logic for autolabel.
This module provides the core processing functions used by both
single document mode and batch processing mode to ensure consistent
matching and annotation logic.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from training.data.autolabel_report import FieldMatchResult
from shared.matcher import FieldMatcher
from shared.normalize import normalize_field
from shared.ocr.machine_code_parser import MachineCodeParser
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
def match_supplier_accounts(
tokens: list,
supplier_accounts_value: str,
matcher: FieldMatcher,
page_no: int,
matches: Dict[str, list],
matched_fields: Set[str],
report: Any,
) -> None:
"""
Match supplier_accounts field and map to Bankgiro/Plusgiro.
This logic is shared between single document mode and batch mode
to ensure consistent BG/PG type detection.
Args:
tokens: List of text tokens from the page
supplier_accounts_value: Raw value from CSV (e.g., "BG:xxx | PG:yyy")
matcher: FieldMatcher instance
page_no: Current page number
matches: Dictionary to store matched fields (modified in place)
matched_fields: Set of matched field names (modified in place)
report: AutoLabelReport instance
"""
if not supplier_accounts_value:
return
# Parse accounts: "BG:xxx | PG:yyy" format
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
for account in accounts:
account = account.strip()
if not account:
continue
# Determine account type (BG or PG) and extract account number
account_type = None
account_number = account # Default to full value
if account.upper().startswith('BG:'):
account_type = 'Bankgiro'
account_number = account[3:].strip() # Remove "BG:" prefix
elif account.upper().startswith('BG '):
account_type = 'Bankgiro'
account_number = account[2:].strip() # Remove "BG" prefix
elif account.upper().startswith('PG:'):
account_type = 'Plusgiro'
account_number = account[3:].strip() # Remove "PG:" prefix
elif account.upper().startswith('PG '):
account_type = 'Plusgiro'
account_number = account[2:].strip() # Remove "PG" prefix
else:
# Try to guess from format - Plusgiro often has format XXXXXXX-X
digits = ''.join(c for c in account if c.isdigit())
if len(digits) == 8 and '-' in account:
account_type = 'Plusgiro'
elif len(digits) in (7, 8):
account_type = 'Bankgiro' # Default to Bankgiro
if not account_type:
continue
# Normalize and match using the account number (without prefix)
normalized = normalize_field('supplier_accounts', account_number)
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
if field_matches:
best = field_matches[0]
# Add to matches under the target class (Bankgiro/Plusgiro)
if account_type not in matches:
matches[account_type] = []
matches[account_type].extend(field_matches)
matched_fields.add('supplier_accounts')
report.add_field_result(FieldMatchResult(
field_name=f'supplier_accounts({account_type})',
csv_value=account_number, # Store without prefix
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
def detect_payment_line(
tokens: list,
page_height: float,
page_width: float,
) -> Optional[Any]:
"""
Detect payment line (machine code) and return the parsed result.
This function only detects and parses the payment line, without generating
annotations. The caller can use the result to extract amount for cross-validation.
Args:
tokens: List of text tokens from the page
page_height: Page height in PDF points
page_width: Page width in PDF points
Returns:
MachineCodeResult if standard format detected (confidence >= 0.95), None otherwise
"""
# Use 55% of page height as bottom region to catch payment lines
# that may be in the middle of the page (e.g., payment slips)
mc_parser = MachineCodeParser(bottom_region_ratio=0.55)
mc_result = mc_parser.parse(tokens, page_height, page_width)
# Only return if we found a STANDARD payment line format
# (confidence 0.95 means standard pattern matched with # and > symbols)
is_standard_format = mc_result.confidence >= 0.95
if is_standard_format:
return mc_result
return None
def match_payment_line(
tokens: list,
page_height: float,
page_width: float,
min_confidence: float,
generator: AnnotationGenerator,
annotations: list,
img_width: int,
img_height: int,
dpi: int,
matched_fields: Set[str],
report: Any,
page_no: int,
mc_result: Optional[Any] = None,
) -> None:
"""
Annotate payment line (machine code) using pre-detected result.
This logic is shared between single document mode and batch mode
to ensure consistent payment_line detection.
Args:
tokens: List of text tokens from the page
page_height: Page height in PDF points
page_width: Page width in PDF points
min_confidence: Minimum confidence threshold
generator: AnnotationGenerator instance
annotations: List of annotations (modified in place)
img_width: Image width in pixels
img_height: Image height in pixels
dpi: DPI used for rendering
matched_fields: Set of matched field names (modified in place)
report: AutoLabelReport instance
page_no: Current page number
mc_result: Pre-detected MachineCodeResult (from detect_payment_line)
"""
# Use pre-detected result if provided, otherwise detect now
if mc_result is None:
mc_result = detect_payment_line(tokens, page_height, page_width)
# Only add payment_line if we have a valid standard format result
if mc_result is None:
return
if mc_result.confidence >= min_confidence:
region_bbox = mc_result.get_region_bbox()
if region_bbox:
generator.add_payment_line_annotation(
annotations, region_bbox, mc_result.confidence,
img_width, img_height, dpi=dpi
)
# Store payment_line result in database
matched_fields.add('payment_line')
report.add_field_result(FieldMatchResult(
field_name='payment_line',
csv_value=mc_result.raw_line[:200] if mc_result.raw_line else '',
matched=True,
score=mc_result.confidence,
matched_text=f"OCR:{mc_result.ocr or ''} Amount:{mc_result.amount or ''} BG:{mc_result.bankgiro or ''}",
candidate_used='machine_code_parser',
bbox=region_bbox,
page_no=page_no,
context_keywords=['payment_line', 'machine_code']
))
def match_standard_fields(
tokens: list,
row_dict: Dict[str, Any],
matcher: FieldMatcher,
page_no: int,
matches: Dict[str, list],
matched_fields: Set[str],
report: Any,
payment_line_amount: Optional[str] = None,
payment_line_bbox: Optional[tuple] = None,
) -> None:
"""
Match standard fields from CSV to tokens.
This excludes payment_line (detected separately) and supplier_accounts
(handled by match_supplier_accounts).
Args:
tokens: List of text tokens from the page
row_dict: Dictionary of field values from CSV
matcher: FieldMatcher instance
page_no: Current page number
matches: Dictionary to store matched fields (modified in place)
matched_fields: Set of matched field names (modified in place)
report: AutoLabelReport instance
payment_line_amount: Amount extracted from payment_line (takes priority over CSV)
payment_line_bbox: Bounding box of payment_line region (used as fallback for Amount)
"""
for field_name in FIELD_CLASSES.keys():
# Skip fields handled separately
if field_name == 'payment_line':
continue
if field_name in ('Bankgiro', 'Plusgiro'):
continue # Handled via supplier_accounts
value = row_dict.get(field_name)
# For Amount field: only use payment_line amount if it matches CSV value
use_payment_line_amount = False
if field_name == 'Amount' and payment_line_amount and value:
# Parse both amounts and check if they're close
try:
csv_amt = float(str(value).replace(',', '.').replace(' ', ''))
pl_amt = float(str(payment_line_amount).replace(',', '.').replace(' ', ''))
if abs(csv_amt - pl_amt) < 0.01:
# Payment line amount matches CSV, use it for better bbox
value = payment_line_amount
use_payment_line_amount = True
# Otherwise keep CSV value for matching
except (ValueError, TypeError):
pass
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
matched_fields.add(field_name)
# For Amount: note if we used payment_line amount
csv_value_display = str(row_dict.get(field_name, value))
if field_name == 'Amount' and use_payment_line_amount:
csv_value_display = f"{row_dict.get(field_name)} (matched via payment_line: {payment_line_amount})"
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=csv_value_display,
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
elif field_name == 'Amount' and use_payment_line_amount and payment_line_bbox:
# Fallback: Amount not found via token matching, but payment_line
# successfully extracted a matching amount. Use payment_line bbox.
# This handles cases where text PDFs merge multiple values into one token.
from shared.matcher.field_matcher import Match
fallback_match = Match(
field='Amount',
value=payment_line_amount,
bbox=payment_line_bbox,
page_no=page_no,
score=0.9,
matched_text=f"Amount:{payment_line_amount}",
context_keywords=['payment_line', 'amount']
)
matches[field_name] = [fallback_match]
matched_fields.add(field_name)
csv_value_display = f"{row_dict.get(field_name)} (via payment_line: {payment_line_amount})"
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=csv_value_display,
matched=True,
score=0.9, # High confidence since payment_line parsing succeeded
matched_text=f"Amount:{payment_line_amount}",
candidate_used='payment_line_fallback',
bbox=payment_line_bbox,
page_no=page_no,
context_keywords=['payment_line', 'amount']
))
def record_unmatched_fields(
row_dict: Dict[str, Any],
matched_fields: Set[str],
report: Any,
) -> None:
"""
Record fields from CSV that were not matched.
Args:
row_dict: Dictionary of field values from CSV
matched_fields: Set of matched field names
report: AutoLabelReport instance
"""
for field_name in FIELD_CLASSES.keys():
if field_name == 'payment_line':
continue # payment_line doesn't come from CSV
if field_name in ('Bankgiro', 'Plusgiro'):
continue # These come from supplier_accounts
value = row_dict.get(field_name)
if value and field_name not in matched_fields:
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=-1
))
# Check if supplier_accounts was not matched
if row_dict.get('supplier_accounts') and 'supplier_accounts' not in matched_fields:
report.add_field_result(FieldMatchResult(
field_name='supplier_accounts',
csv_value=str(row_dict.get('supplier_accounts')),
matched=False,
page_no=-1
))
def process_page(
tokens: list,
row_dict: Dict[str, Any],
page_no: int,
page_height: float,
page_width: float,
img_width: int,
img_height: int,
dpi: int,
min_confidence: float,
matches: Dict[str, list],
matched_fields: Set[str],
report: Any,
result_stats: Dict[str, int],
) -> Tuple[list, int]:
"""
Process a single page: match fields and generate annotations.
This is the main entry point for page processing, used by both
single document mode and batch mode.
Processing order:
1. Detect payment_line first to extract amount
2. Match standard fields (using payment_line amount if available)
3. Match supplier_accounts
4. Generate annotations
Args:
tokens: List of text tokens from the page
row_dict: Dictionary of field values from CSV
page_no: Current page number
page_height: Page height in PDF points
page_width: Page width in PDF points
img_width: Image width in pixels
img_height: Image height in pixels
dpi: DPI used for rendering
min_confidence: Minimum confidence threshold
matches: Dictionary to store matched fields (modified in place)
matched_fields: Set of matched field names (modified in place)
report: AutoLabelReport instance
result_stats: Dictionary of annotation stats (modified in place)
Returns:
Tuple of (annotations list, annotation count)
"""
matcher = FieldMatcher()
generator = AnnotationGenerator(min_confidence=min_confidence)
# Step 1: Detect payment_line FIRST to extract amount
# This allows us to use the payment_line amount for matching Amount field
mc_result = detect_payment_line(tokens, page_height, page_width)
# Extract amount and bbox from payment_line if available
payment_line_amount = None
payment_line_bbox = None
if mc_result and mc_result.amount:
payment_line_amount = mc_result.amount
payment_line_bbox = mc_result.get_region_bbox()
# Step 2: Match standard fields (using payment_line amount if available)
match_standard_fields(
tokens, row_dict, matcher, page_no,
matches, matched_fields, report,
payment_line_amount=payment_line_amount,
payment_line_bbox=payment_line_bbox
)
# Step 3: Match supplier_accounts -> Bankgiro/Plusgiro
supplier_accounts_value = row_dict.get('supplier_accounts')
if supplier_accounts_value:
match_supplier_accounts(
tokens, supplier_accounts_value, matcher, page_no,
matches, matched_fields, report
)
# Generate annotations from matches
annotations = generator.generate_from_matches(
matches, img_width, img_height, dpi=dpi
)
# Step 4: Add payment_line annotation (reuse the pre-detected result)
match_payment_line(
tokens, page_height, page_width, min_confidence,
generator, annotations, img_width, img_height, dpi,
matched_fields, report, page_no,
mc_result=mc_result
)
# Update stats
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result_stats[class_name] += 1
return annotations, len(annotations)

View File

@@ -0,0 +1,339 @@
"""
Dual Pool Coordinator for managing CPU and GPU worker pools.
Coordinates task distribution between CPU and GPU pools, handles result
collection using as_completed(), and provides callbacks for progress tracking.
"""
from __future__ import annotations
import logging
import time
from concurrent.futures import Future, TimeoutError, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from training.processing.cpu_pool import CPUWorkerPool
from training.processing.gpu_pool import GPUWorkerPool
from training.processing.task_dispatcher import Task, TaskDispatcher, TaskType
from training.processing.worker_pool import TaskResult
logger = logging.getLogger(__name__)
@dataclass
class BatchStats:
"""Statistics for a batch processing run."""
total: int = 0
cpu_submitted: int = 0
gpu_submitted: int = 0
successful: int = 0
failed: int = 0
cpu_time: float = 0.0
gpu_time: float = 0.0
errors: List[str] = field(default_factory=list)
@property
def success_rate(self) -> float:
"""Calculate success rate as percentage."""
if self.total == 0:
return 0.0
return (self.successful / self.total) * 100
class DualPoolCoordinator:
"""
Coordinates CPU and GPU worker pools for parallel document processing.
Uses separate ProcessPoolExecutor instances for CPU and GPU tasks,
with as_completed() for efficient result collection across both pools.
Key features:
- Automatic task classification (CPU vs GPU)
- Parallel submission to both pools
- Unified result collection with timeouts
- Progress callbacks for UI integration
- Proper resource cleanup
Example:
with DualPoolCoordinator(cpu_workers=4, gpu_workers=1) as coord:
results = coord.process_batch(
documents=docs,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=lambda r: save_to_db(r),
)
"""
def __init__(
self,
cpu_workers: int = 4,
gpu_workers: int = 1,
gpu_id: int = 0,
task_timeout: float = 300.0,
) -> None:
"""
Initialize the dual pool coordinator.
Args:
cpu_workers: Number of CPU worker processes.
gpu_workers: Number of GPU worker processes (usually 1).
gpu_id: GPU device ID to use.
task_timeout: Timeout in seconds for individual tasks.
"""
self.cpu_workers = cpu_workers
self.gpu_workers = gpu_workers
self.gpu_id = gpu_id
self.task_timeout = task_timeout
self._cpu_pool: Optional[CPUWorkerPool] = None
self._gpu_pool: Optional[GPUWorkerPool] = None
self._dispatcher = TaskDispatcher()
self._started = False
def start(self) -> None:
"""Start both worker pools."""
if self._started:
raise RuntimeError("Coordinator already started")
logger.info(
f"Starting DualPoolCoordinator: "
f"{self.cpu_workers} CPU workers, {self.gpu_workers} GPU workers"
)
self._cpu_pool = CPUWorkerPool(max_workers=self.cpu_workers)
self._gpu_pool = GPUWorkerPool(
max_workers=self.gpu_workers,
gpu_id=self.gpu_id,
)
self._cpu_pool.start()
self._gpu_pool.start()
self._started = True
def shutdown(self, wait: bool = True) -> None:
"""Shutdown both worker pools."""
logger.info("Shutting down DualPoolCoordinator")
if self._cpu_pool is not None:
self._cpu_pool.shutdown(wait=wait)
self._cpu_pool = None
if self._gpu_pool is not None:
self._gpu_pool.shutdown(wait=wait)
self._gpu_pool = None
self._started = False
def __enter__(self) -> "DualPoolCoordinator":
"""Context manager entry."""
self.start()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit."""
self.shutdown(wait=True)
def process_batch(
self,
documents: List[dict],
cpu_task_fn: Callable[[dict], Any],
gpu_task_fn: Callable[[dict], Any],
on_result: Optional[Callable[[TaskResult], None]] = None,
on_error: Optional[Callable[[str, Exception], None]] = None,
on_progress: Optional[Callable[[int, int], None]] = None,
id_field: str = "id",
) -> List[TaskResult]:
"""
Process a batch of documents using both CPU and GPU pools.
Documents are automatically classified and routed to the appropriate
pool. Results are collected as they complete.
Args:
documents: List of document info dicts to process.
cpu_task_fn: Function to process text PDFs (called in CPU pool).
gpu_task_fn: Function to process scanned PDFs (called in GPU pool).
on_result: Callback for each successful result.
on_error: Callback for each failed task (task_id, exception).
on_progress: Callback for progress updates (completed, total).
id_field: Field name to use as task ID.
Returns:
List of TaskResult objects for all tasks.
Raises:
RuntimeError: If coordinator is not started.
"""
if not self._started:
raise RuntimeError("Coordinator not started. Use context manager or call start().")
if not documents:
return []
stats = BatchStats(total=len(documents))
# Create and partition tasks
tasks = self._dispatcher.create_tasks(documents, id_field=id_field)
cpu_tasks, gpu_tasks = self._dispatcher.partition_tasks(tasks)
# Submit tasks to pools
futures_map: Dict[Future, Task] = {}
# Submit CPU tasks
cpu_start = time.time()
for task in cpu_tasks:
future = self._cpu_pool.submit(cpu_task_fn, task.data)
futures_map[future] = task
stats.cpu_submitted += 1
# Submit GPU tasks
gpu_start = time.time()
for task in gpu_tasks:
future = self._gpu_pool.submit(gpu_task_fn, task.data)
futures_map[future] = task
stats.gpu_submitted += 1
logger.info(
f"Submitted {stats.cpu_submitted} CPU tasks, {stats.gpu_submitted} GPU tasks"
)
# Collect results as they complete
results: List[TaskResult] = []
completed = 0
for future in as_completed(futures_map.keys(), timeout=self.task_timeout * len(documents)):
task = futures_map[future]
pool_type = "CPU" if task.task_type == TaskType.CPU else "GPU"
start_time = time.time()
try:
data = future.result(timeout=self.task_timeout)
processing_time = time.time() - start_time
result = TaskResult(
task_id=task.id,
success=True,
data=data,
pool_type=pool_type,
processing_time=processing_time,
)
stats.successful += 1
if pool_type == "CPU":
stats.cpu_time += processing_time
else:
stats.gpu_time += processing_time
if on_result is not None:
try:
on_result(result)
except Exception as e:
logger.warning(f"on_result callback failed: {e}")
except TimeoutError:
error_msg = f"Task timed out after {self.task_timeout}s"
logger.error(f"[{pool_type}] Task {task.id}: {error_msg}")
result = TaskResult(
task_id=task.id,
success=False,
data=None,
error=error_msg,
pool_type=pool_type,
)
stats.failed += 1
stats.errors.append(f"{task.id}: {error_msg}")
if on_error is not None:
try:
on_error(task.id, TimeoutError(error_msg))
except Exception as e:
logger.warning(f"on_error callback failed: {e}")
except Exception as e:
error_msg = str(e)
logger.error(f"[{pool_type}] Task {task.id} failed: {error_msg}")
result = TaskResult(
task_id=task.id,
success=False,
data=None,
error=error_msg,
pool_type=pool_type,
)
stats.failed += 1
stats.errors.append(f"{task.id}: {error_msg}")
if on_error is not None:
try:
on_error(task.id, e)
except Exception as callback_error:
logger.warning(f"on_error callback failed: {callback_error}")
results.append(result)
completed += 1
if on_progress is not None:
try:
on_progress(completed, stats.total)
except Exception as e:
logger.warning(f"on_progress callback failed: {e}")
# Log final stats
logger.info(
f"Batch complete: {stats.successful}/{stats.total} successful "
f"({stats.success_rate:.1f}%), {stats.failed} failed"
)
if stats.cpu_submitted > 0:
logger.info(f"CPU: {stats.cpu_submitted} tasks, {stats.cpu_time:.2f}s total")
if stats.gpu_submitted > 0:
logger.info(f"GPU: {stats.gpu_submitted} tasks, {stats.gpu_time:.2f}s total")
return results
def process_single(
self,
document: dict,
cpu_task_fn: Callable[[dict], Any],
gpu_task_fn: Callable[[dict], Any],
id_field: str = "id",
) -> TaskResult:
"""
Process a single document.
Convenience method for processing one document at a time.
Args:
document: Document info dict.
cpu_task_fn: Function for text PDF processing.
gpu_task_fn: Function for scanned PDF processing.
id_field: Field name for task ID.
Returns:
TaskResult for the document.
"""
results = self.process_batch(
documents=[document],
cpu_task_fn=cpu_task_fn,
gpu_task_fn=gpu_task_fn,
id_field=id_field,
)
return results[0] if results else TaskResult(
task_id=str(document.get(id_field, "unknown")),
success=False,
data=None,
error="No result returned",
)
@property
def is_running(self) -> bool:
"""Check if both pools are running."""
return (
self._started
and self._cpu_pool is not None
and self._gpu_pool is not None
and self._cpu_pool.is_running
and self._gpu_pool.is_running
)

View File

@@ -0,0 +1,110 @@
"""
GPU Worker Pool for OCR processing.
This pool handles GPU-bound tasks like PaddleOCR for scanned PDF processing.
"""
from __future__ import annotations
import logging
import os
from typing import Any, Callable, Optional
from training.processing.worker_pool import WorkerPool
logger = logging.getLogger(__name__)
# Global OCR instance for GPU workers (initialized once per process)
_ocr_instance: Optional[Any] = None
_gpu_initialized: bool = False
def _init_gpu_worker(gpu_id: int = 0) -> None:
"""
Initialize a GPU worker process with PaddleOCR.
Args:
gpu_id: GPU device ID to use.
"""
global _ocr_instance, _gpu_initialized
# Set GPU device before importing paddle
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# Reduce logging noise
os.environ["GLOG_minloglevel"] = "2"
# Suppress PaddleX warnings
import warnings
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
warnings.filterwarnings("ignore", message=".*reinitialization.*")
try:
# Import PaddleOCR after setting environment
# PaddleOCR 3.x uses paddle.set_device() for GPU control, not use_gpu param
import paddle
paddle.set_device(f"gpu:{gpu_id}")
from paddleocr import PaddleOCR
# PaddleOCR 3.x init - minimal params, GPU controlled via paddle.set_device
_ocr_instance = PaddleOCR(lang="en")
_gpu_initialized = True
logger.info(f"GPU worker initialized on GPU {gpu_id} in process {os.getpid()}")
except Exception as e:
logger.error(f"Failed to initialize GPU worker: {e}")
raise
def get_ocr_instance() -> Any:
"""
Get the initialized OCR instance for the current worker.
Returns:
PaddleOCR instance.
Raises:
RuntimeError: If OCR is not initialized.
"""
global _ocr_instance
if _ocr_instance is None:
raise RuntimeError("OCR not initialized. This function must be called from a GPU worker.")
return _ocr_instance
class GPUWorkerPool(WorkerPool):
"""
Worker pool for GPU-bound OCR tasks.
Handles scanned PDF processing using PaddleOCR with GPU acceleration.
Typically limited to 1 worker to avoid GPU memory conflicts.
Example:
with GPUWorkerPool(max_workers=1, gpu_id=0) as pool:
future = pool.submit(process_scanned_pdf, pdf_path)
result = future.result()
"""
def __init__(
self,
max_workers: int = 1,
gpu_id: int = 0,
) -> None:
"""
Initialize GPU worker pool.
Args:
max_workers: Number of GPU worker processes.
Defaults to 1 to avoid GPU memory conflicts.
gpu_id: GPU device ID to use.
"""
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
def get_initializer(self) -> Optional[Callable[..., None]]:
"""Return the GPU worker initializer."""
return _init_gpu_worker
def get_init_args(self) -> tuple:
"""Return args for GPU initializer."""
return (self.gpu_id,)

View File

@@ -0,0 +1,174 @@
"""
Task Dispatcher for classifying and routing tasks to appropriate worker pools.
Determines whether a document should be processed by CPU (text PDF) or
GPU (scanned PDF requiring OCR) workers.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, List, Tuple
logger = logging.getLogger(__name__)
class TaskType(Enum):
"""Task type classification."""
CPU = auto() # Text PDF - no OCR needed
GPU = auto() # Scanned PDF - requires OCR
@dataclass
class Task:
"""
Represents a processing task.
Attributes:
id: Unique task identifier.
task_type: Whether task needs CPU or GPU processing.
data: Task payload (document info, paths, etc.).
"""
id: str
task_type: TaskType
data: Any
class TaskDispatcher:
"""
Classifies and partitions tasks for CPU and GPU worker pools.
Uses PDF characteristics to determine if OCR is needed:
- Text PDFs with extractable text -> CPU
- Scanned PDFs / image-based PDFs -> GPU (OCR)
Example:
dispatcher = TaskDispatcher()
tasks = [Task(id="1", task_type=dispatcher.classify(doc), data=doc) for doc in docs]
cpu_tasks, gpu_tasks = dispatcher.partition_tasks(tasks)
"""
def __init__(
self,
text_char_threshold: int = 100,
ocr_ratio_threshold: float = 0.3,
) -> None:
"""
Initialize the task dispatcher.
Args:
text_char_threshold: Minimum characters to consider as text PDF.
ocr_ratio_threshold: If text/expected ratio below this, use OCR.
"""
self.text_char_threshold = text_char_threshold
self.ocr_ratio_threshold = ocr_ratio_threshold
def classify_by_pdf_info(
self,
has_text: bool,
text_length: int,
page_count: int = 1,
) -> TaskType:
"""
Classify task based on PDF text extraction info.
Args:
has_text: Whether PDF has extractable text layer.
text_length: Number of characters extracted.
page_count: Number of pages in PDF.
Returns:
TaskType.CPU for text PDFs, TaskType.GPU for scanned PDFs.
"""
if not has_text:
return TaskType.GPU
# Check if text density is reasonable
avg_chars_per_page = text_length / max(page_count, 1)
if avg_chars_per_page < self.text_char_threshold:
return TaskType.GPU
return TaskType.CPU
def classify_document(self, doc_info: dict) -> TaskType:
"""
Classify a document based on its metadata.
Args:
doc_info: Document information dict with keys like:
- 'is_scanned': bool (if known)
- 'text_length': int
- 'page_count': int
- 'pdf_path': str
Returns:
TaskType for the document.
"""
# If explicitly marked as scanned
if doc_info.get("is_scanned", False):
return TaskType.GPU
# If we have text extraction info
text_length = doc_info.get("text_length", 0)
page_count = doc_info.get("page_count", 1)
has_text = doc_info.get("has_text", text_length > 0)
return self.classify_by_pdf_info(
has_text=has_text,
text_length=text_length,
page_count=page_count,
)
def partition_tasks(
self,
tasks: List[Task],
) -> Tuple[List[Task], List[Task]]:
"""
Partition tasks into CPU and GPU groups.
Args:
tasks: List of Task objects with task_type set.
Returns:
Tuple of (cpu_tasks, gpu_tasks).
"""
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
logger.info(
f"Task partition: {len(cpu_tasks)} CPU tasks, {len(gpu_tasks)} GPU tasks"
)
return cpu_tasks, gpu_tasks
def create_tasks(
self,
documents: List[dict],
id_field: str = "id",
) -> List[Task]:
"""
Create Task objects from document dicts.
Args:
documents: List of document info dicts.
id_field: Field name to use as task ID.
Returns:
List of Task objects with types classified.
"""
tasks = []
for doc in documents:
task_id = str(doc.get(id_field, id(doc)))
task_type = self.classify_document(doc)
tasks.append(Task(id=task_id, task_type=task_type, data=doc))
cpu_count = sum(1 for t in tasks if t.task_type == TaskType.CPU)
gpu_count = len(tasks) - cpu_count
logger.debug(f"Created {len(tasks)} tasks: {cpu_count} CPU, {gpu_count} GPU")
return tasks

View File

@@ -0,0 +1,182 @@
"""
Abstract base class for worker pools.
Provides a unified interface for CPU and GPU worker pools with proper
initialization, task submission, and resource cleanup.
"""
from __future__ import annotations
import logging
import multiprocessing as mp
from abc import ABC, abstractmethod
from concurrent.futures import Future, ProcessPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
@dataclass
class TaskResult:
"""Container for task execution results."""
task_id: str
success: bool
data: Any
error: Optional[str] = None
processing_time: float = 0.0
pool_type: str = ""
extra: dict = field(default_factory=dict)
class WorkerPool(ABC):
"""
Abstract base class for worker pools.
Provides a common interface for ProcessPoolExecutor-based worker pools
with proper initialization using the 'spawn' start method for CUDA
compatibility.
Attributes:
max_workers: Maximum number of worker processes.
use_gpu: Whether this pool uses GPU resources.
gpu_id: GPU device ID (only relevant if use_gpu=True).
"""
def __init__(
self,
max_workers: int,
use_gpu: bool = False,
gpu_id: int = 0,
) -> None:
"""
Initialize the worker pool configuration.
Args:
max_workers: Maximum number of worker processes.
use_gpu: Whether this pool uses GPU resources.
gpu_id: GPU device ID for GPU pools.
"""
self.max_workers = max_workers
self.use_gpu = use_gpu
self.gpu_id = gpu_id
self._executor: Optional[ProcessPoolExecutor] = None
self._started = False
@property
def name(self) -> str:
"""Return the pool name for logging."""
return self.__class__.__name__
@abstractmethod
def get_initializer(self) -> Optional[Callable[..., None]]:
"""
Return the worker initialization function.
This function is called once per worker process when it starts.
Use it to load models, set environment variables, etc.
Returns:
Callable to initialize each worker, or None if no initialization needed.
"""
pass
@abstractmethod
def get_init_args(self) -> tuple:
"""
Return arguments for the initializer function.
Returns:
Tuple of arguments to pass to the initializer.
"""
pass
def start(self) -> None:
"""
Start the worker pool.
Creates a ProcessPoolExecutor with the 'spawn' start method
for CUDA compatibility.
Raises:
RuntimeError: If the pool is already started.
"""
if self._started:
raise RuntimeError(f"{self.name} is already started")
# Use 'spawn' for CUDA compatibility
ctx = mp.get_context("spawn")
initializer = self.get_initializer()
initargs = self.get_init_args()
logger.info(
f"Starting {self.name} with {self.max_workers} workers "
f"(GPU: {self.use_gpu}, GPU ID: {self.gpu_id})"
)
self._executor = ProcessPoolExecutor(
max_workers=self.max_workers,
mp_context=ctx,
initializer=initializer,
initargs=initargs if initializer else (),
)
self._started = True
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Future:
"""
Submit a task to the worker pool.
Args:
fn: Function to execute.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
Future representing the pending result.
Raises:
RuntimeError: If the pool is not started.
"""
if not self._started or self._executor is None:
raise RuntimeError(f"{self.name} is not started. Call start() first.")
return self._executor.submit(fn, *args, **kwargs)
def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None:
"""
Shutdown the worker pool.
Args:
wait: If True, wait for all pending futures to complete.
cancel_futures: If True, cancel all pending futures.
"""
if self._executor is not None:
logger.info(f"Shutting down {self.name} (wait={wait})")
self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
self._executor = None
self._started = False
@property
def is_running(self) -> bool:
"""Check if the pool is currently running."""
return self._started and self._executor is not None
def __enter__(self) -> "WorkerPool":
"""Context manager entry - start the pool."""
self.start()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit - shutdown the pool."""
self.shutdown(wait=True)
def __repr__(self) -> str:
status = "running" if self.is_running else "stopped"
return (
f"{self.__class__.__name__}("
f"workers={self.max_workers}, "
f"gpu={self.use_gpu}, "
f"status={status})"
)

View File

@@ -0,0 +1,5 @@
from .annotation_generator import AnnotationGenerator, generate_annotations
from .dataset_builder import DatasetBuilder
from .db_dataset import DBYOLODataset, create_datasets
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder', 'DBYOLODataset', 'create_datasets']

View File

@@ -0,0 +1,386 @@
"""
YOLO Annotation Generator
Generates YOLO format annotations from matched fields.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import csv
# Field class mapping for YOLO
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
FIELD_CLASSES = {
'InvoiceNumber': 0,
'InvoiceDate': 1,
'InvoiceDueDate': 2,
'OCR': 3,
'Bankgiro': 4,
'Plusgiro': 5,
'Amount': 6,
'supplier_organisation_number': 7,
'customer_number': 8,
'payment_line': 9, # Machine code payment line at bottom of invoice
}
# Fields that need matching but map to other YOLO classes
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
ACCOUNT_FIELD_MAPPING = {
'supplier_accounts': {
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
}
}
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
'supplier_org_number',
'customer_number',
'payment_line', # Machine code payment line at bottom of invoice
]
@dataclass
class YOLOAnnotation:
"""Represents a single YOLO annotation."""
class_id: int
x_center: float # normalized 0-1
y_center: float # normalized 0-1
width: float # normalized 0-1
height: float # normalized 0-1
confidence: float = 1.0
def to_string(self) -> str:
"""Convert to YOLO format string."""
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
class AnnotationGenerator:
"""Generates YOLO annotations from document matches."""
def __init__(
self,
min_confidence: float = 0.7,
bbox_padding_px: int = 20, # Absolute padding in pixels
min_bbox_height_px: int = 30 # Minimum bbox height
):
"""
Initialize annotation generator.
Args:
min_confidence: Minimum match score to include in training
bbox_padding_px: Absolute padding in pixels to add around bboxes
min_bbox_height_px: Minimum bbox height in pixels
"""
self.min_confidence = min_confidence
self.bbox_padding_px = bbox_padding_px
self.min_bbox_height_px = min_bbox_height_px
def generate_from_matches(
self,
matches: dict[str, list[Any]], # field_name -> list of Match
image_width: float,
image_height: float,
dpi: int = 300
) -> list[YOLOAnnotation]:
"""
Generate YOLO annotations from field matches.
Args:
matches: Dict of field_name -> list of Match objects
image_width: Width of the rendered image in pixels
image_height: Height of the rendered image in pixels
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
Returns:
List of YOLOAnnotation objects
"""
annotations = []
# Scale factor to convert PDF points (72 DPI) to rendered pixels
scale = dpi / 72.0
for field_name, field_matches in matches.items():
if field_name not in FIELD_CLASSES:
continue
class_id = FIELD_CLASSES[field_name]
# Take only the best match per field
if field_matches:
best_match = field_matches[0] # Already sorted by score
if best_match.score < self.min_confidence:
continue
# best_match.bbox is in PDF points (72 DPI), convert to pixels
x0, y0, x1, y1 = best_match.bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Add absolute padding
pad = self.bbox_padding_px
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(image_width, x1 + pad)
y1 = min(image_height, y1 + pad)
# Ensure minimum height
current_height = y1 - y0
if current_height < self.min_bbox_height_px:
extra = (self.min_bbox_height_px - current_height) / 2
y0 = max(0, y0 - extra)
y1 = min(image_height, y1 + extra)
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
# Clamp values to 0-1
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
annotations.append(YOLOAnnotation(
class_id=class_id,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
confidence=best_match.score
))
return annotations
def add_payment_line_annotation(
self,
annotations: list[YOLOAnnotation],
payment_line_bbox: tuple[float, float, float, float],
confidence: float,
image_width: float,
image_height: float,
dpi: int = 300
) -> list[YOLOAnnotation]:
"""
Add payment_line annotation from machine code parser result.
Args:
annotations: Existing list of annotations to append to
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
confidence: Confidence score from machine code parser
image_width: Width of the rendered image in pixels
image_height: Height of the rendered image in pixels
dpi: DPI used for rendering
Returns:
Updated annotations list with payment_line annotation added
"""
if not payment_line_bbox or confidence < self.min_confidence:
return annotations
# Scale factor to convert PDF points (72 DPI) to rendered pixels
scale = dpi / 72.0
x0, y0, x1, y1 = payment_line_bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Add absolute padding
pad = self.bbox_padding_px
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(image_width, x1 + pad)
y1 = min(image_height, y1 + pad)
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
# Clamp values to 0-1
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
annotations.append(YOLOAnnotation(
class_id=FIELD_CLASSES['payment_line'],
x_center=x_center,
y_center=y_center,
width=width,
height=height,
confidence=confidence
))
return annotations
def save_annotations(
self,
annotations: list[YOLOAnnotation],
output_path: str | Path
) -> None:
"""Save annotations to a YOLO format text file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for ann in annotations:
f.write(ann.to_string() + '\n')
@staticmethod
def generate_classes_file(output_path: str | Path) -> None:
"""Generate the classes.txt file for YOLO."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for class_name in CLASS_NAMES:
f.write(class_name + '\n')
@staticmethod
def generate_yaml_config(
output_path: str | Path,
train_path: str = 'train/images',
val_path: str = 'val/images',
test_path: str = 'test/images',
use_wsl_paths: bool | None = None
) -> None:
"""
Generate YOLO dataset YAML configuration.
Args:
output_path: Path to output YAML file
train_path: Relative path to training images
val_path: Relative path to validation images
test_path: Relative path to test images
use_wsl_paths: If True, convert Windows paths to WSL format.
If None, auto-detect based on environment.
"""
import os
import platform
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
dataset_dir = output_path.parent.absolute()
dataset_path_str = str(dataset_dir)
# Auto-detect WSL environment
if use_wsl_paths is None:
# Check if running inside WSL
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
use_wsl_paths = is_wsl
# Convert path format based on environment
if use_wsl_paths:
# Running in WSL: convert Windows paths to /mnt/c/... format
dataset_path_str = dataset_path_str.replace('\\', '/')
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
drive = dataset_path_str[0].lower()
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
elif platform.system() == 'Windows':
# Running on native Windows: use forward slashes for YOLO compatibility
dataset_path_str = dataset_path_str.replace('\\', '/')
config = f"""# Invoice Field Detection Dataset
path: {dataset_path_str}
train: {train_path}
val: {val_path}
test: {test_path}
# Classes
names:
"""
for i, name in enumerate(CLASS_NAMES):
config += f" {i}: {name}\n"
with open(output_path, 'w') as f:
f.write(config)
def generate_annotations(
pdf_path: str | Path,
structured_data: dict[str, Any],
output_dir: str | Path,
dpi: int = 300
) -> list[Path]:
"""
Generate YOLO annotations for a PDF using structured data.
Args:
pdf_path: Path to the PDF file
structured_data: Dict with field values from CSV
output_dir: Directory to save images and labels
dpi: Resolution for rendering
Returns:
List of paths to generated annotation files
"""
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from shared.pdf.renderer import get_render_dimensions
from shared.ocr import OCREngine
from shared.matcher import FieldMatcher
from shared.normalize import normalize_field
output_dir = Path(output_dir)
images_dir = output_dir / 'images'
labels_dir = output_dir / 'labels'
images_dir.mkdir(parents=True, exist_ok=True)
labels_dir.mkdir(parents=True, exist_ok=True)
generator = AnnotationGenerator()
matcher = FieldMatcher()
annotation_files = []
# Check PDF type
use_ocr = not is_text_pdf(pdf_path)
# Initialize OCR if needed
ocr_engine = OCREngine() if use_ocr else None
# Process each page
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
# Get image dimensions
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Extract tokens
if use_ocr:
from PIL import Image
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
else:
tokens = list(extract_text_tokens(pdf_path, page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = structured_data.get(field_name)
if value is None or str(value).strip() == '':
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
matches[field_name] = field_matches
# Generate annotations (pass DPI for coordinate conversion)
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
# Save annotations
if annotations:
label_path = labels_dir / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
annotation_files.append(label_path)
return annotation_files

View File

@@ -0,0 +1,249 @@
"""
YOLO Dataset Builder
Builds training dataset from PDFs and structured CSV data.
"""
import csv
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
import random
@dataclass
class DatasetStats:
"""Statistics about the generated dataset."""
total_documents: int
successful: int
failed: int
total_annotations: int
annotations_per_class: dict[str, int]
train_count: int
val_count: int
test_count: int
class DatasetBuilder:
"""Builds YOLO training dataset from PDFs and CSV data."""
def __init__(
self,
pdf_dir: str | Path,
csv_path: str | Path,
output_dir: str | Path,
document_id_column: str = 'DocumentId',
dpi: int = 300,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1
):
"""
Initialize dataset builder.
Args:
pdf_dir: Directory containing PDF files
csv_path: Path to structured data CSV
output_dir: Output directory for dataset
document_id_column: Column name for document ID
dpi: Resolution for rendering
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
test_ratio: Fraction for test set
"""
self.pdf_dir = Path(pdf_dir)
self.csv_path = Path(csv_path)
self.output_dir = Path(output_dir)
self.document_id_column = document_id_column
self.dpi = dpi
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
def load_structured_data(self) -> dict[str, dict]:
"""Load structured data from CSV."""
data = {}
with open(self.csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row.get(self.document_id_column)
if doc_id:
data[doc_id] = row
return data
def find_pdf_for_document(self, doc_id: str) -> Path | None:
"""Find PDF file for a document ID."""
# Try common naming patterns
patterns = [
f"{doc_id}.pdf",
f"{doc_id.lower()}.pdf",
f"{doc_id.upper()}.pdf",
f"*{doc_id}*.pdf",
]
for pattern in patterns:
matches = list(self.pdf_dir.glob(pattern))
if matches:
return matches[0]
return None
def build(self, seed: int = 42) -> DatasetStats:
"""
Build the complete dataset.
Args:
seed: Random seed for train/val/test split
Returns:
DatasetStats with build results
"""
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
random.seed(seed)
# Setup output directories
for split in ['train', 'val', 'test']:
(self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
(self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
# Generate config files
AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml')
# Load structured data
structured_data = self.load_structured_data()
# Stats tracking
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'annotations': 0,
'per_class': {name: 0 for name in CLASS_NAMES},
'splits': {'train': 0, 'val': 0, 'test': 0}
}
# Process each document
processed_items = []
for doc_id, data in structured_data.items():
stats['total'] += 1
pdf_path = self.find_pdf_for_document(doc_id)
if not pdf_path:
print(f"Warning: PDF not found for document {doc_id}")
stats['failed'] += 1
continue
try:
# Generate to temp dir first
temp_dir = self.output_dir / 'temp' / doc_id
temp_dir.mkdir(parents=True, exist_ok=True)
from .annotation_generator import generate_annotations
annotation_files = generate_annotations(
pdf_path, data, temp_dir, self.dpi
)
if annotation_files:
processed_items.append({
'doc_id': doc_id,
'temp_dir': temp_dir,
'annotation_files': annotation_files
})
stats['successful'] += 1
else:
print(f"Warning: No annotations generated for {doc_id}")
stats['failed'] += 1
except Exception as e:
print(f"Error processing {doc_id}: {e}")
stats['failed'] += 1
# Shuffle and split
random.shuffle(processed_items)
n_train = int(len(processed_items) * self.train_ratio)
n_val = int(len(processed_items) * self.val_ratio)
splits = {
'train': processed_items[:n_train],
'val': processed_items[n_train:n_train + n_val],
'test': processed_items[n_train + n_val:]
}
# Move files to final locations
for split_name, items in splits.items():
for item in items:
temp_dir = item['temp_dir']
# Move images
for img in (temp_dir / 'images').glob('*'):
dest = self.output_dir / split_name / 'images' / img.name
shutil.move(str(img), str(dest))
# Move labels and count annotations
for label in (temp_dir / 'labels').glob('*.txt'):
dest = self.output_dir / split_name / 'labels' / label.name
shutil.move(str(label), str(dest))
# Count annotations per class
with open(dest, 'r') as f:
for line in f:
class_id = int(line.strip().split()[0])
if 0 <= class_id < len(CLASS_NAMES):
stats['per_class'][CLASS_NAMES[class_id]] += 1
stats['annotations'] += 1
stats['splits'][split_name] += 1
# Cleanup temp dir
shutil.rmtree(self.output_dir / 'temp', ignore_errors=True)
return DatasetStats(
total_documents=stats['total'],
successful=stats['successful'],
failed=stats['failed'],
total_annotations=stats['annotations'],
annotations_per_class=stats['per_class'],
train_count=stats['splits']['train'],
val_count=stats['splits']['val'],
test_count=stats['splits']['test']
)
def process_single_document(
self,
doc_id: str,
data: dict,
split: str = 'train'
) -> bool:
"""
Process a single document (for incremental building).
Args:
doc_id: Document ID
data: Structured data dict
split: Which split to add to
Returns:
True if successful
"""
from .annotation_generator import generate_annotations
pdf_path = self.find_pdf_for_document(doc_id)
if not pdf_path:
return False
try:
output_subdir = self.output_dir / split
annotation_files = generate_annotations(
pdf_path, data, output_subdir, self.dpi
)
return len(annotation_files) > 0
except Exception as e:
print(f"Error processing {doc_id}: {e}")
return False

View File

@@ -0,0 +1,728 @@
"""
Database-backed YOLO Dataset
Loads images from filesystem and labels from PostgreSQL database.
Generates YOLO format labels dynamically at training time.
"""
from __future__ import annotations
import logging
import random
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Optional
import numpy as np
from PIL import Image
from shared.config import DEFAULT_DPI
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
logger = logging.getLogger(__name__)
# Module-level LRU cache for image loading (shared across dataset instances)
@lru_cache(maxsize=256)
def _load_image_cached(image_path: str) -> tuple[np.ndarray, int, int]:
"""
Load and cache image from disk.
Args:
image_path: Path to image file (must be string for hashability)
Returns:
Tuple of (image_array, width, height)
"""
image = Image.open(image_path).convert('RGB')
width, height = image.size
image_array = np.array(image)
return image_array, width, height
def clear_image_cache():
"""Clear the image cache to free memory."""
_load_image_cached.cache_clear()
@dataclass
class DatasetItem:
"""Single item in the dataset."""
document_id: str
image_path: Path
page_no: int
labels: list[YOLOAnnotation]
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
csv_split: str | None = None # CSV-defined split ('train', 'test', etc.)
class DBYOLODataset:
"""
YOLO Dataset that reads labels from database.
This dataset:
1. Scans temp directory for rendered images
2. Queries database for bbox data
3. Generates YOLO labels dynamically
4. Performs train/val/test split at runtime
"""
def __init__(
self,
images_dir: str | Path,
db: Any, # DocumentDB instance
split: str = 'train',
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
dpi: int = DEFAULT_DPI, # Must match the DPI used in autolabel_tasks.py for rendering
min_confidence: float = 0.7,
bbox_padding_px: int = 20,
min_bbox_height_px: int = 30,
limit: int | None = None,
):
"""
Initialize database-backed YOLO dataset.
Args:
images_dir: Directory containing temp/{doc_id}/images/*.png
db: DocumentDB instance for label queries
split: Which split to use ('train', 'val', 'test')
train_ratio: Ratio for training set
val_ratio: Ratio for validation set
seed: Random seed for reproducible splits
dpi: DPI used for rendering (for coordinate conversion)
min_confidence: Minimum match score to include
bbox_padding_px: Padding around bboxes
min_bbox_height_px: Minimum bbox height
limit: Maximum number of documents to use (None for all)
"""
self.images_dir = Path(images_dir)
self.db = db
self.split = split
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.seed = seed
self.dpi = dpi
self.min_confidence = min_confidence
self.bbox_padding_px = bbox_padding_px
self.min_bbox_height_px = min_bbox_height_px
self.limit = limit
# Load and split dataset
self.items: list[DatasetItem] = []
self._all_items: list[DatasetItem] = [] # Cache all items for sharing
self._doc_ids_ordered: list[str] = [] # Cache ordered doc IDs for consistent splits
self._load_dataset()
@classmethod
def from_shared_data(
cls,
source_dataset: 'DBYOLODataset',
split: str,
) -> 'DBYOLODataset':
"""
Create a new dataset instance sharing data from an existing one.
This avoids re-loading data from filesystem and database.
Args:
source_dataset: Dataset to share data from
split: Which split to use ('train', 'val', 'test')
Returns:
New dataset instance with shared data
"""
# Create instance without loading (we'll share data)
instance = object.__new__(cls)
# Copy configuration
instance.images_dir = source_dataset.images_dir
instance.db = source_dataset.db
instance.split = split
instance.train_ratio = source_dataset.train_ratio
instance.val_ratio = source_dataset.val_ratio
instance.seed = source_dataset.seed
instance.dpi = source_dataset.dpi
instance.min_confidence = source_dataset.min_confidence
instance.bbox_padding_px = source_dataset.bbox_padding_px
instance.min_bbox_height_px = source_dataset.min_bbox_height_px
instance.limit = source_dataset.limit
# Share loaded data
instance._all_items = source_dataset._all_items
instance._doc_ids_ordered = source_dataset._doc_ids_ordered
# Split items for this split
instance.items = instance._split_dataset_from_cache()
print(f"Split '{split}': {len(instance.items)} items")
return instance
def _load_dataset(self):
"""Load dataset items from filesystem and database."""
# Find all document directories
temp_dir = self.images_dir / 'temp'
if not temp_dir.exists():
print(f"Temp directory not found: {temp_dir}")
return
# Collect all document IDs with images
doc_image_map: dict[str, list[Path]] = {}
for doc_dir in temp_dir.iterdir():
if not doc_dir.is_dir():
continue
images_path = doc_dir / 'images'
if not images_path.exists():
continue
images = list(images_path.glob('*.png'))
if images:
doc_image_map[doc_dir.name] = sorted(images)
print(f"Found {len(doc_image_map)} documents with images")
# Query database for all document labels
doc_ids = list(doc_image_map.keys())
doc_labels = self._load_labels_from_db(doc_ids)
print(f"Loaded labels for {len(doc_labels)} documents from database")
# Build dataset items
all_items: list[DatasetItem] = []
skipped_no_labels = 0
skipped_no_db_record = 0
total_images = 0
for doc_id, images in doc_image_map.items():
doc_data = doc_labels.get(doc_id)
# Skip documents that don't exist in database
if doc_data is None:
skipped_no_db_record += len(images)
total_images += len(images)
continue
labels_by_page, is_scanned, csv_split = doc_data
for image_path in images:
total_images += 1
# Extract page number from filename (e.g., "doc_page_000.png")
page_no = self._extract_page_no(image_path.stem)
# Get labels for this page
page_labels = labels_by_page.get(page_no, [])
if page_labels: # Only include pages with labels
all_items.append(DatasetItem(
document_id=doc_id,
image_path=image_path,
page_no=page_no,
labels=page_labels,
is_scanned=is_scanned,
csv_split=csv_split
))
else:
skipped_no_labels += 1
print(f"Total images found: {total_images}")
print(f"Images with labels: {len(all_items)}")
if skipped_no_db_record > 0:
print(f"Skipped {skipped_no_db_record} images (document not in database)")
if skipped_no_labels > 0:
print(f"Skipped {skipped_no_labels} images (no labels for page)")
# Cache all items for sharing with other splits
self._all_items = all_items
# Split dataset
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
print(f"Split '{self.split}': {len(self.items)} items")
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
"""
Load labels from database for given document IDs using batch queries.
Returns:
Dict of doc_id -> (page_labels, is_scanned, split)
where page_labels is {page_no -> list[YOLOAnnotation]}
is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
split is the CSV-defined split ('train', 'test', etc.) or None
"""
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]] = {}
# Query in batches using efficient batch method
batch_size = 500
for i in range(0, len(doc_ids), batch_size):
batch_ids = doc_ids[i:i + batch_size]
# Use batch query instead of individual queries (N+1 fix)
docs_batch = self.db.get_documents_batch(batch_ids)
for doc_id, doc in docs_batch.items():
if not doc.get('success'):
continue
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
is_scanned = doc.get('pdf_type') == 'scanned'
# Get CSV-defined split
csv_split = doc.get('split')
page_labels: dict[int, list[YOLOAnnotation]] = {}
for field_result in doc.get('field_results', []):
if not field_result.get('matched'):
continue
field_name = field_result.get('field_name')
# Map supplier_accounts(X) to the actual class name (Bankgiro/Plusgiro)
yolo_class_name = field_name
if field_name and field_name.startswith('supplier_accounts('):
# Extract the account type: "supplier_accounts(Bankgiro)" -> "Bankgiro"
yolo_class_name = field_name.split('(')[1].rstrip(')')
if yolo_class_name not in FIELD_CLASSES:
continue
score = field_result.get('score', 0)
if score < self.min_confidence:
continue
bbox = field_result.get('bbox')
page_no = field_result.get('page_no', 0)
if bbox and len(bbox) == 4:
annotation = self._create_annotation(
field_name=yolo_class_name, # Use mapped class name
bbox=bbox,
score=score
)
if page_no not in page_labels:
page_labels[page_no] = []
page_labels[page_no].append(annotation)
if page_labels:
result[doc_id] = (page_labels, is_scanned, csv_split)
return result
def _create_annotation(
self,
field_name: str,
bbox: list[float],
score: float
) -> YOLOAnnotation:
"""
Create a YOLO annotation from bbox.
Note: bbox is in PDF points (72 DPI), will be normalized later.
"""
class_id = FIELD_CLASSES[field_name]
x0, y0, x1, y1 = bbox
# Store raw PDF coordinates - will be normalized when getting item
return YOLOAnnotation(
class_id=class_id,
x_center=(x0 + x1) / 2, # center in PDF points
y_center=(y0 + y1) / 2,
width=x1 - x0,
height=y1 - y0,
confidence=score
)
def _extract_page_no(self, stem: str) -> int:
"""Extract page number from image filename."""
# Format: "{doc_id}_page_{page_no:03d}"
parts = stem.rsplit('_', 1)
if len(parts) == 2:
try:
return int(parts[1])
except ValueError:
pass
return 0
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
"""
Split items into train/val/test based on CSV-defined split field.
If CSV has 'split' field, use it directly.
Otherwise, fall back to random splitting based on train_ratio/val_ratio.
Returns:
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
reused for consistent splits across shared datasets.
"""
# Group by document ID for proper splitting
doc_items: dict[str, list[DatasetItem]] = {}
doc_csv_split: dict[str, str | None] = {} # Track CSV split per document
for item in items:
if item.document_id not in doc_items:
doc_items[item.document_id] = []
doc_csv_split[item.document_id] = item.csv_split
doc_items[item.document_id].append(item)
# Check if we have CSV-defined splits
has_csv_splits = any(split is not None for split in doc_csv_split.values())
doc_ids = list(doc_items.keys())
if has_csv_splits:
# Use CSV-defined splits
print("Using CSV-defined split field for train/val/test assignment")
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
# 'val' is taken from train set using val_ratio
split_doc_ids = []
if self.split == 'train':
# Get documents marked as 'train' or no split defined
train_docs = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
# Take train_ratio of train docs for actual training, rest for val
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[:n_train]
elif self.split == 'val':
# Get documents marked as 'train' and take val portion
train_docs = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[n_train:]
else: # test
# Get documents marked as 'test'
split_doc_ids = [doc_id for doc_id in doc_ids
if doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
# Apply limit if specified
if self.limit is not None and self.limit < len(split_doc_ids):
split_doc_ids = split_doc_ids[:self.limit]
print(f"Limited to {self.limit} documents")
else:
# Fall back to random splitting (original behavior)
print("No CSV split field found, using random splitting")
random.seed(self.seed)
random.shuffle(doc_ids)
# Apply limit if specified (before splitting)
if self.limit is not None and self.limit < len(doc_ids):
doc_ids = doc_ids[:self.limit]
print(f"Limited to {self.limit} documents")
# Calculate split indices
n_total = len(doc_ids)
n_train = int(n_total * self.train_ratio)
n_val = int(n_total * self.val_ratio)
# Split document IDs
if self.split == 'train':
split_doc_ids = doc_ids[:n_train]
elif self.split == 'val':
split_doc_ids = doc_ids[n_train:n_train + n_val]
else: # test
split_doc_ids = doc_ids[n_train + n_val:]
# Collect items for this split
split_items = []
for doc_id in split_doc_ids:
split_items.extend(doc_items[doc_id])
return split_items, doc_ids
def _split_dataset_from_cache(self) -> list[DatasetItem]:
"""
Split items using cached data from a shared dataset.
Uses pre-computed doc_ids order for consistent splits.
Respects CSV-defined splits if available.
"""
# Group by document ID and track CSV splits
doc_items: dict[str, list[DatasetItem]] = {}
doc_csv_split: dict[str, str | None] = {}
for item in self._all_items:
if item.document_id not in doc_items:
doc_items[item.document_id] = []
doc_csv_split[item.document_id] = item.csv_split
doc_items[item.document_id].append(item)
# Check if we have CSV-defined splits
has_csv_splits = any(split is not None for split in doc_csv_split.values())
doc_ids = self._doc_ids_ordered
if has_csv_splits:
# Use CSV-defined splits
if self.split == 'train':
train_docs = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[:n_train]
elif self.split == 'val':
train_docs = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
random.seed(self.seed)
random.shuffle(train_docs)
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
split_doc_ids = train_docs[n_train:]
else: # test
split_doc_ids = [doc_id for doc_id in doc_ids
if doc_id in doc_csv_split and
doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
else:
# Fall back to random splitting
n_total = len(doc_ids)
n_train = int(n_total * self.train_ratio)
n_val = int(n_total * self.val_ratio)
if self.split == 'train':
split_doc_ids = doc_ids[:n_train]
elif self.split == 'val':
split_doc_ids = doc_ids[n_train:n_train + n_val]
else: # test
split_doc_ids = doc_ids[n_train + n_val:]
# Collect items for this split
split_items = []
for doc_id in split_doc_ids:
if doc_id in doc_items:
split_items.extend(doc_items[doc_id])
return split_items
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
"""
Get image and labels for index.
Returns:
(image, labels) where:
- image: numpy array (H, W, C)
- labels: numpy array (N, 5) with [class_id, x_center, y_center, width, height]
"""
item = self.items[idx]
# Load image using LRU cache (significant speedup during training)
image_array, img_width, img_height = _load_image_cached(str(item.image_path))
# Convert annotations to YOLO format (normalized)
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
return image_array, labels
def _convert_labels(
self,
annotations: list[YOLOAnnotation],
img_width: int,
img_height: int,
is_scanned: bool = False
) -> np.ndarray:
"""
Convert annotations to normalized YOLO format.
Args:
annotations: List of annotations
img_width: Actual image width in pixels
img_height: Actual image height in pixels
is_scanned: If True, bbox is already in pixels; if False, bbox is in PDF points
Returns:
numpy array (N, 5) with [class_id, x_center, y_center, width, height]
"""
if not annotations:
return np.zeros((0, 5), dtype=np.float32)
# Scale factor: PDF points (72 DPI) -> rendered pixels
# Note: After the OCR coordinate fix, ALL bbox (both text and scanned PDF)
# are stored in PDF points, so we always apply the same scaling.
scale = self.dpi / 72.0
labels = []
for ann in annotations:
# Convert to pixels (if needed)
x_center_px = ann.x_center * scale
y_center_px = ann.y_center * scale
width_px = ann.width * scale
height_px = ann.height * scale
# Add padding
pad = self.bbox_padding_px
width_px += 2 * pad
height_px += 2 * pad
# Ensure minimum height
if height_px < self.min_bbox_height_px:
height_px = self.min_bbox_height_px
# Normalize to 0-1
x_center = x_center_px / img_width
y_center = y_center_px / img_height
width = width_px / img_width
height = height_px / img_height
# Clamp to valid range
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
labels.append([ann.class_id, x_center, y_center, width, height])
return np.array(labels, dtype=np.float32)
def get_image_path(self, idx: int) -> Path:
"""Get image path for index."""
return self.items[idx].image_path
def get_labels_for_yolo(self, idx: int) -> str:
"""
Get YOLO format labels as string for index.
Returns:
String with YOLO format labels (one per line)
"""
item = self.items[idx]
# Use cached image loading to avoid duplicate disk reads
_, img_width, img_height = _load_image_cached(str(item.image_path))
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
lines = []
for label in labels:
class_id = int(label[0])
x_center, y_center, width, height = label[1:5]
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
return '\n'.join(lines)
def export_to_yolo_format(
self,
output_dir: str | Path,
split_name: Optional[str] = None
) -> int:
"""
Export dataset to standard YOLO format (images + label files).
This is useful for training with standard YOLO training scripts.
Args:
output_dir: Output directory
split_name: Name for the split subdirectory (default: self.split)
Returns:
Number of items exported
"""
import shutil
output_dir = Path(output_dir)
split_name = split_name or self.split
images_out = output_dir / split_name / 'images'
labels_out = output_dir / split_name / 'labels'
# Clear existing directories before export
if images_out.exists():
shutil.rmtree(images_out)
if labels_out.exists():
shutil.rmtree(labels_out)
images_out.mkdir(parents=True, exist_ok=True)
labels_out.mkdir(parents=True, exist_ok=True)
count = 0
for idx in range(len(self)):
item = self.items[idx]
# Copy image
dest_image = images_out / item.image_path.name
shutil.copy2(item.image_path, dest_image)
# Write label file
label_content = self.get_labels_for_yolo(idx)
label_path = labels_out / f"{item.image_path.stem}.txt"
with open(label_path, 'w') as f:
f.write(label_content)
count += 1
print(f"Exported {count} items to {output_dir / split_name}")
return count
def create_datasets(
images_dir: str | Path,
db: Any,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
limit: int | None = None,
**kwargs
) -> dict[str, DBYOLODataset]:
"""
Create train/val/test datasets.
This function loads data once and shares it across all splits for efficiency.
Args:
images_dir: Directory containing temp/{doc_id}/images/
db: DocumentDB instance
train_ratio: Training set ratio
val_ratio: Validation set ratio
seed: Random seed
limit: Maximum number of documents to use (None for all)
**kwargs: Additional arguments for DBYOLODataset
Returns:
Dict with 'train', 'val', 'test' datasets
"""
# Create first dataset which loads all data
print("Loading dataset (this may take a few minutes for large datasets)...")
first_dataset = DBYOLODataset(
images_dir=images_dir,
db=db,
split='train',
train_ratio=train_ratio,
val_ratio=val_ratio,
seed=seed,
limit=limit,
**kwargs
)
# Create other splits by sharing loaded data
datasets = {'train': first_dataset}
for split in ['val', 'test']:
datasets[split] = DBYOLODataset.from_shared_data(
first_dataset,
split=split,
)
return datasets