restructure project
This commit is contained in:
20
packages/training/Dockerfile
Normal file
20
packages/training/Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgl1-mesa-glx libglib2.0-0 libpq-dev gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install shared package
|
||||
COPY packages/shared /app/packages/shared
|
||||
RUN pip install --no-cache-dir -e /app/packages/shared
|
||||
|
||||
# Install training package
|
||||
COPY packages/training /app/packages/training
|
||||
RUN pip install --no-cache-dir -e /app/packages/training
|
||||
|
||||
WORKDIR /app/packages/training
|
||||
|
||||
CMD ["python", "run_training.py", "--task-id", "${TASK_ID}"]
|
||||
4
packages/training/requirements.txt
Normal file
4
packages/training/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
-e ../shared
|
||||
ultralytics>=8.1.0
|
||||
tqdm>=4.65.0
|
||||
torch>=2.0.0
|
||||
100
packages/training/run_training.py
Normal file
100
packages/training/run_training.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Service Entry Point.
|
||||
|
||||
Runs a specific training task by ID (for Azure ACI on-demand mode)
|
||||
or polls the database for pending tasks (for local dev).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from training.data.training_db import TrainingTaskDB
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
|
||||
"""Execute a single training task."""
|
||||
task_id = task["task_id"]
|
||||
config = task.get("config") or {}
|
||||
|
||||
logger.info("Starting training task %s with config: %s", task_id, config)
|
||||
db.update_status(task_id, "running")
|
||||
|
||||
try:
|
||||
from training.cli.train import run_training
|
||||
|
||||
result = run_training(
|
||||
epochs=config.get("epochs", 100),
|
||||
batch=config.get("batch_size", 16),
|
||||
model=config.get("base_model", "yolo11n.pt"),
|
||||
imgsz=config.get("imgsz", 1280),
|
||||
name=config.get("name", f"training_{task_id[:8]}"),
|
||||
)
|
||||
|
||||
db.complete_task(
|
||||
task_id,
|
||||
model_path=result.get("model_path", ""),
|
||||
metrics=result.get("metrics", {}),
|
||||
)
|
||||
logger.info("Training task %s completed successfully.", task_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Training task %s failed", task_id)
|
||||
db.fail_task(task_id, str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Invoice Training Service")
|
||||
parser.add_argument(
|
||||
"--task-id",
|
||||
help="Specific task ID to run (ACI on-demand mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--poll",
|
||||
action="store_true",
|
||||
help="Poll database for pending tasks (local dev mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--poll-interval",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Seconds between polls (default: 60)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
db = TrainingTaskDB()
|
||||
|
||||
if args.task_id:
|
||||
task = db.get_task(args.task_id)
|
||||
if not task:
|
||||
logger.error("Task %s not found", args.task_id)
|
||||
sys.exit(1)
|
||||
execute_training_task(db, task)
|
||||
|
||||
elif args.poll:
|
||||
logger.info(
|
||||
"Starting training service in poll mode (interval=%ds)",
|
||||
args.poll_interval,
|
||||
)
|
||||
while True:
|
||||
tasks = db.get_pending_tasks(limit=1)
|
||||
for task in tasks:
|
||||
execute_training_task(db, task)
|
||||
time.sleep(args.poll_interval)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
13
packages/training/setup.py
Normal file
13
packages/training/setup.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="invoice-training",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.11",
|
||||
install_requires=[
|
||||
"invoice-shared",
|
||||
"ultralytics>=8.1.0",
|
||||
"tqdm>=4.65.0",
|
||||
],
|
||||
)
|
||||
0
packages/training/training/__init__.py
Normal file
0
packages/training/training/__init__.py
Normal file
0
packages/training/training/cli/__init__.py
Normal file
0
packages/training/training/cli/__init__.py
Normal file
599
packages/training/training/cli/analyze_labels.py
Normal file
599
packages/training/training/cli/analyze_labels.py
Normal file
@@ -0,0 +1,599 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Label Analysis CLI
|
||||
|
||||
Analyzes auto-generated labels to identify failures and diagnose root causes.
|
||||
Now reads from PostgreSQL database instead of JSONL files.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
from shared.normalize import normalize_field
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.pdf import is_text_pdf, extract_text_tokens
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldAnalysis:
|
||||
"""Analysis result for a single field."""
|
||||
field_name: str
|
||||
csv_value: str
|
||||
expected: bool # True if CSV has value
|
||||
labeled: bool # True if label file has this field
|
||||
matched: bool # True if matcher finds it
|
||||
|
||||
# Diagnosis
|
||||
failure_reason: Optional[str] = None
|
||||
details: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentAnalysis:
|
||||
"""Analysis result for a document."""
|
||||
doc_id: str
|
||||
pdf_exists: bool
|
||||
pdf_type: str # "text" or "scanned"
|
||||
total_pages: int
|
||||
|
||||
# Per-field analysis
|
||||
fields: list[FieldAnalysis] = field(default_factory=list)
|
||||
|
||||
# Summary
|
||||
csv_fields_count: int = 0 # Fields with values in CSV
|
||||
labeled_fields_count: int = 0 # Fields in label files
|
||||
matched_fields_count: int = 0 # Fields matcher can find
|
||||
|
||||
@property
|
||||
def has_issues(self) -> bool:
|
||||
"""Check if document has any labeling issues."""
|
||||
return any(
|
||||
f.expected and not f.labeled
|
||||
for f in self.fields
|
||||
)
|
||||
|
||||
@property
|
||||
def missing_labels(self) -> list[FieldAnalysis]:
|
||||
"""Get fields that should be labeled but aren't."""
|
||||
return [f for f in self.fields if f.expected and not f.labeled]
|
||||
|
||||
|
||||
class LabelAnalyzer:
|
||||
"""Analyzes labels and diagnoses failures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
pdf_dir: str,
|
||||
dataset_dir: str,
|
||||
use_db: bool = True
|
||||
):
|
||||
self.csv_path = Path(csv_path)
|
||||
self.pdf_dir = Path(pdf_dir)
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
self.use_db = use_db
|
||||
|
||||
self.matcher = FieldMatcher()
|
||||
self.csv_data = {}
|
||||
self.label_data = {}
|
||||
self.report_data = {}
|
||||
|
||||
# Database connection
|
||||
self.db = None
|
||||
if use_db:
|
||||
self.db = DocumentDB()
|
||||
self.db.connect()
|
||||
|
||||
# Class ID to name mapping
|
||||
self.class_names = list(FIELD_CLASSES.keys())
|
||||
|
||||
def load_csv(self):
|
||||
"""Load CSV data."""
|
||||
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
doc_id = row['DocumentId']
|
||||
self.csv_data[doc_id] = row
|
||||
print(f"Loaded {len(self.csv_data)} records from CSV")
|
||||
|
||||
def load_labels(self):
|
||||
"""Load all label files from dataset."""
|
||||
for split in ['train', 'val', 'test']:
|
||||
label_dir = self.dataset_dir / split / 'labels'
|
||||
if not label_dir.exists():
|
||||
continue
|
||||
|
||||
for label_file in label_dir.glob('*.txt'):
|
||||
# Parse document ID from filename (uuid_page_XXX.txt)
|
||||
name = label_file.stem
|
||||
parts = name.rsplit('_page_', 1)
|
||||
if len(parts) == 2:
|
||||
doc_id = parts[0]
|
||||
page_no = int(parts[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
if doc_id not in self.label_data:
|
||||
self.label_data[doc_id] = {'pages': {}, 'split': split}
|
||||
|
||||
# Parse label file
|
||||
labels = []
|
||||
with open(label_file, 'r') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 5:
|
||||
class_id = int(parts[0])
|
||||
labels.append({
|
||||
'class_id': class_id,
|
||||
'class_name': self.class_names[class_id],
|
||||
'x_center': float(parts[1]),
|
||||
'y_center': float(parts[2]),
|
||||
'width': float(parts[3]),
|
||||
'height': float(parts[4])
|
||||
})
|
||||
|
||||
self.label_data[doc_id]['pages'][page_no] = labels
|
||||
|
||||
total_docs = len(self.label_data)
|
||||
total_labels = sum(
|
||||
len(labels)
|
||||
for doc in self.label_data.values()
|
||||
for labels in doc['pages'].values()
|
||||
)
|
||||
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
|
||||
|
||||
def load_report(self):
|
||||
"""Load autolabel report from database."""
|
||||
if not self.db:
|
||||
print("Database not configured, skipping report loading")
|
||||
return
|
||||
|
||||
# Get document IDs from CSV to query
|
||||
doc_ids = list(self.csv_data.keys())
|
||||
if not doc_ids:
|
||||
return
|
||||
|
||||
# Query in batches to avoid memory issues
|
||||
batch_size = 1000
|
||||
loaded = 0
|
||||
|
||||
for i in range(0, len(doc_ids), batch_size):
|
||||
batch_ids = doc_ids[i:i + batch_size]
|
||||
for doc_id in batch_ids:
|
||||
doc = self.db.get_document(doc_id)
|
||||
if doc:
|
||||
self.report_data[doc_id] = doc
|
||||
loaded += 1
|
||||
|
||||
print(f"Loaded {loaded} autolabel reports from database")
|
||||
|
||||
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
|
||||
"""Analyze a single document."""
|
||||
csv_row = self.csv_data.get(doc_id, {})
|
||||
label_info = self.label_data.get(doc_id, {'pages': {}})
|
||||
report = self.report_data.get(doc_id, {})
|
||||
|
||||
# Check PDF
|
||||
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
|
||||
pdf_exists = pdf_path.exists()
|
||||
|
||||
# Skip documents without PDF if requested
|
||||
if skip_missing_pdf and not pdf_exists:
|
||||
return None
|
||||
|
||||
pdf_type = "unknown"
|
||||
total_pages = 0
|
||||
|
||||
if pdf_exists:
|
||||
pdf_type = "scanned" if not is_text_pdf(pdf_path) else "text"
|
||||
total_pages = len(label_info['pages']) or report.get('total_pages', 0)
|
||||
|
||||
analysis = DocumentAnalysis(
|
||||
doc_id=doc_id,
|
||||
pdf_exists=pdf_exists,
|
||||
pdf_type=pdf_type,
|
||||
total_pages=total_pages
|
||||
)
|
||||
|
||||
# Get labeled classes
|
||||
labeled_classes = set()
|
||||
for page_labels in label_info['pages'].values():
|
||||
for label in page_labels:
|
||||
labeled_classes.add(label['class_name'])
|
||||
|
||||
# Analyze each field
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
csv_value = csv_row.get(field_name, '')
|
||||
if csv_value is None:
|
||||
csv_value = ''
|
||||
csv_value = str(csv_value).strip()
|
||||
|
||||
# Handle datetime values (remove time part)
|
||||
if ' 00:00:00' in csv_value:
|
||||
csv_value = csv_value.replace(' 00:00:00', '')
|
||||
|
||||
expected = bool(csv_value)
|
||||
labeled = field_name in labeled_classes
|
||||
|
||||
field_analysis = FieldAnalysis(
|
||||
field_name=field_name,
|
||||
csv_value=csv_value,
|
||||
expected=expected,
|
||||
labeled=labeled,
|
||||
matched=False
|
||||
)
|
||||
|
||||
if expected:
|
||||
analysis.csv_fields_count += 1
|
||||
if labeled:
|
||||
analysis.labeled_fields_count += 1
|
||||
|
||||
# Diagnose failures
|
||||
if expected and not labeled:
|
||||
field_analysis.failure_reason = self._diagnose_failure(
|
||||
doc_id, field_name, csv_value, pdf_path, pdf_type, report
|
||||
)
|
||||
field_analysis.details = self._get_failure_details(
|
||||
doc_id, field_name, csv_value, pdf_path, pdf_type
|
||||
)
|
||||
elif not expected and labeled:
|
||||
field_analysis.failure_reason = "EXTRA_LABEL"
|
||||
field_analysis.details = {'note': 'Labeled but no CSV value'}
|
||||
|
||||
analysis.fields.append(field_analysis)
|
||||
|
||||
return analysis
|
||||
|
||||
def _diagnose_failure(
|
||||
self,
|
||||
doc_id: str,
|
||||
field_name: str,
|
||||
csv_value: str,
|
||||
pdf_path: Path,
|
||||
pdf_type: str,
|
||||
report: dict
|
||||
) -> str:
|
||||
"""Diagnose why a field wasn't labeled."""
|
||||
|
||||
if not pdf_path.exists():
|
||||
return "PDF_NOT_FOUND"
|
||||
|
||||
if pdf_type == "scanned":
|
||||
return "SCANNED_PDF"
|
||||
|
||||
# Try to match now with current normalizer (not historical report)
|
||||
if pdf_path.exists() and pdf_type == "text":
|
||||
try:
|
||||
# Check all pages
|
||||
for page_no in range(10): # Max 10 pages
|
||||
try:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
if not tokens:
|
||||
break
|
||||
|
||||
normalized = normalize_field(field_name, csv_value)
|
||||
matches = self.matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
if matches:
|
||||
return "MATCHER_OK_NOW" # Would match with current normalizer
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return "VALUE_NOT_IN_PDF"
|
||||
|
||||
except Exception as e:
|
||||
return f"PDF_ERROR: {str(e)[:50]}"
|
||||
|
||||
return "UNKNOWN"
|
||||
|
||||
def _get_failure_details(
|
||||
self,
|
||||
doc_id: str,
|
||||
field_name: str,
|
||||
csv_value: str,
|
||||
pdf_path: Path,
|
||||
pdf_type: str
|
||||
) -> dict:
|
||||
"""Get detailed information about a failure."""
|
||||
details = {
|
||||
'csv_value': csv_value,
|
||||
'normalized_candidates': [],
|
||||
'pdf_tokens_sample': [],
|
||||
'potential_matches': []
|
||||
}
|
||||
|
||||
# Get normalized candidates
|
||||
try:
|
||||
details['normalized_candidates'] = normalize_field(field_name, csv_value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get PDF tokens if available
|
||||
if pdf_path.exists() and pdf_type == "text":
|
||||
try:
|
||||
tokens = list(extract_text_tokens(pdf_path, 0))[:100]
|
||||
|
||||
# Find tokens that might be related
|
||||
candidates = details['normalized_candidates']
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
# Check if any candidate is substring or similar
|
||||
for cand in candidates:
|
||||
if cand in text or text in cand:
|
||||
details['potential_matches'].append({
|
||||
'token': text,
|
||||
'candidate': cand,
|
||||
'bbox': token.bbox
|
||||
})
|
||||
break
|
||||
# Also collect date-like or number-like tokens for reference
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
if any(c.isdigit() for c in text) and len(text) >= 6:
|
||||
details['pdf_tokens_sample'].append(text)
|
||||
elif field_name == 'Amount':
|
||||
if any(c.isdigit() for c in text) and (',' in text or '.' in text or len(text) >= 4):
|
||||
details['pdf_tokens_sample'].append(text)
|
||||
|
||||
# Limit samples
|
||||
details['pdf_tokens_sample'] = details['pdf_tokens_sample'][:10]
|
||||
details['potential_matches'] = details['potential_matches'][:5]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return details
|
||||
|
||||
def run_analysis(self, limit: Optional[int] = None, skip_missing_pdf: bool = True) -> list[DocumentAnalysis]:
|
||||
"""Run analysis on all documents."""
|
||||
self.load_csv()
|
||||
self.load_labels()
|
||||
self.load_report()
|
||||
|
||||
results = []
|
||||
doc_ids = list(self.csv_data.keys())
|
||||
skipped = 0
|
||||
|
||||
for doc_id in doc_ids:
|
||||
analysis = self.analyze_document(doc_id, skip_missing_pdf=skip_missing_pdf)
|
||||
if analysis is None:
|
||||
skipped += 1
|
||||
continue
|
||||
results.append(analysis)
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} documents without PDF files")
|
||||
|
||||
return results
|
||||
|
||||
def generate_report(
|
||||
self,
|
||||
results: list[DocumentAnalysis],
|
||||
output_path: str,
|
||||
verbose: bool = False
|
||||
):
|
||||
"""Generate analysis report."""
|
||||
output = Path(output_path)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Collect statistics
|
||||
stats = {
|
||||
'total_documents': len(results),
|
||||
'documents_with_issues': 0,
|
||||
'total_expected_fields': 0,
|
||||
'total_labeled_fields': 0,
|
||||
'missing_labels': 0,
|
||||
'extra_labels': 0,
|
||||
'failure_reasons': defaultdict(int),
|
||||
'failures_by_field': defaultdict(lambda: defaultdict(int))
|
||||
}
|
||||
|
||||
issues = []
|
||||
|
||||
for analysis in results:
|
||||
stats['total_expected_fields'] += analysis.csv_fields_count
|
||||
stats['total_labeled_fields'] += analysis.labeled_fields_count
|
||||
|
||||
if analysis.has_issues:
|
||||
stats['documents_with_issues'] += 1
|
||||
|
||||
for f in analysis.fields:
|
||||
if f.expected and not f.labeled:
|
||||
stats['missing_labels'] += 1
|
||||
stats['failure_reasons'][f.failure_reason] += 1
|
||||
stats['failures_by_field'][f.field_name][f.failure_reason] += 1
|
||||
|
||||
issues.append({
|
||||
'doc_id': analysis.doc_id,
|
||||
'field': f.field_name,
|
||||
'csv_value': f.csv_value,
|
||||
'reason': f.failure_reason,
|
||||
'details': f.details if verbose else {}
|
||||
})
|
||||
elif not f.expected and f.labeled:
|
||||
stats['extra_labels'] += 1
|
||||
|
||||
# Write JSON report
|
||||
report = {
|
||||
'summary': {
|
||||
'total_documents': stats['total_documents'],
|
||||
'documents_with_issues': stats['documents_with_issues'],
|
||||
'issue_rate': f"{stats['documents_with_issues'] / stats['total_documents'] * 100:.1f}%",
|
||||
'total_expected_fields': stats['total_expected_fields'],
|
||||
'total_labeled_fields': stats['total_labeled_fields'],
|
||||
'label_coverage': f"{stats['total_labeled_fields'] / max(1, stats['total_expected_fields']) * 100:.1f}%",
|
||||
'missing_labels': stats['missing_labels'],
|
||||
'extra_labels': stats['extra_labels']
|
||||
},
|
||||
'failure_reasons': dict(stats['failure_reasons']),
|
||||
'failures_by_field': {
|
||||
field: dict(reasons)
|
||||
for field, reasons in stats['failures_by_field'].items()
|
||||
},
|
||||
'issues': issues
|
||||
}
|
||||
|
||||
with open(output, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport saved to: {output}")
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def print_summary(report: dict):
|
||||
"""Print summary to console."""
|
||||
summary = report['summary']
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LABEL ANALYSIS SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nDocuments:")
|
||||
print(f" Total: {summary['total_documents']}")
|
||||
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
|
||||
|
||||
print(f"\nFields:")
|
||||
print(f" Expected: {summary['total_expected_fields']}")
|
||||
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
|
||||
print(f" Missing: {summary['missing_labels']}")
|
||||
print(f" Extra: {summary['extra_labels']}")
|
||||
|
||||
print(f"\nFailure Reasons:")
|
||||
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
|
||||
print(f" {reason}: {count}")
|
||||
|
||||
print(f"\nFailures by Field:")
|
||||
for field, reasons in report['failures_by_field'].items():
|
||||
total = sum(reasons.values())
|
||||
print(f" {field}: {total}")
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" - {reason}: {count}")
|
||||
|
||||
# Show sample issues
|
||||
if report['issues']:
|
||||
print(f"\n" + "-" * 60)
|
||||
print("SAMPLE ISSUES (first 10)")
|
||||
print("-" * 60)
|
||||
|
||||
for issue in report['issues'][:10]:
|
||||
print(f"\n[{issue['doc_id']}] {issue['field']}")
|
||||
print(f" CSV value: {issue['csv_value']}")
|
||||
print(f" Reason: {issue['reason']}")
|
||||
|
||||
if issue.get('details'):
|
||||
details = issue['details']
|
||||
if details.get('normalized_candidates'):
|
||||
print(f" Candidates: {details['normalized_candidates'][:5]}")
|
||||
if details.get('pdf_tokens_sample'):
|
||||
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
|
||||
if details.get('potential_matches'):
|
||||
print(f" Potential matches:")
|
||||
for pm in details['potential_matches'][:3]:
|
||||
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Analyze auto-generated labels and diagnose failures'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--csv', '-c',
|
||||
default='data/structured_data/document_export_20260109_220326.csv',
|
||||
help='Path to structured data CSV file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir', '-p',
|
||||
default='data/raw_pdfs',
|
||||
help='Directory containing PDF files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset', '-d',
|
||||
default='data/dataset',
|
||||
help='Dataset directory with labels'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='reports/label_analysis.json',
|
||||
help='Output path for analysis report'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents to analyze'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Include detailed failure information'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--single', '-s',
|
||||
help='Analyze single document ID'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-db',
|
||||
action='store_true',
|
||||
help='Skip database, only analyze label files'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
analyzer = LabelAnalyzer(
|
||||
csv_path=args.csv,
|
||||
pdf_dir=args.pdf_dir,
|
||||
dataset_dir=args.dataset,
|
||||
use_db=not args.no_db
|
||||
)
|
||||
|
||||
if args.single:
|
||||
# Analyze single document
|
||||
analyzer.load_csv()
|
||||
analyzer.load_labels()
|
||||
analyzer.load_report()
|
||||
|
||||
analysis = analyzer.analyze_document(args.single)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Document: {analysis.doc_id}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"PDF exists: {analysis.pdf_exists}")
|
||||
print(f"PDF type: {analysis.pdf_type}")
|
||||
print(f"Pages: {analysis.total_pages}")
|
||||
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
|
||||
|
||||
for f in analysis.fields:
|
||||
status = "✓" if f.labeled else ("✗" if f.expected else "-")
|
||||
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
|
||||
print(f" [{status}] {f.field_name}: {value_str}")
|
||||
|
||||
if f.failure_reason:
|
||||
print(f" Reason: {f.failure_reason}")
|
||||
if f.details.get('normalized_candidates'):
|
||||
print(f" Candidates: {f.details['normalized_candidates']}")
|
||||
if f.details.get('potential_matches'):
|
||||
print(f" Potential matches in PDF:")
|
||||
for pm in f.details['potential_matches'][:3]:
|
||||
print(f" - '{pm['token']}'")
|
||||
else:
|
||||
# Full analysis
|
||||
print("Running label analysis...")
|
||||
results = analyzer.run_analysis(limit=args.limit)
|
||||
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
|
||||
print_summary(report)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
434
packages/training/training/cli/analyze_report.py
Normal file
434
packages/training/training/cli/analyze_report.py
Normal file
@@ -0,0 +1,434 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze Auto-Label Report
|
||||
|
||||
Generates statistics and insights from database or autolabel_report.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
|
||||
def load_reports_from_db() -> dict:
|
||||
"""Load statistics directly from database using SQL aggregation."""
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
stats = {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
|
||||
'by_field': defaultdict(lambda: {
|
||||
'total': 0,
|
||||
'matched': 0,
|
||||
'exact_match': 0,
|
||||
'flexible_match': 0,
|
||||
'scores': [],
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
|
||||
}),
|
||||
'errors': defaultdict(int),
|
||||
'processing_times': [],
|
||||
}
|
||||
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Overall stats
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful,
|
||||
SUM(CASE WHEN NOT success THEN 1 ELSE 0 END) as failed
|
||||
FROM documents
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
stats['total'] = row[0] or 0
|
||||
stats['successful'] = row[1] or 0
|
||||
stats['failed'] = row[2] or 0
|
||||
|
||||
# By PDF type
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
pdf_type,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful
|
||||
FROM documents
|
||||
GROUP BY pdf_type
|
||||
""")
|
||||
for row in cursor.fetchall():
|
||||
pdf_type = row[0] or 'unknown'
|
||||
stats['by_pdf_type'][pdf_type] = {
|
||||
'total': row[1] or 0,
|
||||
'successful': row[2] or 0
|
||||
}
|
||||
|
||||
# Processing times
|
||||
cursor.execute("""
|
||||
SELECT AVG(processing_time_ms), MIN(processing_time_ms), MAX(processing_time_ms)
|
||||
FROM documents
|
||||
WHERE processing_time_ms > 0
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
if row[0]:
|
||||
stats['processing_time_stats'] = {
|
||||
'avg_ms': float(row[0]),
|
||||
'min_ms': float(row[1]),
|
||||
'max_ms': float(row[2])
|
||||
}
|
||||
|
||||
# Field stats
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
field_name,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched,
|
||||
SUM(CASE WHEN matched AND score >= 0.99 THEN 1 ELSE 0 END) as exact_match,
|
||||
SUM(CASE WHEN matched AND score < 0.99 THEN 1 ELSE 0 END) as flexible_match,
|
||||
AVG(CASE WHEN matched THEN score END) as avg_score
|
||||
FROM field_results
|
||||
GROUP BY field_name
|
||||
ORDER BY field_name
|
||||
""")
|
||||
for row in cursor.fetchall():
|
||||
field_name = row[0]
|
||||
stats['by_field'][field_name] = {
|
||||
'total': row[1] or 0,
|
||||
'matched': row[2] or 0,
|
||||
'exact_match': row[3] or 0,
|
||||
'flexible_match': row[4] or 0,
|
||||
'avg_score': float(row[5]) if row[5] else 0,
|
||||
'scores': [], # Not loading individual scores for efficiency
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
|
||||
}
|
||||
|
||||
# Field stats by PDF type
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
fr.field_name,
|
||||
d.pdf_type,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN fr.matched THEN 1 ELSE 0 END) as matched
|
||||
FROM field_results fr
|
||||
JOIN documents d ON fr.document_id = d.document_id
|
||||
GROUP BY fr.field_name, d.pdf_type
|
||||
""")
|
||||
for row in cursor.fetchall():
|
||||
field_name = row[0]
|
||||
pdf_type = row[1] or 'unknown'
|
||||
if field_name in stats['by_field']:
|
||||
stats['by_field'][field_name]['by_pdf_type'][pdf_type] = {
|
||||
'total': row[2] or 0,
|
||||
'matched': row[3] or 0
|
||||
}
|
||||
|
||||
db.close()
|
||||
return stats
|
||||
|
||||
|
||||
def load_reports_from_file(report_path: str) -> list[dict]:
|
||||
"""Load all reports from JSONL file(s). Supports glob patterns."""
|
||||
path = Path(report_path)
|
||||
|
||||
# Handle glob pattern
|
||||
if '*' in str(path) or '?' in str(path):
|
||||
parent = path.parent
|
||||
pattern = path.name
|
||||
report_files = sorted(parent.glob(pattern))
|
||||
else:
|
||||
report_files = [path]
|
||||
|
||||
if not report_files:
|
||||
return []
|
||||
|
||||
print(f"Reading {len(report_files)} report file(s):")
|
||||
for f in report_files:
|
||||
print(f" - {f.name}")
|
||||
|
||||
reports = []
|
||||
for report_file in report_files:
|
||||
if not report_file.exists():
|
||||
continue
|
||||
with open(report_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
reports.append(json.loads(line))
|
||||
|
||||
return reports
|
||||
|
||||
|
||||
def analyze_reports(reports: list[dict]) -> dict:
|
||||
"""Analyze reports and generate statistics."""
|
||||
stats = {
|
||||
'total': len(reports),
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
|
||||
'by_field': defaultdict(lambda: {
|
||||
'total': 0,
|
||||
'matched': 0,
|
||||
'exact_match': 0, # score == 1.0
|
||||
'flexible_match': 0, # score < 1.0
|
||||
'scores': [],
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
|
||||
}),
|
||||
'errors': defaultdict(int),
|
||||
'processing_times': [],
|
||||
}
|
||||
|
||||
for report in reports:
|
||||
pdf_type = report.get('pdf_type') or 'unknown'
|
||||
success = report.get('success', False)
|
||||
|
||||
# Overall stats
|
||||
if success:
|
||||
stats['successful'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
# By PDF type
|
||||
stats['by_pdf_type'][pdf_type]['total'] += 1
|
||||
if success:
|
||||
stats['by_pdf_type'][pdf_type]['successful'] += 1
|
||||
|
||||
# Processing time
|
||||
proc_time = report.get('processing_time_ms', 0)
|
||||
if proc_time > 0:
|
||||
stats['processing_times'].append(proc_time)
|
||||
|
||||
# Errors
|
||||
for error in report.get('errors', []):
|
||||
stats['errors'][error] += 1
|
||||
|
||||
# Field results
|
||||
for field_result in report.get('field_results', []):
|
||||
field_name = field_result['field_name']
|
||||
matched = field_result.get('matched', False)
|
||||
score = field_result.get('score', 0.0)
|
||||
|
||||
stats['by_field'][field_name]['total'] += 1
|
||||
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['total'] += 1
|
||||
|
||||
if matched:
|
||||
stats['by_field'][field_name]['matched'] += 1
|
||||
stats['by_field'][field_name]['scores'].append(score)
|
||||
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['matched'] += 1
|
||||
|
||||
if score >= 0.99:
|
||||
stats['by_field'][field_name]['exact_match'] += 1
|
||||
else:
|
||||
stats['by_field'][field_name]['flexible_match'] += 1
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_report(stats: dict, verbose: bool = False):
|
||||
"""Print analysis report."""
|
||||
print("\n" + "=" * 60)
|
||||
print("AUTO-LABEL REPORT ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Overall stats
|
||||
print(f"\n{'OVERALL STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
total = stats['total']
|
||||
successful = stats['successful']
|
||||
failed = stats['failed']
|
||||
success_rate = successful / total * 100 if total > 0 else 0
|
||||
|
||||
print(f"Total documents: {total:>8}")
|
||||
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
|
||||
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
|
||||
|
||||
# Processing time
|
||||
if 'processing_time_stats' in stats:
|
||||
pts = stats['processing_time_stats']
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {pts['avg_ms']:>8.1f}")
|
||||
print(f" Min: {pts['min_ms']:>8.1f}")
|
||||
print(f" Max: {pts['max_ms']:>8.1f}")
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {avg_time:>8.1f}")
|
||||
print(f" Min: {min_time:>8.1f}")
|
||||
print(f" Max: {max_time:>8.1f}")
|
||||
|
||||
# By PDF type
|
||||
print(f"\n{'BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
|
||||
print("-" * 60)
|
||||
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
|
||||
type_total = type_stats['total']
|
||||
type_success = type_stats['successful']
|
||||
type_rate = type_success / type_total * 100 if type_total > 0 else 0
|
||||
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
|
||||
|
||||
# By field
|
||||
print(f"\n{'FIELD MATCH STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
|
||||
print("-" * 60)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
continue
|
||||
field_stats = stats['by_field'][field_name]
|
||||
total = field_stats['total']
|
||||
matched = field_stats['matched']
|
||||
exact = field_stats['exact_match']
|
||||
flex = field_stats['flexible_match']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
|
||||
# Handle avg_score from either DB or file analysis
|
||||
if 'avg_score' in field_stats:
|
||||
avg_score = field_stats['avg_score']
|
||||
elif field_stats['scores']:
|
||||
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
|
||||
else:
|
||||
avg_score = 0
|
||||
|
||||
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
|
||||
|
||||
# Field match by PDF type
|
||||
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
|
||||
for pdf_type in sorted(stats['by_pdf_type'].keys()):
|
||||
print(f"\n[{pdf_type.upper()}]")
|
||||
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
|
||||
print("-" * 50)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
continue
|
||||
type_stats = stats['by_field'][field_name]['by_pdf_type'].get(pdf_type, {'total': 0, 'matched': 0})
|
||||
total = type_stats['total']
|
||||
matched = type_stats['matched']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
|
||||
|
||||
# Errors
|
||||
if stats.get('errors') and verbose:
|
||||
print(f"\n{'ERRORS':^60}")
|
||||
print("-" * 60)
|
||||
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
|
||||
print(f"{count:>5}x {error[:50]}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
def export_json(stats: dict, output_path: str):
|
||||
"""Export statistics to JSON file."""
|
||||
# Convert defaultdicts to regular dicts for JSON serialization
|
||||
export_data = {
|
||||
'total': stats['total'],
|
||||
'successful': stats['successful'],
|
||||
'failed': stats['failed'],
|
||||
'by_pdf_type': dict(stats['by_pdf_type']),
|
||||
'by_field': {},
|
||||
'errors': dict(stats.get('errors', {})),
|
||||
}
|
||||
|
||||
# Processing time stats
|
||||
if 'processing_time_stats' in stats:
|
||||
export_data['processing_time_stats'] = stats['processing_time_stats']
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
export_data['processing_time_stats'] = {
|
||||
'avg_ms': sum(times) / len(times),
|
||||
'min_ms': min(times),
|
||||
'max_ms': max(times),
|
||||
'count': len(times)
|
||||
}
|
||||
|
||||
# Field stats
|
||||
for field_name, field_stats in stats['by_field'].items():
|
||||
avg_score = field_stats.get('avg_score', 0)
|
||||
if not avg_score and field_stats.get('scores'):
|
||||
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
|
||||
|
||||
export_data['by_field'][field_name] = {
|
||||
'total': field_stats['total'],
|
||||
'matched': field_stats['matched'],
|
||||
'exact_match': field_stats['exact_match'],
|
||||
'flexible_match': field_stats['flexible_match'],
|
||||
'match_rate': field_stats['matched'] / field_stats['total'] if field_stats['total'] > 0 else 0,
|
||||
'avg_score': avg_score,
|
||||
'by_pdf_type': dict(field_stats['by_pdf_type'])
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nStatistics exported to: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Analyze auto-label report'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report', '-r',
|
||||
default=None,
|
||||
help='Path to autolabel report JSONL file (uses database if not specified)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
help='Export statistics to JSON file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Show detailed error messages'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--from-file',
|
||||
action='store_true',
|
||||
help='Force reading from JSONL file instead of database'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Decide source
|
||||
use_db = not args.from_file and args.report is None
|
||||
|
||||
if use_db:
|
||||
print("Loading statistics from database...")
|
||||
stats = load_reports_from_db()
|
||||
print(f"Loaded stats for {stats['total']} documents")
|
||||
else:
|
||||
report_path = args.report or 'reports/autolabel_report.jsonl'
|
||||
path = Path(report_path)
|
||||
|
||||
# Check if file exists (handle glob patterns)
|
||||
if '*' not in str(path) and '?' not in str(path) and not path.exists():
|
||||
print(f"Error: Report file not found: {path}")
|
||||
return 1
|
||||
|
||||
print(f"Loading reports from: {report_path}")
|
||||
reports = load_reports_from_file(report_path)
|
||||
print(f"Loaded {len(reports)} reports")
|
||||
stats = analyze_reports(reports)
|
||||
|
||||
print_report(stats, verbose=args.verbose)
|
||||
|
||||
if args.output:
|
||||
export_json(stats, args.output)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
752
packages/training/training/cli/autolabel.py
Normal file
752
packages/training/training/cli/autolabel.py
Normal file
@@ -0,0 +1,752 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-labeling CLI
|
||||
|
||||
Generates YOLO training data from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import signal
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
import multiprocessing
|
||||
|
||||
# Global flag for graceful shutdown
|
||||
_shutdown_requested = False
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
"""Handle interrupt signals for graceful shutdown."""
|
||||
global _shutdown_requested
|
||||
_shutdown_requested = True
|
||||
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
||||
print("(Press Ctrl+C again to force quit)\n")
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
if sys.platform == 'win32':
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
from shared.config import get_db_connection_string, PATHS, AUTOLABEL
|
||||
|
||||
# Global OCR engine for worker processes (initialized once per worker)
|
||||
_worker_ocr_engine = None
|
||||
_worker_initialized = False
|
||||
_worker_type = None # 'cpu' or 'gpu'
|
||||
|
||||
|
||||
def _init_cpu_worker():
|
||||
"""Initialize CPU worker (no OCR engine needed)."""
|
||||
global _worker_initialized, _worker_type
|
||||
_worker_initialized = True
|
||||
_worker_type = 'cpu'
|
||||
|
||||
|
||||
def _init_gpu_worker():
|
||||
"""Initialize GPU worker with OCR engine (called once per worker)."""
|
||||
global _worker_ocr_engine, _worker_initialized, _worker_type
|
||||
|
||||
# Suppress PaddlePaddle/PaddleX reinitialization warnings
|
||||
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
|
||||
warnings.filterwarnings('ignore', message='.*reinitialization.*')
|
||||
|
||||
# Set environment variable to suppress paddle warnings
|
||||
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
|
||||
|
||||
# OCR engine will be lazily initialized on first use
|
||||
_worker_ocr_engine = None
|
||||
_worker_initialized = True
|
||||
_worker_type = 'gpu'
|
||||
|
||||
|
||||
def _init_worker():
|
||||
"""Initialize worker process with OCR engine (called once per worker).
|
||||
Legacy function for backwards compatibility.
|
||||
"""
|
||||
_init_gpu_worker()
|
||||
|
||||
|
||||
def _get_ocr_engine():
|
||||
"""Get or create OCR engine for current worker."""
|
||||
global _worker_ocr_engine
|
||||
|
||||
if _worker_ocr_engine is None:
|
||||
# Suppress warnings during OCR initialization
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
from shared.ocr import OCREngine
|
||||
_worker_ocr_engine = OCREngine()
|
||||
|
||||
return _worker_ocr_engine
|
||||
|
||||
|
||||
def _save_output_img(output_img, image_path: Path) -> None:
|
||||
"""Save OCR output_img to replace the original rendered image."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Convert numpy array to PIL Image and save
|
||||
if output_img is not None:
|
||||
img = PILImage.fromarray(output_img)
|
||||
img.save(str(image_path))
|
||||
# If output_img is None, the original image is already saved
|
||||
|
||||
|
||||
def process_single_document(args_tuple):
|
||||
"""
|
||||
Process a single document (worker function for parallel processing).
|
||||
|
||||
Args:
|
||||
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
|
||||
|
||||
Returns:
|
||||
dict with results
|
||||
"""
|
||||
import shutil
|
||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
||||
|
||||
# Import inside worker to avoid pickling issues
|
||||
from training.data.autolabel_report import AutoLabelReport
|
||||
from shared.pdf import PDFDocument
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||
|
||||
start_time = time.time()
|
||||
pdf_path = Path(pdf_path_str)
|
||||
output_dir = Path(output_dir_str)
|
||||
doc_id = row_dict['DocumentId']
|
||||
|
||||
# Clean up existing temp folder for this document (for re-matching)
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
# Store metadata fields from CSV
|
||||
report.split = row_dict.get('split')
|
||||
report.customer_number = row_dict.get('customer_number')
|
||||
report.supplier_name = row_dict.get('supplier_name')
|
||||
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
|
||||
report.supplier_accounts = row_dict.get('supplier_accounts')
|
||||
|
||||
result = {
|
||||
'doc_id': doc_id,
|
||||
'success': False,
|
||||
'pages': [],
|
||||
'report': None,
|
||||
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
try:
|
||||
# Use PDFDocument context manager for efficient PDF handling
|
||||
# Opens PDF only once, caches dimensions, handles cleanup automatically
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
# Check PDF type (uses cached document)
|
||||
use_ocr = not pdf_doc.is_text_pdf()
|
||||
report.pdf_type = "scanned" if use_ocr else "text"
|
||||
|
||||
# Skip OCR if requested
|
||||
if use_ocr and skip_ocr:
|
||||
report.errors.append("Skipped (scanned PDF)")
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result['report'] = report.to_dict()
|
||||
return result
|
||||
|
||||
# Get OCR engine from worker cache (only created once per worker)
|
||||
ocr_engine = None
|
||||
if use_ocr:
|
||||
ocr_engine = _get_ocr_engine()
|
||||
|
||||
# Process each page
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
# Render all pages and process (uses cached document handle)
|
||||
images_dir = output_dir / 'temp' / doc_id / 'images'
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
|
||||
# Get dimensions from cache (no additional PDF open)
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
if use_ocr:
|
||||
# Use extract_with_image to get both tokens and preprocessed image
|
||||
# PaddleOCR coordinates are relative to output_img, not original image
|
||||
ocr_result = ocr_engine.extract_with_image(
|
||||
str(image_path),
|
||||
page_no,
|
||||
scale_to_pdf_points=72 / dpi
|
||||
)
|
||||
tokens = ocr_result.tokens
|
||||
|
||||
# Save output_img to replace the original rendered image
|
||||
# This ensures coordinates match the saved image
|
||||
_save_output_img(ocr_result.output_img, image_path)
|
||||
|
||||
# Update image dimensions to match output_img
|
||||
if ocr_result.output_img is not None:
|
||||
img_height, img_width = ocr_result.output_img.shape[:2]
|
||||
else:
|
||||
# Use cached document for text extraction
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
|
||||
# Get page dimensions
|
||||
page = pdf_doc.doc[page_no]
|
||||
page_height = page.rect.height
|
||||
page_width = page.rect.width
|
||||
|
||||
# Use shared processing logic
|
||||
matches = {}
|
||||
annotations, ann_count = process_page(
|
||||
tokens=tokens,
|
||||
row_dict=row_dict,
|
||||
page_no=page_no,
|
||||
page_height=page_height,
|
||||
page_width=page_width,
|
||||
img_width=img_width,
|
||||
img_height=img_height,
|
||||
dpi=dpi,
|
||||
min_confidence=min_confidence,
|
||||
matches=matches,
|
||||
matched_fields=matched_fields,
|
||||
report=report,
|
||||
result_stats=result['stats'],
|
||||
)
|
||||
|
||||
if annotations:
|
||||
page_annotations.append({
|
||||
'image_path': str(image_path),
|
||||
'page_no': page_no,
|
||||
'count': ann_count
|
||||
})
|
||||
report.annotations_generated += ann_count
|
||||
|
||||
# Record unmatched fields using shared logic
|
||||
record_unmatched_fields(row_dict, matched_fields, report)
|
||||
|
||||
if page_annotations:
|
||||
result['pages'] = page_annotations
|
||||
result['success'] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result['report'] = report.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Generate YOLO annotations from PDFs and CSV data'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--csv', '-c',
|
||||
default=f"{PATHS['csv_dir']}/*.csv",
|
||||
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir', '-p',
|
||||
default=PATHS['pdf_dir'],
|
||||
help='Directory containing PDF files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default=PATHS['output_dir'],
|
||||
help='Output directory for dataset'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=AUTOLABEL['dpi'],
|
||||
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--min-confidence',
|
||||
type=float,
|
||||
default=AUTOLABEL['min_confidence'],
|
||||
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report',
|
||||
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
|
||||
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-records',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='Max records per report file for sharding (default: 10000, 0 = single file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--single',
|
||||
help='Process single document ID only'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Verbose output'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--workers', '-w',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cpu-workers',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu-workers',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip-ocr',
|
||||
action='store_true',
|
||||
help='Skip scanned PDFs (text-layer only)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents to process (for testing)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
|
||||
# Import here to avoid slow startup
|
||||
from shared.data import CSVLoader
|
||||
from training.data.autolabel_report import AutoLabelReport, FieldMatchResult, ReportWriter
|
||||
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from shared.pdf.renderer import get_render_dimensions
|
||||
from shared.ocr import OCREngine
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.normalize import normalize_field
|
||||
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
# Handle comma-separated CSV paths
|
||||
csv_input = args.csv
|
||||
if ',' in csv_input and '*' not in csv_input:
|
||||
csv_input = [p.strip() for p in csv_input.split(',')]
|
||||
|
||||
# Get list of CSV files (don't load all data at once)
|
||||
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
||||
csv_files = temp_loader.csv_paths
|
||||
pdf_dir = temp_loader.pdf_dir
|
||||
print(f"Found {len(csv_files)} CSV file(s) to process")
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
# Only create temp directory for images (no train/val/test split during labeling)
|
||||
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Report writer with optional sharding
|
||||
report_path = Path(args.report)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
|
||||
|
||||
# Database connection for checking existing documents
|
||||
from shared.data.db import DocumentDB
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
db.create_tables() # Ensure tables exist
|
||||
print("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
stats = {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'skipped': 0,
|
||||
'skipped_db': 0, # Skipped because already in DB
|
||||
'retried': 0, # Re-processed failed ones
|
||||
'annotations': 0,
|
||||
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
|
||||
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
# Track all processed items for final split (write to temp file to save memory)
|
||||
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
|
||||
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
|
||||
processed_count = 0
|
||||
seen_doc_ids = set()
|
||||
|
||||
# Batch for database updates
|
||||
db_batch = []
|
||||
DB_BATCH_SIZE = 100
|
||||
|
||||
# Helper function to handle result and update database
|
||||
# Defined outside the loop so nonlocal can properly reference db_batch
|
||||
def handle_result(result):
|
||||
nonlocal processed_count, db_batch
|
||||
|
||||
# Write report to file
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
# Add to database batch
|
||||
db_batch.append(result['report'])
|
||||
if len(db_batch) >= DB_BATCH_SIZE:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
if result['success']:
|
||||
# Write to temp file instead of memory
|
||||
import json
|
||||
processed_items_writer.write(json.dumps({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
}) + '\n')
|
||||
processed_items_writer.flush()
|
||||
processed_count += 1
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
def handle_error(doc_id, error):
|
||||
nonlocal db_batch
|
||||
stats['failed'] += 1
|
||||
error_report = {
|
||||
'document_id': doc_id,
|
||||
'success': False,
|
||||
'errors': [f"Worker error: {str(error)}"]
|
||||
}
|
||||
report_writer.write_dict(error_report)
|
||||
db_batch.append(error_report)
|
||||
if len(db_batch) >= DB_BATCH_SIZE:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {error}")
|
||||
|
||||
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
||||
dual_pool_coordinator = None
|
||||
use_dual_pool = args.cpu_workers is not None
|
||||
|
||||
if use_dual_pool:
|
||||
from training.processing import DualPoolCoordinator
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
|
||||
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
||||
dual_pool_coordinator = DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers,
|
||||
gpu_workers=args.gpu_workers,
|
||||
gpu_id=0,
|
||||
task_timeout=300.0,
|
||||
)
|
||||
dual_pool_coordinator.start()
|
||||
|
||||
try:
|
||||
# Process CSV files one by one (streaming)
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
# Check for shutdown request
|
||||
if _shutdown_requested:
|
||||
print("\nShutdown requested. Stopping after current batch...")
|
||||
break
|
||||
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
|
||||
# Load only this CSV file
|
||||
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
||||
rows = single_loader.load_all()
|
||||
|
||||
# Filter to single document if specified
|
||||
if args.single:
|
||||
rows = [r for r in rows if r.DocumentId == args.single]
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
# Deduplicate across CSV files
|
||||
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
|
||||
for r in rows:
|
||||
seen_doc_ids.add(r.DocumentId)
|
||||
|
||||
if not rows:
|
||||
print(f" Skipping CSV (no new documents)")
|
||||
continue
|
||||
|
||||
# Batch query database for all document IDs in this CSV
|
||||
csv_doc_ids = [r.DocumentId for r in rows]
|
||||
db_status_map = db.check_documents_status_batch(csv_doc_ids)
|
||||
|
||||
# Count how many are already processed successfully
|
||||
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
|
||||
|
||||
# Skip entire CSV if all documents are already processed
|
||||
if already_processed == len(rows):
|
||||
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
||||
stats['skipped_db'] += len(rows)
|
||||
continue
|
||||
|
||||
# Count how many new documents need processing in this CSV
|
||||
new_to_process = len(rows) - already_processed
|
||||
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
||||
|
||||
stats['total'] += len(rows)
|
||||
|
||||
# Prepare tasks for this CSV
|
||||
tasks = []
|
||||
skipped_in_csv = 0
|
||||
retry_in_csv = 0
|
||||
|
||||
# Calculate how many more we can process if limit is set
|
||||
# Use tasks_submitted counter which tracks across all CSVs
|
||||
if args.limit:
|
||||
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
||||
if remaining_limit <= 0:
|
||||
print(f" Reached limit of {args.limit} new documents, stopping.")
|
||||
break
|
||||
else:
|
||||
remaining_limit = float('inf')
|
||||
|
||||
# Collect doc_ids that need retry (for batch delete)
|
||||
retry_doc_ids = []
|
||||
|
||||
for row in rows:
|
||||
# Stop adding tasks if we've reached the limit
|
||||
if len(tasks) >= remaining_limit:
|
||||
break
|
||||
|
||||
doc_id = row.DocumentId
|
||||
|
||||
# Check document status from batch query result
|
||||
db_status = db_status_map.get(doc_id) # None if not in DB
|
||||
|
||||
# Skip if already successful in database
|
||||
if db_status is True:
|
||||
stats['skipped_db'] += 1
|
||||
skipped_in_csv += 1
|
||||
continue
|
||||
|
||||
# Check if this is a retry (was failed before)
|
||||
if db_status is False:
|
||||
stats['retried'] += 1
|
||||
retry_in_csv += 1
|
||||
retry_doc_ids.append(doc_id)
|
||||
|
||||
pdf_path = single_loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
row_dict = {
|
||||
'DocumentId': row.DocumentId,
|
||||
'InvoiceNumber': row.InvoiceNumber,
|
||||
'InvoiceDate': row.InvoiceDate,
|
||||
'InvoiceDueDate': row.InvoiceDueDate,
|
||||
'OCR': row.OCR,
|
||||
'Bankgiro': row.Bankgiro,
|
||||
'Plusgiro': row.Plusgiro,
|
||||
'Amount': row.Amount,
|
||||
# New fields for matching
|
||||
'supplier_organisation_number': row.supplier_organisation_number,
|
||||
'supplier_accounts': row.supplier_accounts,
|
||||
'customer_number': row.customer_number,
|
||||
# Metadata fields (not for matching, but for database storage)
|
||||
'split': row.split,
|
||||
'supplier_name': row.supplier_name,
|
||||
}
|
||||
|
||||
tasks.append((
|
||||
row_dict,
|
||||
str(pdf_path),
|
||||
str(output_dir),
|
||||
args.dpi,
|
||||
args.min_confidence,
|
||||
args.skip_ocr
|
||||
))
|
||||
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
|
||||
# Clean up retry documents: delete from database and remove temp folders
|
||||
if retry_doc_ids:
|
||||
# Batch delete from database (field_results will be cascade deleted)
|
||||
with db.connect().cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DELETE FROM documents WHERE document_id = ANY(%s)",
|
||||
(retry_doc_ids,)
|
||||
)
|
||||
db.connect().commit()
|
||||
# Remove temp folders
|
||||
for doc_id in retry_doc_ids:
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
|
||||
# Update tasks_submitted counter for limit tracking
|
||||
stats['tasks_submitted'] += len(tasks)
|
||||
|
||||
if use_dual_pool:
|
||||
# Dual-pool mode using pre-initialized DualPoolCoordinator
|
||||
# (process_text_pdf, process_scanned_pdf already imported above)
|
||||
|
||||
# Convert tasks to new format
|
||||
documents = []
|
||||
for task in tasks:
|
||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
|
||||
# Pre-classify PDF type
|
||||
try:
|
||||
is_text = is_text_pdf(pdf_path_str)
|
||||
except Exception:
|
||||
is_text = False
|
||||
|
||||
documents.append({
|
||||
"id": row_dict["DocumentId"],
|
||||
"row_dict": row_dict,
|
||||
"pdf_path": pdf_path_str,
|
||||
"output_dir": output_dir_str,
|
||||
"dpi": dpi,
|
||||
"min_confidence": min_confidence,
|
||||
"is_scanned": not is_text,
|
||||
"has_text": is_text,
|
||||
"text_length": 1000 if is_text else 0, # Approximate
|
||||
})
|
||||
|
||||
# Count task types
|
||||
text_count = sum(1 for d in documents if not d["is_scanned"])
|
||||
scan_count = len(documents) - text_count
|
||||
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
||||
|
||||
# Progress tracking with tqdm
|
||||
pbar = tqdm(total=len(documents), desc="Processing")
|
||||
|
||||
def on_result(task_result):
|
||||
"""Handle successful result."""
|
||||
result = task_result.data
|
||||
handle_result(result)
|
||||
pbar.update(1)
|
||||
|
||||
def on_error(task_id, error):
|
||||
"""Handle failed task."""
|
||||
handle_error(task_id, error)
|
||||
pbar.update(1)
|
||||
|
||||
# Process with pre-initialized coordinator (workers stay alive)
|
||||
results = dual_pool_coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error,
|
||||
id_field="id",
|
||||
)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# Log summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
print(f" Batch complete: {successful} successful, {failed} failed")
|
||||
|
||||
else:
|
||||
# Single-pool mode (original behavior)
|
||||
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
|
||||
# Process documents in parallel (inside CSV loop for streaming)
|
||||
# Use single process for debugging or when workers=1
|
||||
if args.workers == 1:
|
||||
for task in tqdm(tasks, desc="Processing"):
|
||||
result = process_single_document(task)
|
||||
handle_result(result)
|
||||
else:
|
||||
# Parallel processing with worker initialization
|
||||
# Each worker initializes OCR engine once and reuses it
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
|
||||
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||
for task in tasks}
|
||||
|
||||
# Per-document timeout: 120 seconds (2 minutes)
|
||||
# This prevents a single stuck document from blocking the entire batch
|
||||
DOCUMENT_TIMEOUT = 120
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
result = future.result(timeout=DOCUMENT_TIMEOUT)
|
||||
handle_result(result)
|
||||
except TimeoutError:
|
||||
handle_error(doc_id, f"Processing timeout after {DOCUMENT_TIMEOUT}s")
|
||||
# Cancel the stuck future
|
||||
future.cancel()
|
||||
except Exception as e:
|
||||
handle_error(doc_id, e)
|
||||
|
||||
# Flush remaining database batch after each CSV
|
||||
if db_batch:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
finally:
|
||||
# Shutdown dual-pool coordinator if it was started
|
||||
if dual_pool_coordinator is not None:
|
||||
dual_pool_coordinator.shutdown()
|
||||
|
||||
# Close temp file
|
||||
processed_items_writer.close()
|
||||
|
||||
# Use the in-memory counter instead of re-reading the file (performance fix)
|
||||
# processed_count already tracks the number of successfully processed items
|
||||
|
||||
# Cleanup processed_items temp file (not needed anymore)
|
||||
processed_items_file.unlink(missing_ok=True)
|
||||
|
||||
# Close database connection
|
||||
db.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Auto-labeling Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (no PDF): {stats['skipped']}")
|
||||
print(f"Skipped (in DB): {stats['skipped_db']}")
|
||||
print(f"Retried (failed): {stats['retried']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nImages saved to: {output_dir / 'temp'}")
|
||||
print(f"Labels stored in: PostgreSQL database")
|
||||
print(f"\nAnnotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
shard_files = report_writer.get_shard_files()
|
||||
if len(shard_files) > 1:
|
||||
print(f"\nReport files ({len(shard_files)}):")
|
||||
for sf in shard_files:
|
||||
print(f" - {sf}")
|
||||
else:
|
||||
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
299
packages/training/training/cli/import_report_to_db.py
Normal file
299
packages/training/training/cli/import_report_to_db.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Import existing JSONL report files into PostgreSQL database.
|
||||
|
||||
Usage:
|
||||
python -m src.cli.import_report_to_db --report "reports/autolabel_report_v4*.jsonl"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
# Add project root to path
|
||||
from shared.config import get_db_connection_string, PATHS
|
||||
|
||||
|
||||
def create_tables(conn):
|
||||
"""Create database tables."""
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
document_id TEXT PRIMARY KEY,
|
||||
pdf_path TEXT,
|
||||
pdf_type TEXT,
|
||||
success BOOLEAN,
|
||||
total_pages INTEGER,
|
||||
fields_matched INTEGER,
|
||||
fields_total INTEGER,
|
||||
annotations_generated INTEGER,
|
||||
processing_time_ms REAL,
|
||||
timestamp TIMESTAMPTZ,
|
||||
errors JSONB DEFAULT '[]',
|
||||
-- New fields for extended CSV format
|
||||
split TEXT,
|
||||
customer_number TEXT,
|
||||
supplier_name TEXT,
|
||||
supplier_organisation_number TEXT,
|
||||
supplier_accounts TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS field_results (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
|
||||
field_name TEXT,
|
||||
csv_value TEXT,
|
||||
matched BOOLEAN,
|
||||
score REAL,
|
||||
matched_text TEXT,
|
||||
candidate_used TEXT,
|
||||
bbox JSONB,
|
||||
page_no INTEGER,
|
||||
context_keywords JSONB DEFAULT '[]',
|
||||
error TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
|
||||
|
||||
-- Add new columns to existing tables if they don't exist (for migration)
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='split') THEN
|
||||
ALTER TABLE documents ADD COLUMN split TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='customer_number') THEN
|
||||
ALTER TABLE documents ADD COLUMN customer_number TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_name') THEN
|
||||
ALTER TABLE documents ADD COLUMN supplier_name TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_organisation_number') THEN
|
||||
ALTER TABLE documents ADD COLUMN supplier_organisation_number TEXT;
|
||||
END IF;
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='documents' AND column_name='supplier_accounts') THEN
|
||||
ALTER TABLE documents ADD COLUMN supplier_accounts TEXT;
|
||||
END IF;
|
||||
END $$;
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_size: int = 1000) -> dict:
|
||||
"""Import a single JSONL file into database."""
|
||||
stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
# Get existing document IDs if skipping
|
||||
existing_ids = set()
|
||||
if skip_existing:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT document_id FROM documents")
|
||||
existing_ids = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
doc_batch = []
|
||||
field_batch = []
|
||||
|
||||
def flush_batches():
|
||||
nonlocal doc_batch, field_batch
|
||||
if doc_batch:
|
||||
with conn.cursor() as cursor:
|
||||
execute_values(cursor, """
|
||||
INSERT INTO documents
|
||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors,
|
||||
split, customer_number, supplier_name, supplier_organisation_number, supplier_accounts)
|
||||
VALUES %s
|
||||
ON CONFLICT (document_id) DO UPDATE SET
|
||||
pdf_path = EXCLUDED.pdf_path,
|
||||
pdf_type = EXCLUDED.pdf_type,
|
||||
success = EXCLUDED.success,
|
||||
total_pages = EXCLUDED.total_pages,
|
||||
fields_matched = EXCLUDED.fields_matched,
|
||||
fields_total = EXCLUDED.fields_total,
|
||||
annotations_generated = EXCLUDED.annotations_generated,
|
||||
processing_time_ms = EXCLUDED.processing_time_ms,
|
||||
timestamp = EXCLUDED.timestamp,
|
||||
errors = EXCLUDED.errors,
|
||||
split = EXCLUDED.split,
|
||||
customer_number = EXCLUDED.customer_number,
|
||||
supplier_name = EXCLUDED.supplier_name,
|
||||
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
|
||||
supplier_accounts = EXCLUDED.supplier_accounts
|
||||
""", doc_batch)
|
||||
doc_batch = []
|
||||
|
||||
if field_batch:
|
||||
with conn.cursor() as cursor:
|
||||
execute_values(cursor, """
|
||||
INSERT INTO field_results
|
||||
(document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error)
|
||||
VALUES %s
|
||||
""", field_batch)
|
||||
field_batch = []
|
||||
|
||||
conn.commit()
|
||||
|
||||
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||
for line_no, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Warning: Line {line_no} - JSON parse error: {e}")
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
doc_id = record.get('document_id')
|
||||
if not doc_id:
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
# Only import successful documents
|
||||
if not record.get('success'):
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Check if already exists
|
||||
if skip_existing and doc_id in existing_ids:
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Add to batch
|
||||
doc_batch.append((
|
||||
doc_id,
|
||||
record.get('pdf_path'),
|
||||
record.get('pdf_type'),
|
||||
record.get('success'),
|
||||
record.get('total_pages'),
|
||||
record.get('fields_matched'),
|
||||
record.get('fields_total'),
|
||||
record.get('annotations_generated'),
|
||||
record.get('processing_time_ms'),
|
||||
record.get('timestamp'),
|
||||
json.dumps(record.get('errors', [])),
|
||||
# New fields
|
||||
record.get('split'),
|
||||
record.get('customer_number'),
|
||||
record.get('supplier_name'),
|
||||
record.get('supplier_organisation_number'),
|
||||
record.get('supplier_accounts'),
|
||||
))
|
||||
|
||||
for field in record.get('field_results', []):
|
||||
field_batch.append((
|
||||
doc_id,
|
||||
field.get('field_name'),
|
||||
field.get('csv_value'),
|
||||
field.get('matched'),
|
||||
field.get('score'),
|
||||
field.get('matched_text'),
|
||||
field.get('candidate_used'),
|
||||
json.dumps(field.get('bbox')) if field.get('bbox') else None,
|
||||
field.get('page_no'),
|
||||
json.dumps(field.get('context_keywords', [])),
|
||||
field.get('error')
|
||||
))
|
||||
|
||||
stats['imported'] += 1
|
||||
existing_ids.add(doc_id)
|
||||
|
||||
# Flush batch if needed
|
||||
if len(doc_batch) >= batch_size:
|
||||
flush_batches()
|
||||
print(f" Processed {stats['imported'] + stats['skipped']} records...")
|
||||
|
||||
# Final flush
|
||||
flush_batches()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Import JSONL reports to PostgreSQL database')
|
||||
parser.add_argument('--report', type=str, default=f"{PATHS['reports_dir']}/autolabel_report*.jsonl",
|
||||
help='Report file path or glob pattern')
|
||||
parser.add_argument('--db', type=str, default=None,
|
||||
help='PostgreSQL connection string (uses config.py if not specified)')
|
||||
parser.add_argument('--no-skip', action='store_true',
|
||||
help='Do not skip existing documents (replace them)')
|
||||
parser.add_argument('--batch-size', type=int, default=1000,
|
||||
help='Batch size for bulk inserts')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use config if db not specified
|
||||
db_connection = args.db or get_db_connection_string()
|
||||
|
||||
# Find report files
|
||||
report_path = Path(args.report)
|
||||
if '*' in str(report_path) or '?' in str(report_path):
|
||||
parent = report_path.parent
|
||||
pattern = report_path.name
|
||||
report_files = sorted(parent.glob(pattern))
|
||||
else:
|
||||
report_files = [report_path] if report_path.exists() else []
|
||||
|
||||
if not report_files:
|
||||
print(f"No report files found: {args.report}")
|
||||
return
|
||||
|
||||
print(f"Found {len(report_files)} report file(s)")
|
||||
|
||||
# Connect to database
|
||||
conn = psycopg2.connect(db_connection)
|
||||
create_tables(conn)
|
||||
|
||||
# Import each file
|
||||
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
for report_file in report_files:
|
||||
print(f"\nImporting: {report_file.name}")
|
||||
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
|
||||
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
|
||||
|
||||
for key in total_stats:
|
||||
total_stats[key] += stats[key]
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Import Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total imported: {total_stats['imported']}")
|
||||
print(f"Total skipped: {total_stats['skipped']}")
|
||||
print(f"Total errors: {total_stats['errors']}")
|
||||
|
||||
# Quick stats from database
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
total_docs = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM documents WHERE success = true")
|
||||
success_docs = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM field_results")
|
||||
total_fields = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM field_results WHERE matched = true")
|
||||
matched_fields = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
print(f"\nDatabase Stats:")
|
||||
print(f" Documents: {total_docs} ({success_docs} successful)")
|
||||
print(f" Field results: {total_fields} ({matched_fields} matched)")
|
||||
if total_fields > 0:
|
||||
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
424
packages/training/training/cli/reprocess_failed.py
Normal file
424
packages/training/training/cli/reprocess_failed.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Re-process failed matches and store detailed information including OCR values,
|
||||
CSV values, and source CSV filename in a new table.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.data.db import DocumentDB
|
||||
from shared.data.csv_loader import CSVLoader
|
||||
from shared.normalize.normalizer import normalize_field
|
||||
|
||||
|
||||
def create_failed_match_table(db: DocumentDB):
|
||||
"""Create the failed_match_details table."""
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
DROP TABLE IF EXISTS failed_match_details;
|
||||
|
||||
CREATE TABLE failed_match_details (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
field_name TEXT NOT NULL,
|
||||
csv_value TEXT,
|
||||
csv_value_normalized TEXT,
|
||||
ocr_value TEXT,
|
||||
ocr_value_normalized TEXT,
|
||||
all_ocr_candidates JSONB,
|
||||
matched BOOLEAN DEFAULT FALSE,
|
||||
match_score REAL,
|
||||
pdf_path TEXT,
|
||||
pdf_type TEXT,
|
||||
csv_filename TEXT,
|
||||
page_no INTEGER,
|
||||
bbox JSONB,
|
||||
error TEXT,
|
||||
reprocessed_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
|
||||
UNIQUE(document_id, field_name)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_document_id ON failed_match_details(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_field_name ON failed_match_details(field_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_csv_filename ON failed_match_details(csv_filename);
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
||||
""")
|
||||
conn.commit()
|
||||
print("Created table: failed_match_details")
|
||||
|
||||
|
||||
def get_failed_documents(db: DocumentDB) -> list:
|
||||
"""Get all documents that have at least one failed field match."""
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT DISTINCT fr.document_id, d.pdf_path, d.pdf_type
|
||||
FROM field_results fr
|
||||
JOIN documents d ON fr.document_id = d.document_id
|
||||
WHERE fr.matched = false
|
||||
ORDER BY fr.document_id
|
||||
""")
|
||||
return [{'document_id': row[0], 'pdf_path': row[1], 'pdf_type': row[2]}
|
||||
for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_failed_fields_for_document(db: DocumentDB, doc_id: str) -> list:
|
||||
"""Get all failed field results for a document."""
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT field_name, csv_value, error
|
||||
FROM field_results
|
||||
WHERE document_id = %s AND matched = false
|
||||
""", (doc_id,))
|
||||
return [{'field_name': row[0], 'csv_value': row[1], 'error': row[2]}
|
||||
for row in cursor.fetchall()]
|
||||
|
||||
|
||||
# Cache for CSV data
|
||||
_csv_cache = {}
|
||||
|
||||
def build_csv_cache(csv_files: list):
|
||||
"""Build a cache of document_id to csv_filename mapping."""
|
||||
global _csv_cache
|
||||
_csv_cache = {}
|
||||
|
||||
for csv_file in csv_files:
|
||||
csv_filename = os.path.basename(csv_file)
|
||||
loader = CSVLoader(csv_file)
|
||||
for row in loader.iter_rows():
|
||||
if row.DocumentId not in _csv_cache:
|
||||
_csv_cache[row.DocumentId] = csv_filename
|
||||
|
||||
|
||||
def find_csv_filename(doc_id: str) -> str:
|
||||
"""Find which CSV file contains the document ID."""
|
||||
return _csv_cache.get(doc_id, None)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize worker process."""
|
||||
import os
|
||||
import warnings
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def process_single_document(args):
|
||||
"""Process a single document and extract OCR values for failed fields."""
|
||||
doc_info, failed_fields, csv_filename = args
|
||||
doc_id = doc_info['document_id']
|
||||
pdf_path = doc_info['pdf_path']
|
||||
pdf_type = doc_info['pdf_type']
|
||||
|
||||
results = []
|
||||
|
||||
# Try to extract OCR from PDF
|
||||
try:
|
||||
if pdf_path and os.path.exists(pdf_path):
|
||||
from shared.pdf import PDFDocument
|
||||
from shared.ocr import OCREngine
|
||||
|
||||
pdf_doc = PDFDocument(pdf_path)
|
||||
is_scanned = pdf_doc.detect_type() == "scanned"
|
||||
|
||||
# Collect all OCR text blocks
|
||||
all_ocr_texts = []
|
||||
|
||||
if is_scanned:
|
||||
# Use OCR for scanned PDFs
|
||||
ocr_engine = OCREngine()
|
||||
for page_no in range(pdf_doc.page_count):
|
||||
# Render page to image
|
||||
img = pdf_doc.render_page(page_no, dpi=DEFAULT_DPI)
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
# OCR the image
|
||||
ocr_results = ocr_engine.extract_from_image(img)
|
||||
for block in ocr_results:
|
||||
all_ocr_texts.append({
|
||||
'text': block.get('text', ''),
|
||||
'bbox': block.get('bbox'),
|
||||
'page_no': page_no
|
||||
})
|
||||
else:
|
||||
# Use text extraction for text PDFs
|
||||
for page_no in range(pdf_doc.page_count):
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
for token in tokens:
|
||||
all_ocr_texts.append({
|
||||
'text': token.text,
|
||||
'bbox': token.bbox,
|
||||
'page_no': page_no
|
||||
})
|
||||
|
||||
# For each failed field, try to find matching OCR
|
||||
for field in failed_fields:
|
||||
field_name = field['field_name']
|
||||
csv_value = field['csv_value']
|
||||
error = field['error']
|
||||
|
||||
# Normalize CSV value
|
||||
csv_normalized = normalize_field(field_name, csv_value) if csv_value else None
|
||||
|
||||
# Try to find best match in OCR
|
||||
best_score = 0
|
||||
best_ocr = None
|
||||
best_bbox = None
|
||||
best_page = None
|
||||
|
||||
for ocr_block in all_ocr_texts:
|
||||
ocr_text = ocr_block['text']
|
||||
if not ocr_text:
|
||||
continue
|
||||
ocr_normalized = normalize_field(field_name, ocr_text)
|
||||
|
||||
# Calculate similarity
|
||||
if csv_normalized and ocr_normalized:
|
||||
# Check substring match
|
||||
if csv_normalized in ocr_normalized:
|
||||
score = len(csv_normalized) / max(len(ocr_normalized), 1)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_ocr = ocr_text
|
||||
best_bbox = ocr_block['bbox']
|
||||
best_page = ocr_block['page_no']
|
||||
elif ocr_normalized in csv_normalized:
|
||||
score = len(ocr_normalized) / max(len(csv_normalized), 1)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_ocr = ocr_text
|
||||
best_bbox = ocr_block['bbox']
|
||||
best_page = ocr_block['page_no']
|
||||
# Exact match
|
||||
elif csv_normalized == ocr_normalized:
|
||||
best_score = 1.0
|
||||
best_ocr = ocr_text
|
||||
best_bbox = ocr_block['bbox']
|
||||
best_page = ocr_block['page_no']
|
||||
break
|
||||
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'field_name': field_name,
|
||||
'csv_value': csv_value,
|
||||
'csv_value_normalized': csv_normalized,
|
||||
'ocr_value': best_ocr,
|
||||
'ocr_value_normalized': normalize_field(field_name, best_ocr) if best_ocr else None,
|
||||
'all_ocr_candidates': [t['text'] for t in all_ocr_texts[:100]], # Limit to 100
|
||||
'matched': best_score > 0.8,
|
||||
'match_score': best_score,
|
||||
'pdf_path': pdf_path,
|
||||
'pdf_type': pdf_type,
|
||||
'csv_filename': csv_filename,
|
||||
'page_no': best_page,
|
||||
'bbox': list(best_bbox) if best_bbox else None,
|
||||
'error': error
|
||||
})
|
||||
else:
|
||||
# PDF not found
|
||||
for field in failed_fields:
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'field_name': field['field_name'],
|
||||
'csv_value': field['csv_value'],
|
||||
'csv_value_normalized': normalize_field(field['field_name'], field['csv_value']) if field['csv_value'] else None,
|
||||
'ocr_value': None,
|
||||
'ocr_value_normalized': None,
|
||||
'all_ocr_candidates': [],
|
||||
'matched': False,
|
||||
'match_score': 0,
|
||||
'pdf_path': pdf_path,
|
||||
'pdf_type': pdf_type,
|
||||
'csv_filename': csv_filename,
|
||||
'page_no': None,
|
||||
'bbox': None,
|
||||
'error': f"PDF not found: {pdf_path}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
for field in failed_fields:
|
||||
results.append({
|
||||
'document_id': doc_id,
|
||||
'field_name': field['field_name'],
|
||||
'csv_value': field['csv_value'],
|
||||
'csv_value_normalized': None,
|
||||
'ocr_value': None,
|
||||
'ocr_value_normalized': None,
|
||||
'all_ocr_candidates': [],
|
||||
'matched': False,
|
||||
'match_score': 0,
|
||||
'pdf_path': pdf_path,
|
||||
'pdf_type': pdf_type,
|
||||
'csv_filename': csv_filename,
|
||||
'page_no': None,
|
||||
'bbox': None,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_results_batch(db: DocumentDB, results: list):
|
||||
"""Save results to failed_match_details table."""
|
||||
if not results:
|
||||
return
|
||||
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
for r in results:
|
||||
cursor.execute("""
|
||||
INSERT INTO failed_match_details
|
||||
(document_id, field_name, csv_value, csv_value_normalized,
|
||||
ocr_value, ocr_value_normalized, all_ocr_candidates,
|
||||
matched, match_score, pdf_path, pdf_type, csv_filename,
|
||||
page_no, bbox, error)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT (document_id, field_name) DO UPDATE SET
|
||||
csv_value = EXCLUDED.csv_value,
|
||||
csv_value_normalized = EXCLUDED.csv_value_normalized,
|
||||
ocr_value = EXCLUDED.ocr_value,
|
||||
ocr_value_normalized = EXCLUDED.ocr_value_normalized,
|
||||
all_ocr_candidates = EXCLUDED.all_ocr_candidates,
|
||||
matched = EXCLUDED.matched,
|
||||
match_score = EXCLUDED.match_score,
|
||||
pdf_path = EXCLUDED.pdf_path,
|
||||
pdf_type = EXCLUDED.pdf_type,
|
||||
csv_filename = EXCLUDED.csv_filename,
|
||||
page_no = EXCLUDED.page_no,
|
||||
bbox = EXCLUDED.bbox,
|
||||
error = EXCLUDED.error,
|
||||
reprocessed_at = NOW()
|
||||
""", (
|
||||
r['document_id'],
|
||||
r['field_name'],
|
||||
r['csv_value'],
|
||||
r['csv_value_normalized'],
|
||||
r['ocr_value'],
|
||||
r['ocr_value_normalized'],
|
||||
json.dumps(r['all_ocr_candidates']),
|
||||
r['matched'],
|
||||
r['match_score'],
|
||||
r['pdf_path'],
|
||||
r['pdf_type'],
|
||||
r['csv_filename'],
|
||||
r['page_no'],
|
||||
json.dumps(r['bbox']) if r['bbox'] else None,
|
||||
r['error']
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Re-process failed matches')
|
||||
parser.add_argument('--csv', required=True, help='CSV files glob pattern')
|
||||
parser.add_argument('--pdf-dir', required=True, help='PDF directory')
|
||||
parser.add_argument('--workers', type=int, default=3, help='Number of workers')
|
||||
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Expand CSV glob
|
||||
csv_files = sorted(glob.glob(args.csv))
|
||||
print(f"Found {len(csv_files)} CSV files")
|
||||
|
||||
# Build CSV cache
|
||||
print("Building CSV filename cache...")
|
||||
build_csv_cache(csv_files)
|
||||
print(f"Cached {len(_csv_cache)} document IDs")
|
||||
|
||||
# Connect to database
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create new table
|
||||
create_failed_match_table(db)
|
||||
|
||||
# Get all failed documents
|
||||
print("Fetching failed documents...")
|
||||
failed_docs = get_failed_documents(db)
|
||||
print(f"Found {len(failed_docs)} documents with failed matches")
|
||||
|
||||
if args.limit:
|
||||
failed_docs = failed_docs[:args.limit]
|
||||
print(f"Limited to {len(failed_docs)} documents")
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
for doc in failed_docs:
|
||||
failed_fields = get_failed_fields_for_document(db, doc['document_id'])
|
||||
csv_filename = find_csv_filename(doc['document_id'])
|
||||
if failed_fields:
|
||||
tasks.append((doc, failed_fields, csv_filename))
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
|
||||
# Process with multiprocessing
|
||||
total_results = 0
|
||||
batch_results = []
|
||||
batch_size = 50
|
||||
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_worker) as executor:
|
||||
futures = {executor.submit(process_single_document, task): task[0]['document_id']
|
||||
for task in tasks}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
results = future.result(timeout=120)
|
||||
batch_results.extend(results)
|
||||
total_results += len(results)
|
||||
|
||||
# Save in batches
|
||||
if len(batch_results) >= batch_size:
|
||||
save_results_batch(db, batch_results)
|
||||
batch_results = []
|
||||
|
||||
except TimeoutError:
|
||||
print(f"\nTimeout processing {doc_id}")
|
||||
except Exception as e:
|
||||
print(f"\nError processing {doc_id}: {e}")
|
||||
|
||||
# Save remaining results
|
||||
if batch_results:
|
||||
save_results_batch(db, batch_results)
|
||||
|
||||
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
||||
|
||||
# Show summary
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT field_name, COUNT(*) as total,
|
||||
COUNT(*) FILTER (WHERE ocr_value IS NOT NULL) as has_ocr,
|
||||
COALESCE(AVG(match_score), 0) as avg_score
|
||||
FROM failed_match_details
|
||||
GROUP BY field_name
|
||||
ORDER BY total DESC
|
||||
""")
|
||||
print("\nSummary by field:")
|
||||
print("-" * 70)
|
||||
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
||||
print("-" * 70)
|
||||
for row in cursor.fetchall():
|
||||
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
||||
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
269
packages/training/training/cli/train.py
Normal file
269
packages/training/training/cli/train.py
Normal file
@@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training CLI
|
||||
|
||||
Trains YOLO model on dataset with labels from PostgreSQL database.
|
||||
Images are read from filesystem, labels are dynamically generated from DB.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train YOLO model for invoice field detection'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset-dir', '-d',
|
||||
default=PATHS['output_dir'],
|
||||
help='Dataset directory containing temp/{doc_id}/images/ (default: data/dataset)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
default='yolov8s.pt',
|
||||
help='Base model (default: yolov8s.pt)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--epochs', '-e',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of epochs (default: 100)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch', '-b',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Batch size (default: 16)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--imgsz',
|
||||
type=int,
|
||||
default=1280,
|
||||
help='Image size (default: 1280)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--project',
|
||||
default='runs/train',
|
||||
help='Project directory (default: runs/train)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--name',
|
||||
default='invoice_fields',
|
||||
help='Run name (default: invoice_fields)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
default='0',
|
||||
help='Device (0, 1, cpu, mps)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--resume',
|
||||
action='store_true',
|
||||
help='Resume from last checkpoint'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--workers',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of data loader workers (default: 4, reduce if OOM)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cache',
|
||||
action='store_true',
|
||||
help='Cache images in RAM (faster but uses more memory)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--low-memory',
|
||||
action='store_true',
|
||||
help='Enable low memory mode (batch=4, workers=2, no cache)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--train-ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Training set ratio (default: 0.8)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--val-ratio',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Validation set ratio (default: 0.1)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help='Random seed for split (default: 42)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f'DPI used for rendering (default: {DEFAULT_DPI}, must match autolabel rendering)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--export-only',
|
||||
action='store_true',
|
||||
help='Only export dataset to YOLO format, do not train'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents for training (default: all)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Apply low-memory mode if specified
|
||||
if args.low_memory:
|
||||
print("🔧 Low memory mode enabled")
|
||||
args.batch = min(args.batch, 8) # Reduce from 16 to 8
|
||||
args.workers = min(args.workers, 4) # Reduce from 8 to 4
|
||||
args.cache = False
|
||||
print(f" Batch size: {args.batch}")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Cache: disabled")
|
||||
|
||||
# Validate dataset directory
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
temp_dir = dataset_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Error: Temp directory not found: {temp_dir}")
|
||||
print("Run autolabel first to generate images.")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print("YOLO Training with Database Labels")
|
||||
print("=" * 60)
|
||||
print(f"Dataset dir: {dataset_dir}")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch}")
|
||||
print(f"Image size: {args.imgsz}")
|
||||
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
|
||||
if args.limit:
|
||||
print(f"Document limit: {args.limit}")
|
||||
|
||||
# Connect to database
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
print("\nConnecting to database...")
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create datasets from database
|
||||
from training.yolo.db_dataset import create_datasets
|
||||
|
||||
print("Loading dataset from database...")
|
||||
datasets = create_datasets(
|
||||
images_dir=dataset_dir,
|
||||
db=db,
|
||||
train_ratio=args.train_ratio,
|
||||
val_ratio=args.val_ratio,
|
||||
seed=args.seed,
|
||||
dpi=args.dpi,
|
||||
limit=args.limit
|
||||
)
|
||||
|
||||
print(f"\nDataset splits:")
|
||||
print(f" Train: {len(datasets['train'])} items")
|
||||
print(f" Val: {len(datasets['val'])} items")
|
||||
print(f" Test: {len(datasets['test'])} items")
|
||||
|
||||
if len(datasets['train']) == 0:
|
||||
print("\nError: No training data found!")
|
||||
print("Make sure autolabel has been run and images exist in temp directory.")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Export to YOLO format (required for Ultralytics training)
|
||||
print("\nExporting dataset to YOLO format...")
|
||||
for split_name, dataset in datasets.items():
|
||||
count = dataset.export_to_yolo_format(dataset_dir, split_name)
|
||||
print(f" {split_name}: {count} items exported")
|
||||
|
||||
# Generate YOLO config files
|
||||
from training.yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
|
||||
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
|
||||
|
||||
if args.export_only:
|
||||
print("\nExport complete (--export-only specified, skipping training)")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load model
|
||||
last_checkpoint = Path(args.project) / args.name / 'weights' / 'last.pt'
|
||||
if args.resume and last_checkpoint.exists():
|
||||
print(f"Resuming from: {last_checkpoint}")
|
||||
model = YOLO(str(last_checkpoint))
|
||||
else:
|
||||
model = YOLO(args.model)
|
||||
|
||||
# Training arguments
|
||||
data_yaml = dataset_dir / 'dataset.yaml'
|
||||
train_args = {
|
||||
'data': str(data_yaml.absolute()),
|
||||
'epochs': args.epochs,
|
||||
'batch': args.batch,
|
||||
'imgsz': args.imgsz,
|
||||
'project': args.project,
|
||||
'name': args.name,
|
||||
'device': args.device,
|
||||
'exist_ok': True,
|
||||
'pretrained': True,
|
||||
'verbose': True,
|
||||
'workers': args.workers,
|
||||
'cache': args.cache,
|
||||
'resume': args.resume and last_checkpoint.exists(),
|
||||
# Document-specific augmentation settings
|
||||
'degrees': 5.0,
|
||||
'translate': 0.05,
|
||||
'scale': 0.2,
|
||||
'shear': 0.0,
|
||||
'perspective': 0.0,
|
||||
'flipud': 0.0,
|
||||
'fliplr': 0.0,
|
||||
'mosaic': 0.0,
|
||||
'mixup': 0.0,
|
||||
'hsv_h': 0.0,
|
||||
'hsv_s': 0.1,
|
||||
'hsv_v': 0.2,
|
||||
}
|
||||
|
||||
# Train
|
||||
results = model.train(**train_args)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 60)
|
||||
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation on test set...")
|
||||
metrics = model.val(split='test')
|
||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
336
packages/training/training/cli/validate.py
Normal file
336
packages/training/training/cli/validate.py
Normal file
@@ -0,0 +1,336 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI for cross-validation of invoice field extraction using LLM.
|
||||
|
||||
Validates documents with failed field matches by sending them to an LLM
|
||||
and comparing the extraction results.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Cross-validate invoice field extraction using LLM'
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Commands')
|
||||
|
||||
# Stats command
|
||||
stats_parser = subparsers.add_parser('stats', help='Show failed match statistics')
|
||||
|
||||
# Validate command
|
||||
validate_parser = subparsers.add_parser('validate', help='Validate documents with failed matches')
|
||||
validate_parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum number of documents to validate (default: 10)'
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
'--provider', '-p',
|
||||
choices=['openai', 'anthropic'],
|
||||
default='openai',
|
||||
help='LLM provider to use (default: openai)'
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
'--model', '-m',
|
||||
help='Model to use (default: gpt-4o for OpenAI, claude-sonnet-4-20250514 for Anthropic)'
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
'--single', '-s',
|
||||
help='Validate a single document ID'
|
||||
)
|
||||
|
||||
# Compare command
|
||||
compare_parser = subparsers.add_parser('compare', help='Compare validation results')
|
||||
compare_parser.add_argument(
|
||||
'document_id',
|
||||
nargs='?',
|
||||
help='Document ID to compare (or omit to show all)'
|
||||
)
|
||||
compare_parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Maximum number of results to show (default: 20)'
|
||||
)
|
||||
|
||||
# Report command
|
||||
report_parser = subparsers.add_parser('report', help='Generate validation report')
|
||||
report_parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='reports/llm_validation_report.json',
|
||||
help='Output file path (default: reports/llm_validation_report.json)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
from inference.validation import LLMValidator
|
||||
|
||||
validator = LLMValidator()
|
||||
validator.connect()
|
||||
validator.create_validation_table()
|
||||
|
||||
if args.command == 'stats':
|
||||
show_stats(validator)
|
||||
|
||||
elif args.command == 'validate':
|
||||
if args.single:
|
||||
validate_single(validator, args.single, args.provider, args.model)
|
||||
else:
|
||||
validate_batch(validator, args.limit, args.provider, args.model)
|
||||
|
||||
elif args.command == 'compare':
|
||||
if args.document_id:
|
||||
compare_single(validator, args.document_id)
|
||||
else:
|
||||
compare_all(validator, args.limit)
|
||||
|
||||
elif args.command == 'report':
|
||||
generate_report(validator, args.output)
|
||||
|
||||
validator.close()
|
||||
|
||||
|
||||
def show_stats(validator):
|
||||
"""Show statistics about failed matches."""
|
||||
stats = validator.get_failed_match_stats()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Failed Match Statistics")
|
||||
print("=" * 50)
|
||||
print(f"\nDocuments with failures: {stats['documents_with_failures']}")
|
||||
print(f"Already validated: {stats['already_validated']}")
|
||||
print(f"Remaining to validate: {stats['remaining']}")
|
||||
print("\nFailures by field:")
|
||||
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
|
||||
print(f" {field}: {count}")
|
||||
|
||||
|
||||
def validate_single(validator, doc_id: str, provider: str, model: str):
|
||||
"""Validate a single document."""
|
||||
print(f"\nValidating document: {doc_id}")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
|
||||
result = validator.validate_document(doc_id, provider, model)
|
||||
|
||||
if result.error:
|
||||
print(f"ERROR: {result.error}")
|
||||
return
|
||||
|
||||
print(f"Processing time: {result.processing_time_ms:.0f}ms")
|
||||
print(f"Model used: {result.model_used}")
|
||||
print("\nExtracted fields:")
|
||||
print(f" Invoice Number: {result.invoice_number}")
|
||||
print(f" Invoice Date: {result.invoice_date}")
|
||||
print(f" Due Date: {result.invoice_due_date}")
|
||||
print(f" OCR: {result.ocr_number}")
|
||||
print(f" Bankgiro: {result.bankgiro}")
|
||||
print(f" Plusgiro: {result.plusgiro}")
|
||||
print(f" Amount: {result.amount}")
|
||||
print(f" Org Number: {result.supplier_organisation_number}")
|
||||
|
||||
# Show comparison
|
||||
print("\n" + "-" * 50)
|
||||
print("Comparison with autolabel:")
|
||||
comparison = validator.compare_results(doc_id)
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value'):
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
print(f" {status} {field}:")
|
||||
print(f" CSV: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM: {data['llm_value']}")
|
||||
|
||||
|
||||
def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
"""Validate a batch of documents."""
|
||||
print(f"\nValidating up to {limit} documents with failed matches")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
|
||||
results = validator.validate_batch(
|
||||
limit=limit,
|
||||
provider=provider,
|
||||
model=model,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Summary
|
||||
success = sum(1 for r in results if not r.error)
|
||||
failed = len(results) - success
|
||||
total_time = sum(r.processing_time_ms or 0 for r in results)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Validation Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total: {len(results)}")
|
||||
print(f"Success: {success}")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Total time: {total_time/1000:.1f}s")
|
||||
if success > 0:
|
||||
print(f"Avg time: {total_time/success:.0f}ms per document")
|
||||
|
||||
|
||||
def compare_single(validator, doc_id: str):
|
||||
"""Compare results for a single document."""
|
||||
comparison = validator.compare_results(doc_id)
|
||||
|
||||
if 'error' in comparison:
|
||||
print(f"Error: {comparison['error']}")
|
||||
return
|
||||
|
||||
print(f"\nComparison for document: {doc_id}")
|
||||
print("=" * 60)
|
||||
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value') is None:
|
||||
continue
|
||||
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
|
||||
print(f"\n{status} {field}:")
|
||||
print(f" CSV value: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM extracted: {data['llm_value']}")
|
||||
|
||||
|
||||
def compare_all(validator, limit: int):
|
||||
"""Show comparison summary for all validated documents."""
|
||||
conn = validator.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT document_id FROM llm_validations
|
||||
WHERE error IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s
|
||||
""", (limit,))
|
||||
|
||||
doc_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not doc_ids:
|
||||
print("No validated documents found.")
|
||||
return
|
||||
|
||||
print(f"\nComparison Summary ({len(doc_ids)} documents)")
|
||||
print("=" * 80)
|
||||
|
||||
# Aggregate stats
|
||||
field_stats = {}
|
||||
|
||||
for doc_id in doc_ids:
|
||||
comparison = validator.compare_results(doc_id)
|
||||
if 'error' in comparison:
|
||||
continue
|
||||
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value') is None:
|
||||
continue
|
||||
|
||||
if field not in field_stats:
|
||||
field_stats[field] = {
|
||||
'total': 0,
|
||||
'autolabel_matched': 0,
|
||||
'llm_agrees': 0,
|
||||
'llm_correct_auto_wrong': 0,
|
||||
}
|
||||
|
||||
stats = field_stats[field]
|
||||
stats['total'] += 1
|
||||
|
||||
if data['autolabel_matched']:
|
||||
stats['autolabel_matched'] += 1
|
||||
|
||||
if data['agreement']:
|
||||
stats['llm_agrees'] += 1
|
||||
|
||||
# LLM found correct value when autolabel failed
|
||||
if not data['autolabel_matched'] and data['agreement']:
|
||||
stats['llm_correct_auto_wrong'] += 1
|
||||
|
||||
print(f"\n{'Field':<30} {'Total':>6} {'Auto OK':>8} {'LLM Agrees':>10} {'LLM Found':>10}")
|
||||
print("-" * 80)
|
||||
|
||||
for field, stats in sorted(field_stats.items()):
|
||||
print(f"{field:<30} {stats['total']:>6} {stats['autolabel_matched']:>8} "
|
||||
f"{stats['llm_agrees']:>10} {stats['llm_correct_auto_wrong']:>10}")
|
||||
|
||||
|
||||
def generate_report(validator, output_path: str):
|
||||
"""Generate a detailed validation report."""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
conn = validator.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Get all validated documents
|
||||
cursor.execute("""
|
||||
SELECT document_id, invoice_number, invoice_date, invoice_due_date,
|
||||
ocr_number, bankgiro, plusgiro, amount,
|
||||
supplier_organisation_number, model_used, processing_time_ms,
|
||||
error, created_at
|
||||
FROM llm_validations
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
validations = []
|
||||
for row in cursor.fetchall():
|
||||
doc_id = row[0]
|
||||
comparison = validator.compare_results(doc_id) if not row[11] else {}
|
||||
|
||||
validations.append({
|
||||
'document_id': doc_id,
|
||||
'llm_extraction': {
|
||||
'invoice_number': row[1],
|
||||
'invoice_date': row[2],
|
||||
'invoice_due_date': row[3],
|
||||
'ocr_number': row[4],
|
||||
'bankgiro': row[5],
|
||||
'plusgiro': row[6],
|
||||
'amount': row[7],
|
||||
'supplier_organisation_number': row[8],
|
||||
},
|
||||
'model_used': row[9],
|
||||
'processing_time_ms': row[10],
|
||||
'error': row[11],
|
||||
'created_at': str(row[12]) if row[12] else None,
|
||||
'comparison': comparison,
|
||||
})
|
||||
|
||||
# Calculate summary stats
|
||||
stats = validator.get_failed_match_stats()
|
||||
|
||||
report = {
|
||||
'generated_at': datetime.now().isoformat(),
|
||||
'summary': {
|
||||
'total_documents_with_failures': stats['documents_with_failures'],
|
||||
'documents_validated': stats['already_validated'],
|
||||
'failures_by_field': stats['failures_by_field'],
|
||||
},
|
||||
'validations': validations,
|
||||
}
|
||||
|
||||
# Write report
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport generated: {output_path}")
|
||||
print(f"Total validations: {len(validations)}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
packages/training/training/data/__init__.py
Normal file
0
packages/training/training/data/__init__.py
Normal file
313
packages/training/training/data/autolabel_report.py
Normal file
313
packages/training/training/data/autolabel_report.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Auto-Label Report Generator
|
||||
|
||||
Generates quality control reports for auto-labeling process.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldMatchResult:
|
||||
"""Result of matching a single field."""
|
||||
field_name: str
|
||||
csv_value: str | None
|
||||
matched: bool
|
||||
score: float = 0.0
|
||||
matched_text: str | None = None
|
||||
candidate_used: str | None = None # Which normalized variant matched
|
||||
bbox: tuple[float, float, float, float] | None = None
|
||||
page_no: int = 0
|
||||
context_keywords: list[str] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
# Convert bbox to native Python floats to avoid numpy serialization issues
|
||||
bbox_list = None
|
||||
if self.bbox:
|
||||
bbox_list = [float(x) for x in self.bbox]
|
||||
|
||||
return {
|
||||
'field_name': self.field_name,
|
||||
'csv_value': self.csv_value,
|
||||
'matched': self.matched,
|
||||
'score': float(self.score) if self.score else 0.0,
|
||||
'matched_text': self.matched_text,
|
||||
'candidate_used': self.candidate_used,
|
||||
'bbox': bbox_list,
|
||||
'page_no': int(self.page_no) if self.page_no else 0,
|
||||
'context_keywords': self.context_keywords,
|
||||
'error': self.error
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoLabelReport:
|
||||
"""Report for a single document's auto-labeling process."""
|
||||
document_id: str
|
||||
pdf_path: str | None = None
|
||||
pdf_type: str | None = None # 'text' | 'scanned' | 'mixed'
|
||||
success: bool = False
|
||||
total_pages: int = 0
|
||||
fields_matched: int = 0
|
||||
fields_total: int = 0
|
||||
field_results: list[FieldMatchResult] = field(default_factory=list)
|
||||
annotations_generated: int = 0
|
||||
image_paths: list[str] = field(default_factory=list)
|
||||
label_paths: list[str] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
errors: list[str] = field(default_factory=list)
|
||||
# New metadata fields (from CSV, not for matching)
|
||||
split: str | None = None
|
||||
customer_number: str | None = None
|
||||
supplier_name: str | None = None
|
||||
supplier_organisation_number: str | None = None
|
||||
supplier_accounts: str | None = None
|
||||
|
||||
def add_field_result(self, result: FieldMatchResult) -> None:
|
||||
"""Add a field matching result."""
|
||||
self.field_results.append(result)
|
||||
self.fields_total += 1
|
||||
if result.matched:
|
||||
self.fields_matched += 1
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'document_id': self.document_id,
|
||||
'pdf_path': self.pdf_path,
|
||||
'pdf_type': self.pdf_type,
|
||||
'success': self.success,
|
||||
'total_pages': self.total_pages,
|
||||
'fields_matched': self.fields_matched,
|
||||
'fields_total': self.fields_total,
|
||||
'field_results': [r.to_dict() for r in self.field_results],
|
||||
'annotations_generated': self.annotations_generated,
|
||||
'image_paths': self.image_paths,
|
||||
'label_paths': self.label_paths,
|
||||
'processing_time_ms': self.processing_time_ms,
|
||||
'timestamp': self.timestamp,
|
||||
'errors': self.errors,
|
||||
# New metadata fields
|
||||
'split': self.split,
|
||||
'customer_number': self.customer_number,
|
||||
'supplier_name': self.supplier_name,
|
||||
'supplier_organisation_number': self.supplier_organisation_number,
|
||||
'supplier_accounts': self.supplier_accounts,
|
||||
}
|
||||
|
||||
def to_json(self, indent: int | None = None) -> str:
|
||||
"""Convert to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def match_rate(self) -> float:
|
||||
"""Calculate field match rate."""
|
||||
if self.fields_total == 0:
|
||||
return 0.0
|
||||
return self.fields_matched / self.fields_total
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
"""Get a summary of the report."""
|
||||
return {
|
||||
'document_id': self.document_id,
|
||||
'success': self.success,
|
||||
'match_rate': f"{self.match_rate:.1%}",
|
||||
'fields': f"{self.fields_matched}/{self.fields_total}",
|
||||
'annotations': self.annotations_generated,
|
||||
'errors': len(self.errors)
|
||||
}
|
||||
|
||||
|
||||
class ReportWriter:
|
||||
"""Writes auto-label reports to file with optional sharding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_path: str | Path,
|
||||
max_records_per_file: int = 0
|
||||
):
|
||||
"""
|
||||
Initialize report writer.
|
||||
|
||||
Args:
|
||||
output_path: Path to output JSONL file (base name if sharding)
|
||||
max_records_per_file: Max records per file (0 = no limit, single file)
|
||||
"""
|
||||
self.output_path = Path(output_path)
|
||||
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.max_records_per_file = max_records_per_file
|
||||
|
||||
# Sharding state
|
||||
self._current_shard = 0
|
||||
self._records_in_current_shard = 0
|
||||
self._shard_files: list[Path] = []
|
||||
|
||||
def _get_shard_path(self) -> Path:
|
||||
"""Get the path for current shard."""
|
||||
if self.max_records_per_file > 0:
|
||||
base = self.output_path.stem
|
||||
suffix = self.output_path.suffix
|
||||
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
|
||||
else:
|
||||
shard_path = self.output_path
|
||||
|
||||
if shard_path not in self._shard_files:
|
||||
self._shard_files.append(shard_path)
|
||||
|
||||
return shard_path
|
||||
|
||||
def _check_shard_rotation(self) -> None:
|
||||
"""Check if we need to rotate to a new shard file."""
|
||||
if self.max_records_per_file > 0:
|
||||
if self._records_in_current_shard >= self.max_records_per_file:
|
||||
self._current_shard += 1
|
||||
self._records_in_current_shard = 0
|
||||
|
||||
def write(self, report: AutoLabelReport) -> None:
|
||||
"""Append a report to the output file."""
|
||||
self._check_shard_rotation()
|
||||
shard_path = self._get_shard_path()
|
||||
with open(shard_path, 'a', encoding='utf-8') as f:
|
||||
f.write(report.to_json() + '\n')
|
||||
self._records_in_current_shard += 1
|
||||
|
||||
def write_dict(self, report_dict: dict) -> None:
|
||||
"""Append a report dict to the output file (for parallel processing)."""
|
||||
self._check_shard_rotation()
|
||||
shard_path = self._get_shard_path()
|
||||
with open(shard_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
|
||||
f.flush()
|
||||
self._records_in_current_shard += 1
|
||||
|
||||
def write_batch(self, reports: list[AutoLabelReport]) -> None:
|
||||
"""Write multiple reports."""
|
||||
for report in reports:
|
||||
self.write(report)
|
||||
|
||||
def get_shard_files(self) -> list[Path]:
|
||||
"""Get list of all shard files created."""
|
||||
return self._shard_files.copy()
|
||||
|
||||
|
||||
class ReportReader:
|
||||
"""Reads auto-label reports from file(s)."""
|
||||
|
||||
def __init__(self, input_path: str | Path):
|
||||
"""
|
||||
Initialize report reader.
|
||||
|
||||
Args:
|
||||
input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
|
||||
"""
|
||||
self.input_path = Path(input_path)
|
||||
|
||||
# Handle glob pattern
|
||||
if '*' in str(input_path) or '?' in str(input_path):
|
||||
parent = self.input_path.parent
|
||||
pattern = self.input_path.name
|
||||
self.input_paths = sorted(parent.glob(pattern))
|
||||
else:
|
||||
self.input_paths = [self.input_path]
|
||||
|
||||
def read_all(self) -> list[AutoLabelReport]:
|
||||
"""Read all reports from file(s)."""
|
||||
reports = []
|
||||
|
||||
for input_path in self.input_paths:
|
||||
if not input_path.exists():
|
||||
continue
|
||||
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
data = json.loads(line)
|
||||
report = self._dict_to_report(data)
|
||||
reports.append(report)
|
||||
|
||||
return reports
|
||||
|
||||
def _dict_to_report(self, data: dict) -> AutoLabelReport:
|
||||
"""Convert dictionary to AutoLabelReport."""
|
||||
field_results = []
|
||||
for fr_data in data.get('field_results', []):
|
||||
bbox = tuple(fr_data['bbox']) if fr_data.get('bbox') else None
|
||||
field_results.append(FieldMatchResult(
|
||||
field_name=fr_data['field_name'],
|
||||
csv_value=fr_data.get('csv_value'),
|
||||
matched=fr_data.get('matched', False),
|
||||
score=fr_data.get('score', 0.0),
|
||||
matched_text=fr_data.get('matched_text'),
|
||||
candidate_used=fr_data.get('candidate_used'),
|
||||
bbox=bbox,
|
||||
page_no=fr_data.get('page_no', 0),
|
||||
context_keywords=fr_data.get('context_keywords', []),
|
||||
error=fr_data.get('error')
|
||||
))
|
||||
|
||||
return AutoLabelReport(
|
||||
document_id=data['document_id'],
|
||||
pdf_path=data.get('pdf_path'),
|
||||
pdf_type=data.get('pdf_type'),
|
||||
success=data.get('success', False),
|
||||
total_pages=data.get('total_pages', 0),
|
||||
fields_matched=data.get('fields_matched', 0),
|
||||
fields_total=data.get('fields_total', 0),
|
||||
field_results=field_results,
|
||||
annotations_generated=data.get('annotations_generated', 0),
|
||||
image_paths=data.get('image_paths', []),
|
||||
label_paths=data.get('label_paths', []),
|
||||
processing_time_ms=data.get('processing_time_ms', 0.0),
|
||||
timestamp=data.get('timestamp', ''),
|
||||
errors=data.get('errors', [])
|
||||
)
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
"""Calculate statistics from all reports."""
|
||||
reports = self.read_all()
|
||||
|
||||
if not reports:
|
||||
return {'total': 0}
|
||||
|
||||
successful = sum(1 for r in reports if r.success)
|
||||
total_fields_matched = sum(r.fields_matched for r in reports)
|
||||
total_fields = sum(r.fields_total for r in reports)
|
||||
total_annotations = sum(r.annotations_generated for r in reports)
|
||||
|
||||
# Per-field statistics
|
||||
field_stats = {}
|
||||
for report in reports:
|
||||
for fr in report.field_results:
|
||||
if fr.field_name not in field_stats:
|
||||
field_stats[fr.field_name] = {'matched': 0, 'total': 0, 'avg_score': 0.0}
|
||||
field_stats[fr.field_name]['total'] += 1
|
||||
if fr.matched:
|
||||
field_stats[fr.field_name]['matched'] += 1
|
||||
field_stats[fr.field_name]['avg_score'] += fr.score
|
||||
|
||||
# Calculate averages
|
||||
for field_name, stats in field_stats.items():
|
||||
if stats['matched'] > 0:
|
||||
stats['avg_score'] /= stats['matched']
|
||||
stats['match_rate'] = stats['matched'] / stats['total'] if stats['total'] > 0 else 0
|
||||
|
||||
return {
|
||||
'total': len(reports),
|
||||
'successful': successful,
|
||||
'success_rate': successful / len(reports),
|
||||
'total_fields_matched': total_fields_matched,
|
||||
'total_fields': total_fields,
|
||||
'overall_match_rate': total_fields_matched / total_fields if total_fields > 0 else 0,
|
||||
'total_annotations': total_annotations,
|
||||
'field_statistics': field_stats
|
||||
}
|
||||
134
packages/training/training/data/training_db.py
Normal file
134
packages/training/training/data/training_db.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Database operations for training tasks."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingTaskDB:
|
||||
"""Read/write training_tasks table."""
|
||||
|
||||
def _connect(self):
|
||||
return psycopg2.connect(get_db_connection_string())
|
||||
|
||||
def get_task(self, task_id: str) -> dict | None:
|
||||
"""Get a single training task by ID."""
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM training_tasks WHERE task_id = %s",
|
||||
(task_id,),
|
||||
)
|
||||
return cur.fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_pending_tasks(self, limit: int = 1) -> list[dict]:
|
||||
"""Get pending tasks ordered by creation time."""
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT * FROM training_tasks
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT %s
|
||||
""",
|
||||
(limit,),
|
||||
)
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_status(self, task_id: str, status: str) -> None:
|
||||
"""Update task status with timestamp."""
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
if status == "running":
|
||||
cur.execute(
|
||||
"UPDATE training_tasks SET status = %s, started_at = %s WHERE task_id = %s",
|
||||
(status, datetime.now(timezone.utc), task_id),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"UPDATE training_tasks SET status = %s WHERE task_id = %s",
|
||||
(status, task_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def complete_task(
|
||||
self, task_id: str, model_path: str, metrics: dict
|
||||
) -> None:
|
||||
"""Mark task as completed with results."""
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE training_tasks
|
||||
SET status = 'completed',
|
||||
completed_at = %s,
|
||||
model_path = %s,
|
||||
metrics = %s
|
||||
WHERE task_id = %s
|
||||
""",
|
||||
(
|
||||
datetime.now(timezone.utc),
|
||||
model_path,
|
||||
json.dumps(metrics),
|
||||
task_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def fail_task(self, task_id: str, error_message: str) -> None:
|
||||
"""Mark task as failed."""
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE training_tasks
|
||||
SET status = 'failed',
|
||||
completed_at = %s,
|
||||
error_message = %s
|
||||
WHERE task_id = %s
|
||||
""",
|
||||
(datetime.now(timezone.utc), error_message[:2000], task_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def create_task(self, config: dict) -> str:
|
||||
"""Create a new training task. Returns task_id."""
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO training_tasks (config)
|
||||
VALUES (%s)
|
||||
RETURNING task_id
|
||||
""",
|
||||
(json.dumps(config),),
|
||||
)
|
||||
task_id = str(cur.fetchone()[0])
|
||||
conn.commit()
|
||||
return task_id
|
||||
finally:
|
||||
conn.close()
|
||||
22
packages/training/training/processing/__init__.py
Normal file
22
packages/training/training/processing/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Processing module for multi-pool parallel processing.
|
||||
|
||||
This module provides a robust dual-pool architecture for processing
|
||||
documents with both CPU-bound and GPU-bound tasks.
|
||||
"""
|
||||
|
||||
from training.processing.worker_pool import WorkerPool, TaskResult
|
||||
from training.processing.cpu_pool import CPUWorkerPool
|
||||
from training.processing.gpu_pool import GPUWorkerPool
|
||||
from training.processing.task_dispatcher import TaskDispatcher, TaskType
|
||||
from training.processing.dual_pool_coordinator import DualPoolCoordinator
|
||||
|
||||
__all__ = [
|
||||
"WorkerPool",
|
||||
"TaskResult",
|
||||
"CPUWorkerPool",
|
||||
"GPUWorkerPool",
|
||||
"TaskDispatcher",
|
||||
"TaskType",
|
||||
"DualPoolCoordinator",
|
||||
]
|
||||
323
packages/training/training/processing/autolabel_tasks.py
Normal file
323
packages/training/training/processing/autolabel_tasks.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Task functions for autolabel processing in multi-pool architecture.
|
||||
|
||||
Provides CPU and GPU task functions that can be called from worker pools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
# Global OCR instance (initialized once per GPU worker process)
|
||||
_ocr_engine: Optional[Any] = None
|
||||
|
||||
|
||||
def init_cpu_worker() -> None:
|
||||
"""
|
||||
Initialize CPU worker process.
|
||||
|
||||
Disables GPU access and suppresses unnecessary warnings.
|
||||
"""
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
||||
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
||||
|
||||
|
||||
def init_gpu_worker(gpu_id: int = 0, gpu_mem: int = 4000) -> None:
|
||||
"""
|
||||
Initialize GPU worker process with PaddleOCR.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID.
|
||||
gpu_mem: Maximum GPU memory in MB.
|
||||
"""
|
||||
global _ocr_engine
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
|
||||
# Suppress PaddleX warnings
|
||||
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
|
||||
warnings.filterwarnings("ignore", message=".*reinitialization.*")
|
||||
|
||||
# Lazy initialization - OCR will be created on first use
|
||||
_ocr_engine = None
|
||||
|
||||
|
||||
def _get_ocr_engine():
|
||||
"""Get or create OCR engine for current GPU worker."""
|
||||
global _ocr_engine
|
||||
|
||||
if _ocr_engine is None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
from shared.ocr import OCREngine
|
||||
_ocr_engine = OCREngine()
|
||||
|
||||
return _ocr_engine
|
||||
|
||||
|
||||
def _save_output_img(output_img, image_path: Path) -> None:
|
||||
"""Save OCR preprocessed image to replace rendered image."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
if output_img is not None:
|
||||
img = PILImage.fromarray(output_img)
|
||||
img.save(str(image_path))
|
||||
|
||||
|
||||
def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a text PDF (CPU task - no OCR needed).
|
||||
|
||||
Args:
|
||||
task_data: Dictionary with keys:
|
||||
- row_dict: Document fields from CSV
|
||||
- pdf_path: Path to PDF file
|
||||
- output_dir: Output directory
|
||||
- dpi: Rendering DPI
|
||||
- min_confidence: Minimum match confidence
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status, annotations, and report.
|
||||
"""
|
||||
import shutil
|
||||
from training.data.autolabel_report import AutoLabelReport
|
||||
from shared.pdf import PDFDocument
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||
|
||||
row_dict = task_data["row_dict"]
|
||||
pdf_path = Path(task_data["pdf_path"])
|
||||
output_dir = Path(task_data["output_dir"])
|
||||
dpi = task_data.get("dpi", DEFAULT_DPI)
|
||||
min_confidence = task_data.get("min_confidence", 0.5)
|
||||
|
||||
start_time = time.time()
|
||||
doc_id = row_dict["DocumentId"]
|
||||
|
||||
# Clean up existing temp folder for this document (for re-matching)
|
||||
temp_doc_dir = output_dir / "temp" / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
report.pdf_type = "text"
|
||||
# Store metadata fields from CSV (same as single document mode)
|
||||
report.split = row_dict.get('split')
|
||||
report.customer_number = row_dict.get('customer_number')
|
||||
report.supplier_name = row_dict.get('supplier_name')
|
||||
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
|
||||
report.supplier_accounts = row_dict.get('supplier_accounts')
|
||||
|
||||
result = {
|
||||
"doc_id": doc_id,
|
||||
"success": False,
|
||||
"pages": [],
|
||||
"report": None,
|
||||
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
|
||||
}
|
||||
|
||||
try:
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
images_dir = output_dir / "temp" / doc_id / "images"
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# Text extraction (no OCR)
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
|
||||
# Get page dimensions for payment line detection
|
||||
page = pdf_doc.doc[page_no]
|
||||
page_height = page.rect.height
|
||||
page_width = page.rect.width
|
||||
|
||||
# Use shared processing logic (same as single document mode)
|
||||
matches = {}
|
||||
annotations, ann_count = process_page(
|
||||
tokens=tokens,
|
||||
row_dict=row_dict,
|
||||
page_no=page_no,
|
||||
page_height=page_height,
|
||||
page_width=page_width,
|
||||
img_width=img_width,
|
||||
img_height=img_height,
|
||||
dpi=dpi,
|
||||
min_confidence=min_confidence,
|
||||
matches=matches,
|
||||
matched_fields=matched_fields,
|
||||
report=report,
|
||||
result_stats=result["stats"],
|
||||
)
|
||||
|
||||
if annotations:
|
||||
page_annotations.append(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"page_no": page_no,
|
||||
"count": ann_count,
|
||||
}
|
||||
)
|
||||
report.annotations_generated += ann_count
|
||||
|
||||
# Record unmatched fields using shared logic
|
||||
record_unmatched_fields(row_dict, matched_fields, report)
|
||||
|
||||
if page_annotations:
|
||||
result["pages"] = page_annotations
|
||||
result["success"] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result["report"] = report.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a scanned PDF (GPU task - requires OCR).
|
||||
|
||||
Args:
|
||||
task_data: Dictionary with keys:
|
||||
- row_dict: Document fields from CSV
|
||||
- pdf_path: Path to PDF file
|
||||
- output_dir: Output directory
|
||||
- dpi: Rendering DPI
|
||||
- min_confidence: Minimum match confidence
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status, annotations, and report.
|
||||
"""
|
||||
import shutil
|
||||
from training.data.autolabel_report import AutoLabelReport
|
||||
from shared.pdf import PDFDocument
|
||||
from training.yolo.annotation_generator import FIELD_CLASSES
|
||||
from training.processing.document_processor import process_page, record_unmatched_fields
|
||||
|
||||
row_dict = task_data["row_dict"]
|
||||
pdf_path = Path(task_data["pdf_path"])
|
||||
output_dir = Path(task_data["output_dir"])
|
||||
dpi = task_data.get("dpi", DEFAULT_DPI)
|
||||
min_confidence = task_data.get("min_confidence", 0.5)
|
||||
|
||||
start_time = time.time()
|
||||
doc_id = row_dict["DocumentId"]
|
||||
|
||||
# Clean up existing temp folder for this document (for re-matching)
|
||||
temp_doc_dir = output_dir / "temp" / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
report.pdf_type = "scanned"
|
||||
# Store metadata fields from CSV (same as single document mode)
|
||||
report.split = row_dict.get('split')
|
||||
report.customer_number = row_dict.get('customer_number')
|
||||
report.supplier_name = row_dict.get('supplier_name')
|
||||
report.supplier_organisation_number = row_dict.get('supplier_organisation_number')
|
||||
report.supplier_accounts = row_dict.get('supplier_accounts')
|
||||
|
||||
result = {
|
||||
"doc_id": doc_id,
|
||||
"success": False,
|
||||
"pages": [],
|
||||
"report": None,
|
||||
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
|
||||
}
|
||||
|
||||
try:
|
||||
# Get OCR engine from worker cache
|
||||
ocr_engine = _get_ocr_engine()
|
||||
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
images_dir = output_dir / "temp" / doc_id / "images"
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# Get page dimensions for payment line detection
|
||||
page = pdf_doc.doc[page_no]
|
||||
page_height = page.rect.height
|
||||
page_width = page.rect.width
|
||||
|
||||
# OCR extraction
|
||||
ocr_result = ocr_engine.extract_with_image(
|
||||
str(image_path),
|
||||
page_no,
|
||||
scale_to_pdf_points=72 / dpi,
|
||||
)
|
||||
tokens = ocr_result.tokens
|
||||
|
||||
# Save preprocessed image
|
||||
_save_output_img(ocr_result.output_img, image_path)
|
||||
|
||||
# Update dimensions to match OCR output
|
||||
if ocr_result.output_img is not None:
|
||||
img_height, img_width = ocr_result.output_img.shape[:2]
|
||||
|
||||
# Use shared processing logic (same as single document mode)
|
||||
matches = {}
|
||||
annotations, ann_count = process_page(
|
||||
tokens=tokens,
|
||||
row_dict=row_dict,
|
||||
page_no=page_no,
|
||||
page_height=page_height,
|
||||
page_width=page_width,
|
||||
img_width=img_width,
|
||||
img_height=img_height,
|
||||
dpi=dpi,
|
||||
min_confidence=min_confidence,
|
||||
matches=matches,
|
||||
matched_fields=matched_fields,
|
||||
report=report,
|
||||
result_stats=result["stats"],
|
||||
)
|
||||
|
||||
if annotations:
|
||||
page_annotations.append(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"page_no": page_no,
|
||||
"count": ann_count,
|
||||
}
|
||||
)
|
||||
report.annotations_generated += ann_count
|
||||
|
||||
# Record unmatched fields using shared logic
|
||||
record_unmatched_fields(row_dict, matched_fields, report)
|
||||
|
||||
if page_annotations:
|
||||
result["pages"] = page_annotations
|
||||
result["success"] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result["report"] = report.to_dict()
|
||||
|
||||
return result
|
||||
71
packages/training/training/processing/cpu_pool.py
Normal file
71
packages/training/training/processing/cpu_pool.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
CPU Worker Pool for text PDF processing.
|
||||
|
||||
This pool handles CPU-bound tasks like text extraction from native PDFs
|
||||
that don't require OCR.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
from training.processing.worker_pool import WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global resources for CPU workers (initialized once per process)
|
||||
_cpu_initialized: bool = False
|
||||
|
||||
|
||||
def _init_cpu_worker() -> None:
|
||||
"""
|
||||
Initialize a CPU worker process.
|
||||
|
||||
Disables GPU access and sets up CPU-only environment.
|
||||
"""
|
||||
global _cpu_initialized
|
||||
|
||||
# Disable GPU access for CPU workers
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
# Set threading limits for better CPU utilization
|
||||
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
||||
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
||||
|
||||
_cpu_initialized = True
|
||||
logger.debug(f"CPU worker initialized in process {os.getpid()}")
|
||||
|
||||
|
||||
class CPUWorkerPool(WorkerPool):
|
||||
"""
|
||||
Worker pool for CPU-bound tasks.
|
||||
|
||||
Handles text PDF processing that doesn't require OCR.
|
||||
Each worker is initialized with CUDA disabled to prevent
|
||||
accidental GPU memory consumption.
|
||||
|
||||
Example:
|
||||
with CPUWorkerPool(max_workers=4) as pool:
|
||||
future = pool.submit(process_text_pdf, pdf_path)
|
||||
result = future.result()
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 4) -> None:
|
||||
"""
|
||||
Initialize CPU worker pool.
|
||||
|
||||
Args:
|
||||
max_workers: Number of CPU worker processes.
|
||||
Defaults to 4 for balanced performance.
|
||||
"""
|
||||
super().__init__(max_workers=max_workers, use_gpu=False, gpu_id=-1)
|
||||
|
||||
def get_initializer(self) -> Optional[Callable[..., None]]:
|
||||
"""Return the CPU worker initializer."""
|
||||
return _init_cpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
"""Return empty args for CPU initializer."""
|
||||
return ()
|
||||
448
packages/training/training/processing/document_processor.py
Normal file
448
packages/training/training/processing/document_processor.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Shared document processing logic for autolabel.
|
||||
|
||||
This module provides the core processing functions used by both
|
||||
single document mode and batch processing mode to ensure consistent
|
||||
matching and annotation logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from training.data.autolabel_report import FieldMatchResult
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.normalize import normalize_field
|
||||
from shared.ocr.machine_code_parser import MachineCodeParser
|
||||
from training.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
|
||||
def match_supplier_accounts(
|
||||
tokens: list,
|
||||
supplier_accounts_value: str,
|
||||
matcher: FieldMatcher,
|
||||
page_no: int,
|
||||
matches: Dict[str, list],
|
||||
matched_fields: Set[str],
|
||||
report: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Match supplier_accounts field and map to Bankgiro/Plusgiro.
|
||||
|
||||
This logic is shared between single document mode and batch mode
|
||||
to ensure consistent BG/PG type detection.
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens from the page
|
||||
supplier_accounts_value: Raw value from CSV (e.g., "BG:xxx | PG:yyy")
|
||||
matcher: FieldMatcher instance
|
||||
page_no: Current page number
|
||||
matches: Dictionary to store matched fields (modified in place)
|
||||
matched_fields: Set of matched field names (modified in place)
|
||||
report: AutoLabelReport instance
|
||||
"""
|
||||
if not supplier_accounts_value:
|
||||
return
|
||||
|
||||
# Parse accounts: "BG:xxx | PG:yyy" format
|
||||
accounts = [acc.strip() for acc in str(supplier_accounts_value).split('|')]
|
||||
|
||||
for account in accounts:
|
||||
account = account.strip()
|
||||
if not account:
|
||||
continue
|
||||
|
||||
# Determine account type (BG or PG) and extract account number
|
||||
account_type = None
|
||||
account_number = account # Default to full value
|
||||
|
||||
if account.upper().startswith('BG:'):
|
||||
account_type = 'Bankgiro'
|
||||
account_number = account[3:].strip() # Remove "BG:" prefix
|
||||
elif account.upper().startswith('BG '):
|
||||
account_type = 'Bankgiro'
|
||||
account_number = account[2:].strip() # Remove "BG" prefix
|
||||
elif account.upper().startswith('PG:'):
|
||||
account_type = 'Plusgiro'
|
||||
account_number = account[3:].strip() # Remove "PG:" prefix
|
||||
elif account.upper().startswith('PG '):
|
||||
account_type = 'Plusgiro'
|
||||
account_number = account[2:].strip() # Remove "PG" prefix
|
||||
else:
|
||||
# Try to guess from format - Plusgiro often has format XXXXXXX-X
|
||||
digits = ''.join(c for c in account if c.isdigit())
|
||||
if len(digits) == 8 and '-' in account:
|
||||
account_type = 'Plusgiro'
|
||||
elif len(digits) in (7, 8):
|
||||
account_type = 'Bankgiro' # Default to Bankgiro
|
||||
|
||||
if not account_type:
|
||||
continue
|
||||
|
||||
# Normalize and match using the account number (without prefix)
|
||||
normalized = normalize_field('supplier_accounts', account_number)
|
||||
field_matches = matcher.find_matches(tokens, account_type, normalized, page_no)
|
||||
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
# Add to matches under the target class (Bankgiro/Plusgiro)
|
||||
if account_type not in matches:
|
||||
matches[account_type] = []
|
||||
matches[account_type].extend(field_matches)
|
||||
matched_fields.add('supplier_accounts')
|
||||
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=f'supplier_accounts({account_type})',
|
||||
csv_value=account_number, # Store without prefix
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
|
||||
|
||||
def detect_payment_line(
|
||||
tokens: list,
|
||||
page_height: float,
|
||||
page_width: float,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Detect payment line (machine code) and return the parsed result.
|
||||
|
||||
This function only detects and parses the payment line, without generating
|
||||
annotations. The caller can use the result to extract amount for cross-validation.
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens from the page
|
||||
page_height: Page height in PDF points
|
||||
page_width: Page width in PDF points
|
||||
|
||||
Returns:
|
||||
MachineCodeResult if standard format detected (confidence >= 0.95), None otherwise
|
||||
"""
|
||||
# Use 55% of page height as bottom region to catch payment lines
|
||||
# that may be in the middle of the page (e.g., payment slips)
|
||||
mc_parser = MachineCodeParser(bottom_region_ratio=0.55)
|
||||
mc_result = mc_parser.parse(tokens, page_height, page_width)
|
||||
|
||||
# Only return if we found a STANDARD payment line format
|
||||
# (confidence 0.95 means standard pattern matched with # and > symbols)
|
||||
is_standard_format = mc_result.confidence >= 0.95
|
||||
if is_standard_format:
|
||||
return mc_result
|
||||
return None
|
||||
|
||||
|
||||
def match_payment_line(
|
||||
tokens: list,
|
||||
page_height: float,
|
||||
page_width: float,
|
||||
min_confidence: float,
|
||||
generator: AnnotationGenerator,
|
||||
annotations: list,
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
dpi: int,
|
||||
matched_fields: Set[str],
|
||||
report: Any,
|
||||
page_no: int,
|
||||
mc_result: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Annotate payment line (machine code) using pre-detected result.
|
||||
|
||||
This logic is shared between single document mode and batch mode
|
||||
to ensure consistent payment_line detection.
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens from the page
|
||||
page_height: Page height in PDF points
|
||||
page_width: Page width in PDF points
|
||||
min_confidence: Minimum confidence threshold
|
||||
generator: AnnotationGenerator instance
|
||||
annotations: List of annotations (modified in place)
|
||||
img_width: Image width in pixels
|
||||
img_height: Image height in pixels
|
||||
dpi: DPI used for rendering
|
||||
matched_fields: Set of matched field names (modified in place)
|
||||
report: AutoLabelReport instance
|
||||
page_no: Current page number
|
||||
mc_result: Pre-detected MachineCodeResult (from detect_payment_line)
|
||||
"""
|
||||
# Use pre-detected result if provided, otherwise detect now
|
||||
if mc_result is None:
|
||||
mc_result = detect_payment_line(tokens, page_height, page_width)
|
||||
|
||||
# Only add payment_line if we have a valid standard format result
|
||||
if mc_result is None:
|
||||
return
|
||||
|
||||
if mc_result.confidence >= min_confidence:
|
||||
region_bbox = mc_result.get_region_bbox()
|
||||
if region_bbox:
|
||||
generator.add_payment_line_annotation(
|
||||
annotations, region_bbox, mc_result.confidence,
|
||||
img_width, img_height, dpi=dpi
|
||||
)
|
||||
# Store payment_line result in database
|
||||
matched_fields.add('payment_line')
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name='payment_line',
|
||||
csv_value=mc_result.raw_line[:200] if mc_result.raw_line else '',
|
||||
matched=True,
|
||||
score=mc_result.confidence,
|
||||
matched_text=f"OCR:{mc_result.ocr or ''} Amount:{mc_result.amount or ''} BG:{mc_result.bankgiro or ''}",
|
||||
candidate_used='machine_code_parser',
|
||||
bbox=region_bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=['payment_line', 'machine_code']
|
||||
))
|
||||
|
||||
|
||||
def match_standard_fields(
|
||||
tokens: list,
|
||||
row_dict: Dict[str, Any],
|
||||
matcher: FieldMatcher,
|
||||
page_no: int,
|
||||
matches: Dict[str, list],
|
||||
matched_fields: Set[str],
|
||||
report: Any,
|
||||
payment_line_amount: Optional[str] = None,
|
||||
payment_line_bbox: Optional[tuple] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Match standard fields from CSV to tokens.
|
||||
|
||||
This excludes payment_line (detected separately) and supplier_accounts
|
||||
(handled by match_supplier_accounts).
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens from the page
|
||||
row_dict: Dictionary of field values from CSV
|
||||
matcher: FieldMatcher instance
|
||||
page_no: Current page number
|
||||
matches: Dictionary to store matched fields (modified in place)
|
||||
matched_fields: Set of matched field names (modified in place)
|
||||
report: AutoLabelReport instance
|
||||
payment_line_amount: Amount extracted from payment_line (takes priority over CSV)
|
||||
payment_line_bbox: Bounding box of payment_line region (used as fallback for Amount)
|
||||
"""
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
# Skip fields handled separately
|
||||
if field_name == 'payment_line':
|
||||
continue
|
||||
if field_name in ('Bankgiro', 'Plusgiro'):
|
||||
continue # Handled via supplier_accounts
|
||||
|
||||
value = row_dict.get(field_name)
|
||||
|
||||
# For Amount field: only use payment_line amount if it matches CSV value
|
||||
use_payment_line_amount = False
|
||||
if field_name == 'Amount' and payment_line_amount and value:
|
||||
# Parse both amounts and check if they're close
|
||||
try:
|
||||
csv_amt = float(str(value).replace(',', '.').replace(' ', ''))
|
||||
pl_amt = float(str(payment_line_amount).replace(',', '.').replace(' ', ''))
|
||||
if abs(csv_amt - pl_amt) < 0.01:
|
||||
# Payment line amount matches CSV, use it for better bbox
|
||||
value = payment_line_amount
|
||||
use_payment_line_amount = True
|
||||
# Otherwise keep CSV value for matching
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
matched_fields.add(field_name)
|
||||
|
||||
# For Amount: note if we used payment_line amount
|
||||
csv_value_display = str(row_dict.get(field_name, value))
|
||||
if field_name == 'Amount' and use_payment_line_amount:
|
||||
csv_value_display = f"{row_dict.get(field_name)} (matched via payment_line: {payment_line_amount})"
|
||||
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=csv_value_display,
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
elif field_name == 'Amount' and use_payment_line_amount and payment_line_bbox:
|
||||
# Fallback: Amount not found via token matching, but payment_line
|
||||
# successfully extracted a matching amount. Use payment_line bbox.
|
||||
# This handles cases where text PDFs merge multiple values into one token.
|
||||
from shared.matcher.field_matcher import Match
|
||||
|
||||
fallback_match = Match(
|
||||
field='Amount',
|
||||
value=payment_line_amount,
|
||||
bbox=payment_line_bbox,
|
||||
page_no=page_no,
|
||||
score=0.9,
|
||||
matched_text=f"Amount:{payment_line_amount}",
|
||||
context_keywords=['payment_line', 'amount']
|
||||
)
|
||||
matches[field_name] = [fallback_match]
|
||||
matched_fields.add(field_name)
|
||||
csv_value_display = f"{row_dict.get(field_name)} (via payment_line: {payment_line_amount})"
|
||||
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=csv_value_display,
|
||||
matched=True,
|
||||
score=0.9, # High confidence since payment_line parsing succeeded
|
||||
matched_text=f"Amount:{payment_line_amount}",
|
||||
candidate_used='payment_line_fallback',
|
||||
bbox=payment_line_bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=['payment_line', 'amount']
|
||||
))
|
||||
|
||||
|
||||
def record_unmatched_fields(
|
||||
row_dict: Dict[str, Any],
|
||||
matched_fields: Set[str],
|
||||
report: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Record fields from CSV that were not matched.
|
||||
|
||||
Args:
|
||||
row_dict: Dictionary of field values from CSV
|
||||
matched_fields: Set of matched field names
|
||||
report: AutoLabelReport instance
|
||||
"""
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
if field_name == 'payment_line':
|
||||
continue # payment_line doesn't come from CSV
|
||||
if field_name in ('Bankgiro', 'Plusgiro'):
|
||||
continue # These come from supplier_accounts
|
||||
|
||||
value = row_dict.get(field_name)
|
||||
if value and field_name not in matched_fields:
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=-1
|
||||
))
|
||||
|
||||
# Check if supplier_accounts was not matched
|
||||
if row_dict.get('supplier_accounts') and 'supplier_accounts' not in matched_fields:
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name='supplier_accounts',
|
||||
csv_value=str(row_dict.get('supplier_accounts')),
|
||||
matched=False,
|
||||
page_no=-1
|
||||
))
|
||||
|
||||
|
||||
def process_page(
|
||||
tokens: list,
|
||||
row_dict: Dict[str, Any],
|
||||
page_no: int,
|
||||
page_height: float,
|
||||
page_width: float,
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
dpi: int,
|
||||
min_confidence: float,
|
||||
matches: Dict[str, list],
|
||||
matched_fields: Set[str],
|
||||
report: Any,
|
||||
result_stats: Dict[str, int],
|
||||
) -> Tuple[list, int]:
|
||||
"""
|
||||
Process a single page: match fields and generate annotations.
|
||||
|
||||
This is the main entry point for page processing, used by both
|
||||
single document mode and batch mode.
|
||||
|
||||
Processing order:
|
||||
1. Detect payment_line first to extract amount
|
||||
2. Match standard fields (using payment_line amount if available)
|
||||
3. Match supplier_accounts
|
||||
4. Generate annotations
|
||||
|
||||
Args:
|
||||
tokens: List of text tokens from the page
|
||||
row_dict: Dictionary of field values from CSV
|
||||
page_no: Current page number
|
||||
page_height: Page height in PDF points
|
||||
page_width: Page width in PDF points
|
||||
img_width: Image width in pixels
|
||||
img_height: Image height in pixels
|
||||
dpi: DPI used for rendering
|
||||
min_confidence: Minimum confidence threshold
|
||||
matches: Dictionary to store matched fields (modified in place)
|
||||
matched_fields: Set of matched field names (modified in place)
|
||||
report: AutoLabelReport instance
|
||||
result_stats: Dictionary of annotation stats (modified in place)
|
||||
|
||||
Returns:
|
||||
Tuple of (annotations list, annotation count)
|
||||
"""
|
||||
matcher = FieldMatcher()
|
||||
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||
|
||||
# Step 1: Detect payment_line FIRST to extract amount
|
||||
# This allows us to use the payment_line amount for matching Amount field
|
||||
mc_result = detect_payment_line(tokens, page_height, page_width)
|
||||
|
||||
# Extract amount and bbox from payment_line if available
|
||||
payment_line_amount = None
|
||||
payment_line_bbox = None
|
||||
if mc_result and mc_result.amount:
|
||||
payment_line_amount = mc_result.amount
|
||||
payment_line_bbox = mc_result.get_region_bbox()
|
||||
|
||||
# Step 2: Match standard fields (using payment_line amount if available)
|
||||
match_standard_fields(
|
||||
tokens, row_dict, matcher, page_no,
|
||||
matches, matched_fields, report,
|
||||
payment_line_amount=payment_line_amount,
|
||||
payment_line_bbox=payment_line_bbox
|
||||
)
|
||||
|
||||
# Step 3: Match supplier_accounts -> Bankgiro/Plusgiro
|
||||
supplier_accounts_value = row_dict.get('supplier_accounts')
|
||||
if supplier_accounts_value:
|
||||
match_supplier_accounts(
|
||||
tokens, supplier_accounts_value, matcher, page_no,
|
||||
matches, matched_fields, report
|
||||
)
|
||||
|
||||
# Generate annotations from matches
|
||||
annotations = generator.generate_from_matches(
|
||||
matches, img_width, img_height, dpi=dpi
|
||||
)
|
||||
|
||||
# Step 4: Add payment_line annotation (reuse the pre-detected result)
|
||||
match_payment_line(
|
||||
tokens, page_height, page_width, min_confidence,
|
||||
generator, annotations, img_width, img_height, dpi,
|
||||
matched_fields, report, page_no,
|
||||
mc_result=mc_result
|
||||
)
|
||||
|
||||
# Update stats
|
||||
for ann in annotations:
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result_stats[class_name] += 1
|
||||
|
||||
return annotations, len(annotations)
|
||||
339
packages/training/training/processing/dual_pool_coordinator.py
Normal file
339
packages/training/training/processing/dual_pool_coordinator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Dual Pool Coordinator for managing CPU and GPU worker pools.
|
||||
|
||||
Coordinates task distribution between CPU and GPU pools, handles result
|
||||
collection using as_completed(), and provides callbacks for progress tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import Future, TimeoutError, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from training.processing.cpu_pool import CPUWorkerPool
|
||||
from training.processing.gpu_pool import GPUWorkerPool
|
||||
from training.processing.task_dispatcher import Task, TaskDispatcher, TaskType
|
||||
from training.processing.worker_pool import TaskResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStats:
|
||||
"""Statistics for a batch processing run."""
|
||||
|
||||
total: int = 0
|
||||
cpu_submitted: int = 0
|
||||
gpu_submitted: int = 0
|
||||
successful: int = 0
|
||||
failed: int = 0
|
||||
cpu_time: float = 0.0
|
||||
gpu_time: float = 0.0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate as percentage."""
|
||||
if self.total == 0:
|
||||
return 0.0
|
||||
return (self.successful / self.total) * 100
|
||||
|
||||
|
||||
class DualPoolCoordinator:
|
||||
"""
|
||||
Coordinates CPU and GPU worker pools for parallel document processing.
|
||||
|
||||
Uses separate ProcessPoolExecutor instances for CPU and GPU tasks,
|
||||
with as_completed() for efficient result collection across both pools.
|
||||
|
||||
Key features:
|
||||
- Automatic task classification (CPU vs GPU)
|
||||
- Parallel submission to both pools
|
||||
- Unified result collection with timeouts
|
||||
- Progress callbacks for UI integration
|
||||
- Proper resource cleanup
|
||||
|
||||
Example:
|
||||
with DualPoolCoordinator(cpu_workers=4, gpu_workers=1) as coord:
|
||||
results = coord.process_batch(
|
||||
documents=docs,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=lambda r: save_to_db(r),
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_workers: int = 4,
|
||||
gpu_workers: int = 1,
|
||||
gpu_id: int = 0,
|
||||
task_timeout: float = 300.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dual pool coordinator.
|
||||
|
||||
Args:
|
||||
cpu_workers: Number of CPU worker processes.
|
||||
gpu_workers: Number of GPU worker processes (usually 1).
|
||||
gpu_id: GPU device ID to use.
|
||||
task_timeout: Timeout in seconds for individual tasks.
|
||||
"""
|
||||
self.cpu_workers = cpu_workers
|
||||
self.gpu_workers = gpu_workers
|
||||
self.gpu_id = gpu_id
|
||||
self.task_timeout = task_timeout
|
||||
|
||||
self._cpu_pool: Optional[CPUWorkerPool] = None
|
||||
self._gpu_pool: Optional[GPUWorkerPool] = None
|
||||
self._dispatcher = TaskDispatcher()
|
||||
self._started = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start both worker pools."""
|
||||
if self._started:
|
||||
raise RuntimeError("Coordinator already started")
|
||||
|
||||
logger.info(
|
||||
f"Starting DualPoolCoordinator: "
|
||||
f"{self.cpu_workers} CPU workers, {self.gpu_workers} GPU workers"
|
||||
)
|
||||
|
||||
self._cpu_pool = CPUWorkerPool(max_workers=self.cpu_workers)
|
||||
self._gpu_pool = GPUWorkerPool(
|
||||
max_workers=self.gpu_workers,
|
||||
gpu_id=self.gpu_id,
|
||||
)
|
||||
|
||||
self._cpu_pool.start()
|
||||
self._gpu_pool.start()
|
||||
self._started = True
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shutdown both worker pools."""
|
||||
logger.info("Shutting down DualPoolCoordinator")
|
||||
|
||||
if self._cpu_pool is not None:
|
||||
self._cpu_pool.shutdown(wait=wait)
|
||||
self._cpu_pool = None
|
||||
|
||||
if self._gpu_pool is not None:
|
||||
self._gpu_pool.shutdown(wait=wait)
|
||||
self._gpu_pool = None
|
||||
|
||||
self._started = False
|
||||
|
||||
def __enter__(self) -> "DualPoolCoordinator":
|
||||
"""Context manager entry."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Context manager exit."""
|
||||
self.shutdown(wait=True)
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
documents: List[dict],
|
||||
cpu_task_fn: Callable[[dict], Any],
|
||||
gpu_task_fn: Callable[[dict], Any],
|
||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
||||
on_error: Optional[Callable[[str, Exception], None]] = None,
|
||||
on_progress: Optional[Callable[[int, int], None]] = None,
|
||||
id_field: str = "id",
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
Process a batch of documents using both CPU and GPU pools.
|
||||
|
||||
Documents are automatically classified and routed to the appropriate
|
||||
pool. Results are collected as they complete.
|
||||
|
||||
Args:
|
||||
documents: List of document info dicts to process.
|
||||
cpu_task_fn: Function to process text PDFs (called in CPU pool).
|
||||
gpu_task_fn: Function to process scanned PDFs (called in GPU pool).
|
||||
on_result: Callback for each successful result.
|
||||
on_error: Callback for each failed task (task_id, exception).
|
||||
on_progress: Callback for progress updates (completed, total).
|
||||
id_field: Field name to use as task ID.
|
||||
|
||||
Returns:
|
||||
List of TaskResult objects for all tasks.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If coordinator is not started.
|
||||
"""
|
||||
if not self._started:
|
||||
raise RuntimeError("Coordinator not started. Use context manager or call start().")
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
stats = BatchStats(total=len(documents))
|
||||
|
||||
# Create and partition tasks
|
||||
tasks = self._dispatcher.create_tasks(documents, id_field=id_field)
|
||||
cpu_tasks, gpu_tasks = self._dispatcher.partition_tasks(tasks)
|
||||
|
||||
# Submit tasks to pools
|
||||
futures_map: Dict[Future, Task] = {}
|
||||
|
||||
# Submit CPU tasks
|
||||
cpu_start = time.time()
|
||||
for task in cpu_tasks:
|
||||
future = self._cpu_pool.submit(cpu_task_fn, task.data)
|
||||
futures_map[future] = task
|
||||
stats.cpu_submitted += 1
|
||||
|
||||
# Submit GPU tasks
|
||||
gpu_start = time.time()
|
||||
for task in gpu_tasks:
|
||||
future = self._gpu_pool.submit(gpu_task_fn, task.data)
|
||||
futures_map[future] = task
|
||||
stats.gpu_submitted += 1
|
||||
|
||||
logger.info(
|
||||
f"Submitted {stats.cpu_submitted} CPU tasks, {stats.gpu_submitted} GPU tasks"
|
||||
)
|
||||
|
||||
# Collect results as they complete
|
||||
results: List[TaskResult] = []
|
||||
completed = 0
|
||||
|
||||
for future in as_completed(futures_map.keys(), timeout=self.task_timeout * len(documents)):
|
||||
task = futures_map[future]
|
||||
pool_type = "CPU" if task.task_type == TaskType.CPU else "GPU"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
data = future.result(timeout=self.task_timeout)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
result = TaskResult(
|
||||
task_id=task.id,
|
||||
success=True,
|
||||
data=data,
|
||||
pool_type=pool_type,
|
||||
processing_time=processing_time,
|
||||
)
|
||||
stats.successful += 1
|
||||
|
||||
if pool_type == "CPU":
|
||||
stats.cpu_time += processing_time
|
||||
else:
|
||||
stats.gpu_time += processing_time
|
||||
|
||||
if on_result is not None:
|
||||
try:
|
||||
on_result(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"on_result callback failed: {e}")
|
||||
|
||||
except TimeoutError:
|
||||
error_msg = f"Task timed out after {self.task_timeout}s"
|
||||
logger.error(f"[{pool_type}] Task {task.id}: {error_msg}")
|
||||
|
||||
result = TaskResult(
|
||||
task_id=task.id,
|
||||
success=False,
|
||||
data=None,
|
||||
error=error_msg,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
stats.failed += 1
|
||||
stats.errors.append(f"{task.id}: {error_msg}")
|
||||
|
||||
if on_error is not None:
|
||||
try:
|
||||
on_error(task.id, TimeoutError(error_msg))
|
||||
except Exception as e:
|
||||
logger.warning(f"on_error callback failed: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[{pool_type}] Task {task.id} failed: {error_msg}")
|
||||
|
||||
result = TaskResult(
|
||||
task_id=task.id,
|
||||
success=False,
|
||||
data=None,
|
||||
error=error_msg,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
stats.failed += 1
|
||||
stats.errors.append(f"{task.id}: {error_msg}")
|
||||
|
||||
if on_error is not None:
|
||||
try:
|
||||
on_error(task.id, e)
|
||||
except Exception as callback_error:
|
||||
logger.warning(f"on_error callback failed: {callback_error}")
|
||||
|
||||
results.append(result)
|
||||
completed += 1
|
||||
|
||||
if on_progress is not None:
|
||||
try:
|
||||
on_progress(completed, stats.total)
|
||||
except Exception as e:
|
||||
logger.warning(f"on_progress callback failed: {e}")
|
||||
|
||||
# Log final stats
|
||||
logger.info(
|
||||
f"Batch complete: {stats.successful}/{stats.total} successful "
|
||||
f"({stats.success_rate:.1f}%), {stats.failed} failed"
|
||||
)
|
||||
if stats.cpu_submitted > 0:
|
||||
logger.info(f"CPU: {stats.cpu_submitted} tasks, {stats.cpu_time:.2f}s total")
|
||||
if stats.gpu_submitted > 0:
|
||||
logger.info(f"GPU: {stats.gpu_submitted} tasks, {stats.gpu_time:.2f}s total")
|
||||
|
||||
return results
|
||||
|
||||
def process_single(
|
||||
self,
|
||||
document: dict,
|
||||
cpu_task_fn: Callable[[dict], Any],
|
||||
gpu_task_fn: Callable[[dict], Any],
|
||||
id_field: str = "id",
|
||||
) -> TaskResult:
|
||||
"""
|
||||
Process a single document.
|
||||
|
||||
Convenience method for processing one document at a time.
|
||||
|
||||
Args:
|
||||
document: Document info dict.
|
||||
cpu_task_fn: Function for text PDF processing.
|
||||
gpu_task_fn: Function for scanned PDF processing.
|
||||
id_field: Field name for task ID.
|
||||
|
||||
Returns:
|
||||
TaskResult for the document.
|
||||
"""
|
||||
results = self.process_batch(
|
||||
documents=[document],
|
||||
cpu_task_fn=cpu_task_fn,
|
||||
gpu_task_fn=gpu_task_fn,
|
||||
id_field=id_field,
|
||||
)
|
||||
return results[0] if results else TaskResult(
|
||||
task_id=str(document.get(id_field, "unknown")),
|
||||
success=False,
|
||||
data=None,
|
||||
error="No result returned",
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if both pools are running."""
|
||||
return (
|
||||
self._started
|
||||
and self._cpu_pool is not None
|
||||
and self._gpu_pool is not None
|
||||
and self._cpu_pool.is_running
|
||||
and self._gpu_pool.is_running
|
||||
)
|
||||
110
packages/training/training/processing/gpu_pool.py
Normal file
110
packages/training/training/processing/gpu_pool.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
GPU Worker Pool for OCR processing.
|
||||
|
||||
This pool handles GPU-bound tasks like PaddleOCR for scanned PDF processing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from training.processing.worker_pool import WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global OCR instance for GPU workers (initialized once per process)
|
||||
_ocr_instance: Optional[Any] = None
|
||||
_gpu_initialized: bool = False
|
||||
|
||||
|
||||
def _init_gpu_worker(gpu_id: int = 0) -> None:
|
||||
"""
|
||||
Initialize a GPU worker process with PaddleOCR.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use.
|
||||
"""
|
||||
global _ocr_instance, _gpu_initialized
|
||||
|
||||
# Set GPU device before importing paddle
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
# Reduce logging noise
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
|
||||
# Suppress PaddleX warnings
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
|
||||
warnings.filterwarnings("ignore", message=".*reinitialization.*")
|
||||
|
||||
try:
|
||||
# Import PaddleOCR after setting environment
|
||||
# PaddleOCR 3.x uses paddle.set_device() for GPU control, not use_gpu param
|
||||
import paddle
|
||||
paddle.set_device(f"gpu:{gpu_id}")
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
# PaddleOCR 3.x init - minimal params, GPU controlled via paddle.set_device
|
||||
_ocr_instance = PaddleOCR(lang="en")
|
||||
_gpu_initialized = True
|
||||
logger.info(f"GPU worker initialized on GPU {gpu_id} in process {os.getpid()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU worker: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_ocr_instance() -> Any:
|
||||
"""
|
||||
Get the initialized OCR instance for the current worker.
|
||||
|
||||
Returns:
|
||||
PaddleOCR instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If OCR is not initialized.
|
||||
"""
|
||||
global _ocr_instance
|
||||
if _ocr_instance is None:
|
||||
raise RuntimeError("OCR not initialized. This function must be called from a GPU worker.")
|
||||
return _ocr_instance
|
||||
|
||||
|
||||
class GPUWorkerPool(WorkerPool):
|
||||
"""
|
||||
Worker pool for GPU-bound OCR tasks.
|
||||
|
||||
Handles scanned PDF processing using PaddleOCR with GPU acceleration.
|
||||
Typically limited to 1 worker to avoid GPU memory conflicts.
|
||||
|
||||
Example:
|
||||
with GPUWorkerPool(max_workers=1, gpu_id=0) as pool:
|
||||
future = pool.submit(process_scanned_pdf, pdf_path)
|
||||
result = future.result()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int = 1,
|
||||
gpu_id: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GPU worker pool.
|
||||
|
||||
Args:
|
||||
max_workers: Number of GPU worker processes.
|
||||
Defaults to 1 to avoid GPU memory conflicts.
|
||||
gpu_id: GPU device ID to use.
|
||||
"""
|
||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
||||
|
||||
def get_initializer(self) -> Optional[Callable[..., None]]:
|
||||
"""Return the GPU worker initializer."""
|
||||
return _init_gpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
"""Return args for GPU initializer."""
|
||||
return (self.gpu_id,)
|
||||
174
packages/training/training/processing/task_dispatcher.py
Normal file
174
packages/training/training/processing/task_dispatcher.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Task Dispatcher for classifying and routing tasks to appropriate worker pools.
|
||||
|
||||
Determines whether a document should be processed by CPU (text PDF) or
|
||||
GPU (scanned PDF requiring OCR) workers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Task type classification."""
|
||||
|
||||
CPU = auto() # Text PDF - no OCR needed
|
||||
GPU = auto() # Scanned PDF - requires OCR
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a processing task.
|
||||
|
||||
Attributes:
|
||||
id: Unique task identifier.
|
||||
task_type: Whether task needs CPU or GPU processing.
|
||||
data: Task payload (document info, paths, etc.).
|
||||
"""
|
||||
|
||||
id: str
|
||||
task_type: TaskType
|
||||
data: Any
|
||||
|
||||
|
||||
class TaskDispatcher:
|
||||
"""
|
||||
Classifies and partitions tasks for CPU and GPU worker pools.
|
||||
|
||||
Uses PDF characteristics to determine if OCR is needed:
|
||||
- Text PDFs with extractable text -> CPU
|
||||
- Scanned PDFs / image-based PDFs -> GPU (OCR)
|
||||
|
||||
Example:
|
||||
dispatcher = TaskDispatcher()
|
||||
tasks = [Task(id="1", task_type=dispatcher.classify(doc), data=doc) for doc in docs]
|
||||
cpu_tasks, gpu_tasks = dispatcher.partition_tasks(tasks)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_char_threshold: int = 100,
|
||||
ocr_ratio_threshold: float = 0.3,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the task dispatcher.
|
||||
|
||||
Args:
|
||||
text_char_threshold: Minimum characters to consider as text PDF.
|
||||
ocr_ratio_threshold: If text/expected ratio below this, use OCR.
|
||||
"""
|
||||
self.text_char_threshold = text_char_threshold
|
||||
self.ocr_ratio_threshold = ocr_ratio_threshold
|
||||
|
||||
def classify_by_pdf_info(
|
||||
self,
|
||||
has_text: bool,
|
||||
text_length: int,
|
||||
page_count: int = 1,
|
||||
) -> TaskType:
|
||||
"""
|
||||
Classify task based on PDF text extraction info.
|
||||
|
||||
Args:
|
||||
has_text: Whether PDF has extractable text layer.
|
||||
text_length: Number of characters extracted.
|
||||
page_count: Number of pages in PDF.
|
||||
|
||||
Returns:
|
||||
TaskType.CPU for text PDFs, TaskType.GPU for scanned PDFs.
|
||||
"""
|
||||
if not has_text:
|
||||
return TaskType.GPU
|
||||
|
||||
# Check if text density is reasonable
|
||||
avg_chars_per_page = text_length / max(page_count, 1)
|
||||
if avg_chars_per_page < self.text_char_threshold:
|
||||
return TaskType.GPU
|
||||
|
||||
return TaskType.CPU
|
||||
|
||||
def classify_document(self, doc_info: dict) -> TaskType:
|
||||
"""
|
||||
Classify a document based on its metadata.
|
||||
|
||||
Args:
|
||||
doc_info: Document information dict with keys like:
|
||||
- 'is_scanned': bool (if known)
|
||||
- 'text_length': int
|
||||
- 'page_count': int
|
||||
- 'pdf_path': str
|
||||
|
||||
Returns:
|
||||
TaskType for the document.
|
||||
"""
|
||||
# If explicitly marked as scanned
|
||||
if doc_info.get("is_scanned", False):
|
||||
return TaskType.GPU
|
||||
|
||||
# If we have text extraction info
|
||||
text_length = doc_info.get("text_length", 0)
|
||||
page_count = doc_info.get("page_count", 1)
|
||||
has_text = doc_info.get("has_text", text_length > 0)
|
||||
|
||||
return self.classify_by_pdf_info(
|
||||
has_text=has_text,
|
||||
text_length=text_length,
|
||||
page_count=page_count,
|
||||
)
|
||||
|
||||
def partition_tasks(
|
||||
self,
|
||||
tasks: List[Task],
|
||||
) -> Tuple[List[Task], List[Task]]:
|
||||
"""
|
||||
Partition tasks into CPU and GPU groups.
|
||||
|
||||
Args:
|
||||
tasks: List of Task objects with task_type set.
|
||||
|
||||
Returns:
|
||||
Tuple of (cpu_tasks, gpu_tasks).
|
||||
"""
|
||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
||||
|
||||
logger.info(
|
||||
f"Task partition: {len(cpu_tasks)} CPU tasks, {len(gpu_tasks)} GPU tasks"
|
||||
)
|
||||
|
||||
return cpu_tasks, gpu_tasks
|
||||
|
||||
def create_tasks(
|
||||
self,
|
||||
documents: List[dict],
|
||||
id_field: str = "id",
|
||||
) -> List[Task]:
|
||||
"""
|
||||
Create Task objects from document dicts.
|
||||
|
||||
Args:
|
||||
documents: List of document info dicts.
|
||||
id_field: Field name to use as task ID.
|
||||
|
||||
Returns:
|
||||
List of Task objects with types classified.
|
||||
"""
|
||||
tasks = []
|
||||
for doc in documents:
|
||||
task_id = str(doc.get(id_field, id(doc)))
|
||||
task_type = self.classify_document(doc)
|
||||
tasks.append(Task(id=task_id, task_type=task_type, data=doc))
|
||||
|
||||
cpu_count = sum(1 for t in tasks if t.task_type == TaskType.CPU)
|
||||
gpu_count = len(tasks) - cpu_count
|
||||
logger.debug(f"Created {len(tasks)} tasks: {cpu_count} CPU, {gpu_count} GPU")
|
||||
|
||||
return tasks
|
||||
182
packages/training/training/processing/worker_pool.py
Normal file
182
packages/training/training/processing/worker_pool.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Abstract base class for worker pools.
|
||||
|
||||
Provides a unified interface for CPU and GPU worker pools with proper
|
||||
initialization, task submission, and resource cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Container for task execution results."""
|
||||
|
||||
task_id: str
|
||||
success: bool
|
||||
data: Any
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
pool_type: str = ""
|
||||
extra: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkerPool(ABC):
|
||||
"""
|
||||
Abstract base class for worker pools.
|
||||
|
||||
Provides a common interface for ProcessPoolExecutor-based worker pools
|
||||
with proper initialization using the 'spawn' start method for CUDA
|
||||
compatibility.
|
||||
|
||||
Attributes:
|
||||
max_workers: Maximum number of worker processes.
|
||||
use_gpu: Whether this pool uses GPU resources.
|
||||
gpu_id: GPU device ID (only relevant if use_gpu=True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int,
|
||||
use_gpu: bool = False,
|
||||
gpu_id: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker pool configuration.
|
||||
|
||||
Args:
|
||||
max_workers: Maximum number of worker processes.
|
||||
use_gpu: Whether this pool uses GPU resources.
|
||||
gpu_id: GPU device ID for GPU pools.
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.use_gpu = use_gpu
|
||||
self.gpu_id = gpu_id
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the pool name for logging."""
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def get_initializer(self) -> Optional[Callable[..., None]]:
|
||||
"""
|
||||
Return the worker initialization function.
|
||||
|
||||
This function is called once per worker process when it starts.
|
||||
Use it to load models, set environment variables, etc.
|
||||
|
||||
Returns:
|
||||
Callable to initialize each worker, or None if no initialization needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_init_args(self) -> tuple:
|
||||
"""
|
||||
Return arguments for the initializer function.
|
||||
|
||||
Returns:
|
||||
Tuple of arguments to pass to the initializer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Creates a ProcessPoolExecutor with the 'spawn' start method
|
||||
for CUDA compatibility.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the pool is already started.
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError(f"{self.name} is already started")
|
||||
|
||||
# Use 'spawn' for CUDA compatibility
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
initializer = self.get_initializer()
|
||||
initargs = self.get_init_args()
|
||||
|
||||
logger.info(
|
||||
f"Starting {self.name} with {self.max_workers} workers "
|
||||
f"(GPU: {self.use_gpu}, GPU ID: {self.gpu_id})"
|
||||
)
|
||||
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=self.max_workers,
|
||||
mp_context=ctx,
|
||||
initializer=initializer,
|
||||
initargs=initargs if initializer else (),
|
||||
)
|
||||
self._started = True
|
||||
|
||||
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Future:
|
||||
"""
|
||||
Submit a task to the worker pool.
|
||||
|
||||
Args:
|
||||
fn: Function to execute.
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
Future representing the pending result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the pool is not started.
|
||||
"""
|
||||
if not self._started or self._executor is None:
|
||||
raise RuntimeError(f"{self.name} is not started. Call start() first.")
|
||||
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None:
|
||||
"""
|
||||
Shutdown the worker pool.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for all pending futures to complete.
|
||||
cancel_futures: If True, cancel all pending futures.
|
||||
"""
|
||||
if self._executor is not None:
|
||||
logger.info(f"Shutting down {self.name} (wait={wait})")
|
||||
self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
|
||||
self._executor = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the pool is currently running."""
|
||||
return self._started and self._executor is not None
|
||||
|
||||
def __enter__(self) -> "WorkerPool":
|
||||
"""Context manager entry - start the pool."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Context manager exit - shutdown the pool."""
|
||||
self.shutdown(wait=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "running" if self.is_running else "stopped"
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"workers={self.max_workers}, "
|
||||
f"gpu={self.use_gpu}, "
|
||||
f"status={status})"
|
||||
)
|
||||
5
packages/training/training/yolo/__init__.py
Normal file
5
packages/training/training/yolo/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .annotation_generator import AnnotationGenerator, generate_annotations
|
||||
from .dataset_builder import DatasetBuilder
|
||||
from .db_dataset import DBYOLODataset, create_datasets
|
||||
|
||||
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder', 'DBYOLODataset', 'create_datasets']
|
||||
386
packages/training/training/yolo/annotation_generator.py
Normal file
386
packages/training/training/yolo/annotation_generator.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
YOLO Annotation Generator
|
||||
|
||||
Generates YOLO format annotations from matched fields.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import csv
|
||||
|
||||
|
||||
# Field class mapping for YOLO
|
||||
# Note: supplier_accounts is not a separate class - its matches are mapped to Bankgiro/Plusgiro
|
||||
FIELD_CLASSES = {
|
||||
'InvoiceNumber': 0,
|
||||
'InvoiceDate': 1,
|
||||
'InvoiceDueDate': 2,
|
||||
'OCR': 3,
|
||||
'Bankgiro': 4,
|
||||
'Plusgiro': 5,
|
||||
'Amount': 6,
|
||||
'supplier_organisation_number': 7,
|
||||
'customer_number': 8,
|
||||
'payment_line': 9, # Machine code payment line at bottom of invoice
|
||||
}
|
||||
|
||||
# Fields that need matching but map to other YOLO classes
|
||||
# supplier_accounts matches are classified as Bankgiro or Plusgiro based on account type
|
||||
ACCOUNT_FIELD_MAPPING = {
|
||||
'supplier_accounts': {
|
||||
'BG': 'Bankgiro', # BG:xxx -> Bankgiro class
|
||||
'PG': 'Plusgiro', # PG:xxx -> Plusgiro class
|
||||
}
|
||||
}
|
||||
|
||||
CLASS_NAMES = [
|
||||
'invoice_number',
|
||||
'invoice_date',
|
||||
'invoice_due_date',
|
||||
'ocr_number',
|
||||
'bankgiro',
|
||||
'plusgiro',
|
||||
'amount',
|
||||
'supplier_org_number',
|
||||
'customer_number',
|
||||
'payment_line', # Machine code payment line at bottom of invoice
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class YOLOAnnotation:
|
||||
"""Represents a single YOLO annotation."""
|
||||
class_id: int
|
||||
x_center: float # normalized 0-1
|
||||
y_center: float # normalized 0-1
|
||||
width: float # normalized 0-1
|
||||
height: float # normalized 0-1
|
||||
confidence: float = 1.0
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert to YOLO format string."""
|
||||
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
|
||||
|
||||
|
||||
class AnnotationGenerator:
|
||||
"""Generates YOLO annotations from document matches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_confidence: float = 0.7,
|
||||
bbox_padding_px: int = 20, # Absolute padding in pixels
|
||||
min_bbox_height_px: int = 30 # Minimum bbox height
|
||||
):
|
||||
"""
|
||||
Initialize annotation generator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum match score to include in training
|
||||
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
||||
min_bbox_height_px: Minimum bbox height in pixels
|
||||
"""
|
||||
self.min_confidence = min_confidence
|
||||
self.bbox_padding_px = bbox_padding_px
|
||||
self.min_bbox_height_px = min_bbox_height_px
|
||||
|
||||
def generate_from_matches(
|
||||
self,
|
||||
matches: dict[str, list[Any]], # field_name -> list of Match
|
||||
image_width: float,
|
||||
image_height: float,
|
||||
dpi: int = 300
|
||||
) -> list[YOLOAnnotation]:
|
||||
"""
|
||||
Generate YOLO annotations from field matches.
|
||||
|
||||
Args:
|
||||
matches: Dict of field_name -> list of Match objects
|
||||
image_width: Width of the rendered image in pixels
|
||||
image_height: Height of the rendered image in pixels
|
||||
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
|
||||
|
||||
Returns:
|
||||
List of YOLOAnnotation objects
|
||||
"""
|
||||
annotations = []
|
||||
|
||||
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
||||
scale = dpi / 72.0
|
||||
|
||||
for field_name, field_matches in matches.items():
|
||||
if field_name not in FIELD_CLASSES:
|
||||
continue
|
||||
|
||||
class_id = FIELD_CLASSES[field_name]
|
||||
|
||||
# Take only the best match per field
|
||||
if field_matches:
|
||||
best_match = field_matches[0] # Already sorted by score
|
||||
|
||||
if best_match.score < self.min_confidence:
|
||||
continue
|
||||
|
||||
# best_match.bbox is in PDF points (72 DPI), convert to pixels
|
||||
x0, y0, x1, y1 = best_match.bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
|
||||
# Ensure minimum height
|
||||
current_height = y1 - y0
|
||||
if current_height < self.min_bbox_height_px:
|
||||
extra = (self.min_bbox_height_px - current_height) / 2
|
||||
y0 = max(0, y0 - extra)
|
||||
y1 = min(image_height, y1 + extra)
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
# Clamp values to 0-1
|
||||
x_center = max(0, min(1, x_center))
|
||||
y_center = max(0, min(1, y_center))
|
||||
width = max(0, min(1, width))
|
||||
height = max(0, min(1, height))
|
||||
|
||||
annotations.append(YOLOAnnotation(
|
||||
class_id=class_id,
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
confidence=best_match.score
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
def add_payment_line_annotation(
|
||||
self,
|
||||
annotations: list[YOLOAnnotation],
|
||||
payment_line_bbox: tuple[float, float, float, float],
|
||||
confidence: float,
|
||||
image_width: float,
|
||||
image_height: float,
|
||||
dpi: int = 300
|
||||
) -> list[YOLOAnnotation]:
|
||||
"""
|
||||
Add payment_line annotation from machine code parser result.
|
||||
|
||||
Args:
|
||||
annotations: Existing list of annotations to append to
|
||||
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
||||
confidence: Confidence score from machine code parser
|
||||
image_width: Width of the rendered image in pixels
|
||||
image_height: Height of the rendered image in pixels
|
||||
dpi: DPI used for rendering
|
||||
|
||||
Returns:
|
||||
Updated annotations list with payment_line annotation added
|
||||
"""
|
||||
if not payment_line_bbox or confidence < self.min_confidence:
|
||||
return annotations
|
||||
|
||||
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
||||
scale = dpi / 72.0
|
||||
|
||||
x0, y0, x1, y1 = payment_line_bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
# Clamp values to 0-1
|
||||
x_center = max(0, min(1, x_center))
|
||||
y_center = max(0, min(1, y_center))
|
||||
width = max(0, min(1, width))
|
||||
height = max(0, min(1, height))
|
||||
|
||||
annotations.append(YOLOAnnotation(
|
||||
class_id=FIELD_CLASSES['payment_line'],
|
||||
x_center=x_center,
|
||||
y_center=y_center,
|
||||
width=width,
|
||||
height=height,
|
||||
confidence=confidence
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
def save_annotations(
|
||||
self,
|
||||
annotations: list[YOLOAnnotation],
|
||||
output_path: str | Path
|
||||
) -> None:
|
||||
"""Save annotations to a YOLO format text file."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
for ann in annotations:
|
||||
f.write(ann.to_string() + '\n')
|
||||
|
||||
@staticmethod
|
||||
def generate_classes_file(output_path: str | Path) -> None:
|
||||
"""Generate the classes.txt file for YOLO."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
for class_name in CLASS_NAMES:
|
||||
f.write(class_name + '\n')
|
||||
|
||||
@staticmethod
|
||||
def generate_yaml_config(
|
||||
output_path: str | Path,
|
||||
train_path: str = 'train/images',
|
||||
val_path: str = 'val/images',
|
||||
test_path: str = 'test/images',
|
||||
use_wsl_paths: bool | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Generate YOLO dataset YAML configuration.
|
||||
|
||||
Args:
|
||||
output_path: Path to output YAML file
|
||||
train_path: Relative path to training images
|
||||
val_path: Relative path to validation images
|
||||
test_path: Relative path to test images
|
||||
use_wsl_paths: If True, convert Windows paths to WSL format.
|
||||
If None, auto-detect based on environment.
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dataset_dir = output_path.parent.absolute()
|
||||
dataset_path_str = str(dataset_dir)
|
||||
|
||||
# Auto-detect WSL environment
|
||||
if use_wsl_paths is None:
|
||||
# Check if running inside WSL
|
||||
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
|
||||
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
|
||||
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
|
||||
use_wsl_paths = is_wsl
|
||||
|
||||
# Convert path format based on environment
|
||||
if use_wsl_paths:
|
||||
# Running in WSL: convert Windows paths to /mnt/c/... format
|
||||
dataset_path_str = dataset_path_str.replace('\\', '/')
|
||||
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
|
||||
drive = dataset_path_str[0].lower()
|
||||
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
|
||||
elif platform.system() == 'Windows':
|
||||
# Running on native Windows: use forward slashes for YOLO compatibility
|
||||
dataset_path_str = dataset_path_str.replace('\\', '/')
|
||||
|
||||
config = f"""# Invoice Field Detection Dataset
|
||||
path: {dataset_path_str}
|
||||
train: {train_path}
|
||||
val: {val_path}
|
||||
test: {test_path}
|
||||
|
||||
# Classes
|
||||
names:
|
||||
"""
|
||||
for i, name in enumerate(CLASS_NAMES):
|
||||
config += f" {i}: {name}\n"
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(config)
|
||||
|
||||
|
||||
def generate_annotations(
|
||||
pdf_path: str | Path,
|
||||
structured_data: dict[str, Any],
|
||||
output_dir: str | Path,
|
||||
dpi: int = 300
|
||||
) -> list[Path]:
|
||||
"""
|
||||
Generate YOLO annotations for a PDF using structured data.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
structured_data: Dict with field values from CSV
|
||||
output_dir: Directory to save images and labels
|
||||
dpi: Resolution for rendering
|
||||
|
||||
Returns:
|
||||
List of paths to generated annotation files
|
||||
"""
|
||||
from shared.pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from shared.pdf.renderer import get_render_dimensions
|
||||
from shared.ocr import OCREngine
|
||||
from shared.matcher import FieldMatcher
|
||||
from shared.normalize import normalize_field
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
images_dir = output_dir / 'images'
|
||||
labels_dir = output_dir / 'labels'
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
labels_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generator = AnnotationGenerator()
|
||||
matcher = FieldMatcher()
|
||||
annotation_files = []
|
||||
|
||||
# Check PDF type
|
||||
use_ocr = not is_text_pdf(pdf_path)
|
||||
|
||||
# Initialize OCR if needed
|
||||
ocr_engine = OCREngine() if use_ocr else None
|
||||
|
||||
# Process each page
|
||||
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
|
||||
# Get image dimensions
|
||||
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
if use_ocr:
|
||||
from PIL import Image
|
||||
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||
else:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = structured_data.get(field_name)
|
||||
if value is None or str(value).strip() == '':
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
if field_matches:
|
||||
matches[field_name] = field_matches
|
||||
|
||||
# Generate annotations (pass DPI for coordinate conversion)
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
# Save annotations
|
||||
if annotations:
|
||||
label_path = labels_dir / f"{image_path.stem}.txt"
|
||||
generator.save_annotations(annotations, label_path)
|
||||
annotation_files.append(label_path)
|
||||
|
||||
return annotation_files
|
||||
249
packages/training/training/yolo/dataset_builder.py
Normal file
249
packages/training/training/yolo/dataset_builder.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
YOLO Dataset Builder
|
||||
|
||||
Builds training dataset from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
import random
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetStats:
|
||||
"""Statistics about the generated dataset."""
|
||||
total_documents: int
|
||||
successful: int
|
||||
failed: int
|
||||
total_annotations: int
|
||||
annotations_per_class: dict[str, int]
|
||||
train_count: int
|
||||
val_count: int
|
||||
test_count: int
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
"""Builds YOLO training dataset from PDFs and CSV data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pdf_dir: str | Path,
|
||||
csv_path: str | Path,
|
||||
output_dir: str | Path,
|
||||
document_id_column: str = 'DocumentId',
|
||||
dpi: int = 300,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
test_ratio: float = 0.1
|
||||
):
|
||||
"""
|
||||
Initialize dataset builder.
|
||||
|
||||
Args:
|
||||
pdf_dir: Directory containing PDF files
|
||||
csv_path: Path to structured data CSV
|
||||
output_dir: Output directory for dataset
|
||||
document_id_column: Column name for document ID
|
||||
dpi: Resolution for rendering
|
||||
train_ratio: Fraction for training set
|
||||
val_ratio: Fraction for validation set
|
||||
test_ratio: Fraction for test set
|
||||
"""
|
||||
self.pdf_dir = Path(pdf_dir)
|
||||
self.csv_path = Path(csv_path)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.document_id_column = document_id_column
|
||||
self.dpi = dpi
|
||||
self.train_ratio = train_ratio
|
||||
self.val_ratio = val_ratio
|
||||
self.test_ratio = test_ratio
|
||||
|
||||
def load_structured_data(self) -> dict[str, dict]:
|
||||
"""Load structured data from CSV."""
|
||||
data = {}
|
||||
|
||||
with open(self.csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
doc_id = row.get(self.document_id_column)
|
||||
if doc_id:
|
||||
data[doc_id] = row
|
||||
|
||||
return data
|
||||
|
||||
def find_pdf_for_document(self, doc_id: str) -> Path | None:
|
||||
"""Find PDF file for a document ID."""
|
||||
# Try common naming patterns
|
||||
patterns = [
|
||||
f"{doc_id}.pdf",
|
||||
f"{doc_id.lower()}.pdf",
|
||||
f"{doc_id.upper()}.pdf",
|
||||
f"*{doc_id}*.pdf",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = list(self.pdf_dir.glob(pattern))
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
return None
|
||||
|
||||
def build(self, seed: int = 42) -> DatasetStats:
|
||||
"""
|
||||
Build the complete dataset.
|
||||
|
||||
Args:
|
||||
seed: Random seed for train/val/test split
|
||||
|
||||
Returns:
|
||||
DatasetStats with build results
|
||||
"""
|
||||
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
# Setup output directories
|
||||
for split in ['train', 'val', 'test']:
|
||||
(self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate config files
|
||||
AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml')
|
||||
|
||||
# Load structured data
|
||||
structured_data = self.load_structured_data()
|
||||
|
||||
# Stats tracking
|
||||
stats = {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'annotations': 0,
|
||||
'per_class': {name: 0 for name in CLASS_NAMES},
|
||||
'splits': {'train': 0, 'val': 0, 'test': 0}
|
||||
}
|
||||
|
||||
# Process each document
|
||||
processed_items = []
|
||||
|
||||
for doc_id, data in structured_data.items():
|
||||
stats['total'] += 1
|
||||
|
||||
pdf_path = self.find_pdf_for_document(doc_id)
|
||||
if not pdf_path:
|
||||
print(f"Warning: PDF not found for document {doc_id}")
|
||||
stats['failed'] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Generate to temp dir first
|
||||
temp_dir = self.output_dir / 'temp' / doc_id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from .annotation_generator import generate_annotations
|
||||
annotation_files = generate_annotations(
|
||||
pdf_path, data, temp_dir, self.dpi
|
||||
)
|
||||
|
||||
if annotation_files:
|
||||
processed_items.append({
|
||||
'doc_id': doc_id,
|
||||
'temp_dir': temp_dir,
|
||||
'annotation_files': annotation_files
|
||||
})
|
||||
stats['successful'] += 1
|
||||
else:
|
||||
print(f"Warning: No annotations generated for {doc_id}")
|
||||
stats['failed'] += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
stats['failed'] += 1
|
||||
|
||||
# Shuffle and split
|
||||
random.shuffle(processed_items)
|
||||
|
||||
n_train = int(len(processed_items) * self.train_ratio)
|
||||
n_val = int(len(processed_items) * self.val_ratio)
|
||||
|
||||
splits = {
|
||||
'train': processed_items[:n_train],
|
||||
'val': processed_items[n_train:n_train + n_val],
|
||||
'test': processed_items[n_train + n_val:]
|
||||
}
|
||||
|
||||
# Move files to final locations
|
||||
for split_name, items in splits.items():
|
||||
for item in items:
|
||||
temp_dir = item['temp_dir']
|
||||
|
||||
# Move images
|
||||
for img in (temp_dir / 'images').glob('*'):
|
||||
dest = self.output_dir / split_name / 'images' / img.name
|
||||
shutil.move(str(img), str(dest))
|
||||
|
||||
# Move labels and count annotations
|
||||
for label in (temp_dir / 'labels').glob('*.txt'):
|
||||
dest = self.output_dir / split_name / 'labels' / label.name
|
||||
shutil.move(str(label), str(dest))
|
||||
|
||||
# Count annotations per class
|
||||
with open(dest, 'r') as f:
|
||||
for line in f:
|
||||
class_id = int(line.strip().split()[0])
|
||||
if 0 <= class_id < len(CLASS_NAMES):
|
||||
stats['per_class'][CLASS_NAMES[class_id]] += 1
|
||||
stats['annotations'] += 1
|
||||
|
||||
stats['splits'][split_name] += 1
|
||||
|
||||
# Cleanup temp dir
|
||||
shutil.rmtree(self.output_dir / 'temp', ignore_errors=True)
|
||||
|
||||
return DatasetStats(
|
||||
total_documents=stats['total'],
|
||||
successful=stats['successful'],
|
||||
failed=stats['failed'],
|
||||
total_annotations=stats['annotations'],
|
||||
annotations_per_class=stats['per_class'],
|
||||
train_count=stats['splits']['train'],
|
||||
val_count=stats['splits']['val'],
|
||||
test_count=stats['splits']['test']
|
||||
)
|
||||
|
||||
def process_single_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
data: dict,
|
||||
split: str = 'train'
|
||||
) -> bool:
|
||||
"""
|
||||
Process a single document (for incremental building).
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
data: Structured data dict
|
||||
split: Which split to add to
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
from .annotation_generator import generate_annotations
|
||||
|
||||
pdf_path = self.find_pdf_for_document(doc_id)
|
||||
if not pdf_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
output_subdir = self.output_dir / split
|
||||
annotation_files = generate_annotations(
|
||||
pdf_path, data, output_subdir, self.dpi
|
||||
)
|
||||
return len(annotation_files) > 0
|
||||
except Exception as e:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
return False
|
||||
728
packages/training/training/yolo/db_dataset.py
Normal file
728
packages/training/training/yolo/db_dataset.py
Normal file
@@ -0,0 +1,728 @@
|
||||
"""
|
||||
Database-backed YOLO Dataset
|
||||
|
||||
Loads images from filesystem and labels from PostgreSQL database.
|
||||
Generates YOLO format labels dynamically at training time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level LRU cache for image loading (shared across dataset instances)
|
||||
@lru_cache(maxsize=256)
|
||||
def _load_image_cached(image_path: str) -> tuple[np.ndarray, int, int]:
|
||||
"""
|
||||
Load and cache image from disk.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file (must be string for hashability)
|
||||
|
||||
Returns:
|
||||
Tuple of (image_array, width, height)
|
||||
"""
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
width, height = image.size
|
||||
image_array = np.array(image)
|
||||
return image_array, width, height
|
||||
|
||||
|
||||
def clear_image_cache():
|
||||
"""Clear the image cache to free memory."""
|
||||
_load_image_cached.cache_clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetItem:
|
||||
"""Single item in the dataset."""
|
||||
document_id: str
|
||||
image_path: Path
|
||||
page_no: int
|
||||
labels: list[YOLOAnnotation]
|
||||
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
|
||||
csv_split: str | None = None # CSV-defined split ('train', 'test', etc.)
|
||||
|
||||
|
||||
class DBYOLODataset:
|
||||
"""
|
||||
YOLO Dataset that reads labels from database.
|
||||
|
||||
This dataset:
|
||||
1. Scans temp directory for rendered images
|
||||
2. Queries database for bbox data
|
||||
3. Generates YOLO labels dynamically
|
||||
4. Performs train/val/test split at runtime
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
images_dir: str | Path,
|
||||
db: Any, # DocumentDB instance
|
||||
split: str = 'train',
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
dpi: int = DEFAULT_DPI, # Must match the DPI used in autolabel_tasks.py for rendering
|
||||
min_confidence: float = 0.7,
|
||||
bbox_padding_px: int = 20,
|
||||
min_bbox_height_px: int = 30,
|
||||
limit: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize database-backed YOLO dataset.
|
||||
|
||||
Args:
|
||||
images_dir: Directory containing temp/{doc_id}/images/*.png
|
||||
db: DocumentDB instance for label queries
|
||||
split: Which split to use ('train', 'val', 'test')
|
||||
train_ratio: Ratio for training set
|
||||
val_ratio: Ratio for validation set
|
||||
seed: Random seed for reproducible splits
|
||||
dpi: DPI used for rendering (for coordinate conversion)
|
||||
min_confidence: Minimum match score to include
|
||||
bbox_padding_px: Padding around bboxes
|
||||
min_bbox_height_px: Minimum bbox height
|
||||
limit: Maximum number of documents to use (None for all)
|
||||
"""
|
||||
self.images_dir = Path(images_dir)
|
||||
self.db = db
|
||||
self.split = split
|
||||
self.train_ratio = train_ratio
|
||||
self.val_ratio = val_ratio
|
||||
self.seed = seed
|
||||
self.dpi = dpi
|
||||
self.min_confidence = min_confidence
|
||||
self.bbox_padding_px = bbox_padding_px
|
||||
self.min_bbox_height_px = min_bbox_height_px
|
||||
self.limit = limit
|
||||
|
||||
# Load and split dataset
|
||||
self.items: list[DatasetItem] = []
|
||||
self._all_items: list[DatasetItem] = [] # Cache all items for sharing
|
||||
self._doc_ids_ordered: list[str] = [] # Cache ordered doc IDs for consistent splits
|
||||
self._load_dataset()
|
||||
|
||||
@classmethod
|
||||
def from_shared_data(
|
||||
cls,
|
||||
source_dataset: 'DBYOLODataset',
|
||||
split: str,
|
||||
) -> 'DBYOLODataset':
|
||||
"""
|
||||
Create a new dataset instance sharing data from an existing one.
|
||||
|
||||
This avoids re-loading data from filesystem and database.
|
||||
|
||||
Args:
|
||||
source_dataset: Dataset to share data from
|
||||
split: Which split to use ('train', 'val', 'test')
|
||||
|
||||
Returns:
|
||||
New dataset instance with shared data
|
||||
"""
|
||||
# Create instance without loading (we'll share data)
|
||||
instance = object.__new__(cls)
|
||||
|
||||
# Copy configuration
|
||||
instance.images_dir = source_dataset.images_dir
|
||||
instance.db = source_dataset.db
|
||||
instance.split = split
|
||||
instance.train_ratio = source_dataset.train_ratio
|
||||
instance.val_ratio = source_dataset.val_ratio
|
||||
instance.seed = source_dataset.seed
|
||||
instance.dpi = source_dataset.dpi
|
||||
instance.min_confidence = source_dataset.min_confidence
|
||||
instance.bbox_padding_px = source_dataset.bbox_padding_px
|
||||
instance.min_bbox_height_px = source_dataset.min_bbox_height_px
|
||||
instance.limit = source_dataset.limit
|
||||
|
||||
# Share loaded data
|
||||
instance._all_items = source_dataset._all_items
|
||||
instance._doc_ids_ordered = source_dataset._doc_ids_ordered
|
||||
|
||||
# Split items for this split
|
||||
instance.items = instance._split_dataset_from_cache()
|
||||
print(f"Split '{split}': {len(instance.items)} items")
|
||||
|
||||
return instance
|
||||
|
||||
def _load_dataset(self):
|
||||
"""Load dataset items from filesystem and database."""
|
||||
# Find all document directories
|
||||
temp_dir = self.images_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Temp directory not found: {temp_dir}")
|
||||
return
|
||||
|
||||
# Collect all document IDs with images
|
||||
doc_image_map: dict[str, list[Path]] = {}
|
||||
for doc_dir in temp_dir.iterdir():
|
||||
if not doc_dir.is_dir():
|
||||
continue
|
||||
|
||||
images_path = doc_dir / 'images'
|
||||
if not images_path.exists():
|
||||
continue
|
||||
|
||||
images = list(images_path.glob('*.png'))
|
||||
if images:
|
||||
doc_image_map[doc_dir.name] = sorted(images)
|
||||
|
||||
print(f"Found {len(doc_image_map)} documents with images")
|
||||
|
||||
# Query database for all document labels
|
||||
doc_ids = list(doc_image_map.keys())
|
||||
doc_labels = self._load_labels_from_db(doc_ids)
|
||||
|
||||
print(f"Loaded labels for {len(doc_labels)} documents from database")
|
||||
|
||||
# Build dataset items
|
||||
all_items: list[DatasetItem] = []
|
||||
skipped_no_labels = 0
|
||||
skipped_no_db_record = 0
|
||||
total_images = 0
|
||||
|
||||
for doc_id, images in doc_image_map.items():
|
||||
doc_data = doc_labels.get(doc_id)
|
||||
|
||||
# Skip documents that don't exist in database
|
||||
if doc_data is None:
|
||||
skipped_no_db_record += len(images)
|
||||
total_images += len(images)
|
||||
continue
|
||||
|
||||
labels_by_page, is_scanned, csv_split = doc_data
|
||||
|
||||
for image_path in images:
|
||||
total_images += 1
|
||||
# Extract page number from filename (e.g., "doc_page_000.png")
|
||||
page_no = self._extract_page_no(image_path.stem)
|
||||
|
||||
# Get labels for this page
|
||||
page_labels = labels_by_page.get(page_no, [])
|
||||
|
||||
if page_labels: # Only include pages with labels
|
||||
all_items.append(DatasetItem(
|
||||
document_id=doc_id,
|
||||
image_path=image_path,
|
||||
page_no=page_no,
|
||||
labels=page_labels,
|
||||
is_scanned=is_scanned,
|
||||
csv_split=csv_split
|
||||
))
|
||||
else:
|
||||
skipped_no_labels += 1
|
||||
|
||||
print(f"Total images found: {total_images}")
|
||||
print(f"Images with labels: {len(all_items)}")
|
||||
if skipped_no_db_record > 0:
|
||||
print(f"Skipped {skipped_no_db_record} images (document not in database)")
|
||||
if skipped_no_labels > 0:
|
||||
print(f"Skipped {skipped_no_labels} images (no labels for page)")
|
||||
|
||||
# Cache all items for sharing with other splits
|
||||
self._all_items = all_items
|
||||
|
||||
# Split dataset
|
||||
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
||||
print(f"Split '{self.split}': {len(self.items)} items")
|
||||
|
||||
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
|
||||
"""
|
||||
Load labels from database for given document IDs using batch queries.
|
||||
|
||||
Returns:
|
||||
Dict of doc_id -> (page_labels, is_scanned, split)
|
||||
where page_labels is {page_no -> list[YOLOAnnotation]}
|
||||
is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
|
||||
split is the CSV-defined split ('train', 'test', etc.) or None
|
||||
"""
|
||||
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]] = {}
|
||||
|
||||
# Query in batches using efficient batch method
|
||||
batch_size = 500
|
||||
for i in range(0, len(doc_ids), batch_size):
|
||||
batch_ids = doc_ids[i:i + batch_size]
|
||||
|
||||
# Use batch query instead of individual queries (N+1 fix)
|
||||
docs_batch = self.db.get_documents_batch(batch_ids)
|
||||
|
||||
for doc_id, doc in docs_batch.items():
|
||||
if not doc.get('success'):
|
||||
continue
|
||||
|
||||
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
|
||||
is_scanned = doc.get('pdf_type') == 'scanned'
|
||||
|
||||
# Get CSV-defined split
|
||||
csv_split = doc.get('split')
|
||||
|
||||
page_labels: dict[int, list[YOLOAnnotation]] = {}
|
||||
|
||||
for field_result in doc.get('field_results', []):
|
||||
if not field_result.get('matched'):
|
||||
continue
|
||||
|
||||
field_name = field_result.get('field_name')
|
||||
|
||||
# Map supplier_accounts(X) to the actual class name (Bankgiro/Plusgiro)
|
||||
yolo_class_name = field_name
|
||||
if field_name and field_name.startswith('supplier_accounts('):
|
||||
# Extract the account type: "supplier_accounts(Bankgiro)" -> "Bankgiro"
|
||||
yolo_class_name = field_name.split('(')[1].rstrip(')')
|
||||
|
||||
if yolo_class_name not in FIELD_CLASSES:
|
||||
continue
|
||||
|
||||
score = field_result.get('score', 0)
|
||||
if score < self.min_confidence:
|
||||
continue
|
||||
|
||||
bbox = field_result.get('bbox')
|
||||
page_no = field_result.get('page_no', 0)
|
||||
|
||||
if bbox and len(bbox) == 4:
|
||||
annotation = self._create_annotation(
|
||||
field_name=yolo_class_name, # Use mapped class name
|
||||
bbox=bbox,
|
||||
score=score
|
||||
)
|
||||
|
||||
if page_no not in page_labels:
|
||||
page_labels[page_no] = []
|
||||
page_labels[page_no].append(annotation)
|
||||
|
||||
if page_labels:
|
||||
result[doc_id] = (page_labels, is_scanned, csv_split)
|
||||
|
||||
return result
|
||||
|
||||
def _create_annotation(
|
||||
self,
|
||||
field_name: str,
|
||||
bbox: list[float],
|
||||
score: float
|
||||
) -> YOLOAnnotation:
|
||||
"""
|
||||
Create a YOLO annotation from bbox.
|
||||
|
||||
Note: bbox is in PDF points (72 DPI), will be normalized later.
|
||||
"""
|
||||
class_id = FIELD_CLASSES[field_name]
|
||||
x0, y0, x1, y1 = bbox
|
||||
|
||||
# Store raw PDF coordinates - will be normalized when getting item
|
||||
return YOLOAnnotation(
|
||||
class_id=class_id,
|
||||
x_center=(x0 + x1) / 2, # center in PDF points
|
||||
y_center=(y0 + y1) / 2,
|
||||
width=x1 - x0,
|
||||
height=y1 - y0,
|
||||
confidence=score
|
||||
)
|
||||
|
||||
def _extract_page_no(self, stem: str) -> int:
|
||||
"""Extract page number from image filename."""
|
||||
# Format: "{doc_id}_page_{page_no:03d}"
|
||||
parts = stem.rsplit('_', 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
return int(parts[1])
|
||||
except ValueError:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
|
||||
"""
|
||||
Split items into train/val/test based on CSV-defined split field.
|
||||
|
||||
If CSV has 'split' field, use it directly.
|
||||
Otherwise, fall back to random splitting based on train_ratio/val_ratio.
|
||||
|
||||
Returns:
|
||||
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
|
||||
reused for consistent splits across shared datasets.
|
||||
"""
|
||||
# Group by document ID for proper splitting
|
||||
doc_items: dict[str, list[DatasetItem]] = {}
|
||||
doc_csv_split: dict[str, str | None] = {} # Track CSV split per document
|
||||
|
||||
for item in items:
|
||||
if item.document_id not in doc_items:
|
||||
doc_items[item.document_id] = []
|
||||
doc_csv_split[item.document_id] = item.csv_split
|
||||
doc_items[item.document_id].append(item)
|
||||
|
||||
# Check if we have CSV-defined splits
|
||||
has_csv_splits = any(split is not None for split in doc_csv_split.values())
|
||||
|
||||
doc_ids = list(doc_items.keys())
|
||||
|
||||
if has_csv_splits:
|
||||
# Use CSV-defined splits
|
||||
print("Using CSV-defined split field for train/val/test assignment")
|
||||
|
||||
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
|
||||
# 'val' is taken from train set using val_ratio
|
||||
split_doc_ids = []
|
||||
|
||||
if self.split == 'train':
|
||||
# Get documents marked as 'train' or no split defined
|
||||
train_docs = [doc_id for doc_id in doc_ids
|
||||
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||
|
||||
# Take train_ratio of train docs for actual training, rest for val
|
||||
random.seed(self.seed)
|
||||
random.shuffle(train_docs)
|
||||
|
||||
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||
split_doc_ids = train_docs[:n_train]
|
||||
|
||||
elif self.split == 'val':
|
||||
# Get documents marked as 'train' and take val portion
|
||||
train_docs = [doc_id for doc_id in doc_ids
|
||||
if doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(train_docs)
|
||||
|
||||
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||
split_doc_ids = train_docs[n_train:]
|
||||
|
||||
else: # test
|
||||
# Get documents marked as 'test'
|
||||
split_doc_ids = [doc_id for doc_id in doc_ids
|
||||
if doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
|
||||
|
||||
# Apply limit if specified
|
||||
if self.limit is not None and self.limit < len(split_doc_ids):
|
||||
split_doc_ids = split_doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
|
||||
else:
|
||||
# Fall back to random splitting (original behavior)
|
||||
print("No CSV split field found, using random splitting")
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(doc_ids)
|
||||
|
||||
# Apply limit if specified (before splitting)
|
||||
if self.limit is not None and self.limit < len(doc_ids):
|
||||
doc_ids = doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
|
||||
# Calculate split indices
|
||||
n_total = len(doc_ids)
|
||||
n_train = int(n_total * self.train_ratio)
|
||||
n_val = int(n_total * self.val_ratio)
|
||||
|
||||
# Split document IDs
|
||||
if self.split == 'train':
|
||||
split_doc_ids = doc_ids[:n_train]
|
||||
elif self.split == 'val':
|
||||
split_doc_ids = doc_ids[n_train:n_train + n_val]
|
||||
else: # test
|
||||
split_doc_ids = doc_ids[n_train + n_val:]
|
||||
|
||||
# Collect items for this split
|
||||
split_items = []
|
||||
for doc_id in split_doc_ids:
|
||||
split_items.extend(doc_items[doc_id])
|
||||
|
||||
return split_items, doc_ids
|
||||
|
||||
def _split_dataset_from_cache(self) -> list[DatasetItem]:
|
||||
"""
|
||||
Split items using cached data from a shared dataset.
|
||||
|
||||
Uses pre-computed doc_ids order for consistent splits.
|
||||
Respects CSV-defined splits if available.
|
||||
"""
|
||||
# Group by document ID and track CSV splits
|
||||
doc_items: dict[str, list[DatasetItem]] = {}
|
||||
doc_csv_split: dict[str, str | None] = {}
|
||||
|
||||
for item in self._all_items:
|
||||
if item.document_id not in doc_items:
|
||||
doc_items[item.document_id] = []
|
||||
doc_csv_split[item.document_id] = item.csv_split
|
||||
doc_items[item.document_id].append(item)
|
||||
|
||||
# Check if we have CSV-defined splits
|
||||
has_csv_splits = any(split is not None for split in doc_csv_split.values())
|
||||
|
||||
doc_ids = self._doc_ids_ordered
|
||||
|
||||
if has_csv_splits:
|
||||
# Use CSV-defined splits
|
||||
if self.split == 'train':
|
||||
train_docs = [doc_id for doc_id in doc_ids
|
||||
if doc_id in doc_csv_split and
|
||||
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(train_docs)
|
||||
|
||||
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||
split_doc_ids = train_docs[:n_train]
|
||||
|
||||
elif self.split == 'val':
|
||||
train_docs = [doc_id for doc_id in doc_ids
|
||||
if doc_id in doc_csv_split and
|
||||
doc_csv_split[doc_id] in (None, 'train', 'Train', 'TRAIN')]
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(train_docs)
|
||||
|
||||
n_train = int(len(train_docs) * (self.train_ratio / (self.train_ratio + self.val_ratio)))
|
||||
split_doc_ids = train_docs[n_train:]
|
||||
|
||||
else: # test
|
||||
split_doc_ids = [doc_id for doc_id in doc_ids
|
||||
if doc_id in doc_csv_split and
|
||||
doc_csv_split[doc_id] in ('test', 'Test', 'TEST')]
|
||||
|
||||
else:
|
||||
# Fall back to random splitting
|
||||
n_total = len(doc_ids)
|
||||
n_train = int(n_total * self.train_ratio)
|
||||
n_val = int(n_total * self.val_ratio)
|
||||
|
||||
if self.split == 'train':
|
||||
split_doc_ids = doc_ids[:n_train]
|
||||
elif self.split == 'val':
|
||||
split_doc_ids = doc_ids[n_train:n_train + n_val]
|
||||
else: # test
|
||||
split_doc_ids = doc_ids[n_train + n_val:]
|
||||
|
||||
# Collect items for this split
|
||||
split_items = []
|
||||
for doc_id in split_doc_ids:
|
||||
if doc_id in doc_items:
|
||||
split_items.extend(doc_items[doc_id])
|
||||
|
||||
return split_items
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Get image and labels for index.
|
||||
|
||||
Returns:
|
||||
(image, labels) where:
|
||||
- image: numpy array (H, W, C)
|
||||
- labels: numpy array (N, 5) with [class_id, x_center, y_center, width, height]
|
||||
"""
|
||||
item = self.items[idx]
|
||||
|
||||
# Load image using LRU cache (significant speedup during training)
|
||||
image_array, img_width, img_height = _load_image_cached(str(item.image_path))
|
||||
|
||||
# Convert annotations to YOLO format (normalized)
|
||||
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
|
||||
|
||||
return image_array, labels
|
||||
|
||||
def _convert_labels(
|
||||
self,
|
||||
annotations: list[YOLOAnnotation],
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
is_scanned: bool = False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert annotations to normalized YOLO format.
|
||||
|
||||
Args:
|
||||
annotations: List of annotations
|
||||
img_width: Actual image width in pixels
|
||||
img_height: Actual image height in pixels
|
||||
is_scanned: If True, bbox is already in pixels; if False, bbox is in PDF points
|
||||
|
||||
Returns:
|
||||
numpy array (N, 5) with [class_id, x_center, y_center, width, height]
|
||||
"""
|
||||
if not annotations:
|
||||
return np.zeros((0, 5), dtype=np.float32)
|
||||
|
||||
# Scale factor: PDF points (72 DPI) -> rendered pixels
|
||||
# Note: After the OCR coordinate fix, ALL bbox (both text and scanned PDF)
|
||||
# are stored in PDF points, so we always apply the same scaling.
|
||||
scale = self.dpi / 72.0
|
||||
|
||||
labels = []
|
||||
for ann in annotations:
|
||||
# Convert to pixels (if needed)
|
||||
x_center_px = ann.x_center * scale
|
||||
y_center_px = ann.y_center * scale
|
||||
width_px = ann.width * scale
|
||||
height_px = ann.height * scale
|
||||
|
||||
# Add padding
|
||||
pad = self.bbox_padding_px
|
||||
width_px += 2 * pad
|
||||
height_px += 2 * pad
|
||||
|
||||
# Ensure minimum height
|
||||
if height_px < self.min_bbox_height_px:
|
||||
height_px = self.min_bbox_height_px
|
||||
|
||||
# Normalize to 0-1
|
||||
x_center = x_center_px / img_width
|
||||
y_center = y_center_px / img_height
|
||||
width = width_px / img_width
|
||||
height = height_px / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
x_center = max(0, min(1, x_center))
|
||||
y_center = max(0, min(1, y_center))
|
||||
width = max(0, min(1, width))
|
||||
height = max(0, min(1, height))
|
||||
|
||||
labels.append([ann.class_id, x_center, y_center, width, height])
|
||||
|
||||
return np.array(labels, dtype=np.float32)
|
||||
|
||||
def get_image_path(self, idx: int) -> Path:
|
||||
"""Get image path for index."""
|
||||
return self.items[idx].image_path
|
||||
|
||||
def get_labels_for_yolo(self, idx: int) -> str:
|
||||
"""
|
||||
Get YOLO format labels as string for index.
|
||||
|
||||
Returns:
|
||||
String with YOLO format labels (one per line)
|
||||
"""
|
||||
item = self.items[idx]
|
||||
|
||||
# Use cached image loading to avoid duplicate disk reads
|
||||
_, img_width, img_height = _load_image_cached(str(item.image_path))
|
||||
|
||||
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
|
||||
|
||||
lines = []
|
||||
for label in labels:
|
||||
class_id = int(label[0])
|
||||
x_center, y_center, width, height = label[1:5]
|
||||
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def export_to_yolo_format(
|
||||
self,
|
||||
output_dir: str | Path,
|
||||
split_name: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Export dataset to standard YOLO format (images + label files).
|
||||
|
||||
This is useful for training with standard YOLO training scripts.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
split_name: Name for the split subdirectory (default: self.split)
|
||||
|
||||
Returns:
|
||||
Number of items exported
|
||||
"""
|
||||
import shutil
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
split_name = split_name or self.split
|
||||
|
||||
images_out = output_dir / split_name / 'images'
|
||||
labels_out = output_dir / split_name / 'labels'
|
||||
|
||||
# Clear existing directories before export
|
||||
if images_out.exists():
|
||||
shutil.rmtree(images_out)
|
||||
if labels_out.exists():
|
||||
shutil.rmtree(labels_out)
|
||||
|
||||
images_out.mkdir(parents=True, exist_ok=True)
|
||||
labels_out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
for idx in range(len(self)):
|
||||
item = self.items[idx]
|
||||
|
||||
# Copy image
|
||||
dest_image = images_out / item.image_path.name
|
||||
shutil.copy2(item.image_path, dest_image)
|
||||
|
||||
# Write label file
|
||||
label_content = self.get_labels_for_yolo(idx)
|
||||
label_path = labels_out / f"{item.image_path.stem}.txt"
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(label_content)
|
||||
|
||||
count += 1
|
||||
|
||||
print(f"Exported {count} items to {output_dir / split_name}")
|
||||
return count
|
||||
|
||||
|
||||
def create_datasets(
|
||||
images_dir: str | Path,
|
||||
db: Any,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
limit: int | None = None,
|
||||
**kwargs
|
||||
) -> dict[str, DBYOLODataset]:
|
||||
"""
|
||||
Create train/val/test datasets.
|
||||
|
||||
This function loads data once and shares it across all splits for efficiency.
|
||||
|
||||
Args:
|
||||
images_dir: Directory containing temp/{doc_id}/images/
|
||||
db: DocumentDB instance
|
||||
train_ratio: Training set ratio
|
||||
val_ratio: Validation set ratio
|
||||
seed: Random seed
|
||||
limit: Maximum number of documents to use (None for all)
|
||||
**kwargs: Additional arguments for DBYOLODataset
|
||||
|
||||
Returns:
|
||||
Dict with 'train', 'val', 'test' datasets
|
||||
"""
|
||||
# Create first dataset which loads all data
|
||||
print("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
first_dataset = DBYOLODataset(
|
||||
images_dir=images_dir,
|
||||
db=db,
|
||||
split='train',
|
||||
train_ratio=train_ratio,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create other splits by sharing loaded data
|
||||
datasets = {'train': first_dataset}
|
||||
for split in ['val', 'test']:
|
||||
datasets[split] = DBYOLODataset.from_shared_data(
|
||||
first_dataset,
|
||||
split=split,
|
||||
)
|
||||
|
||||
return datasets
|
||||
Reference in New Issue
Block a user